Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 36 additions & 146 deletions modal-go/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@ import (
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"time"

pickle "github.com/kisielk/og-rek"
pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

// From: modal/_utils/blob_utils.py
Expand All @@ -26,6 +25,9 @@ const maxObjectSizeBytes int = 2 * 1024 * 1024 // 2 MiB
// From: modal-client/modal/_utils/function_utils.py
const outputsTimeout time.Duration = time.Second * 55

// From: client/modal/_functions.py
const maxSystemRetries = 8

func timeNowSeconds() float64 {
return float64(time.Now().UnixNano()) / 1e9
}
Expand Down Expand Up @@ -85,7 +87,7 @@ func pickleDeserialize(buffer []byte) (any, error) {
}

// Serializes inputs, make a function call and return its ID
func (f *Function) execFunctionCall(args []any, kwargs map[string]any, invocationType pb.FunctionCallInvocationType) (*string, error) {
func (f *Function) createInput(args []any, kwargs map[string]any) (*pb.FunctionInput, error) {
payload, err := pickleSerialize(pickle.Tuple{args, kwargs})
if err != nil {
return nil, err
Expand All @@ -102,132 +104,57 @@ func (f *Function) execFunctionCall(args []any, kwargs map[string]any, invocatio
argsBlobId = &blobId
}

// Single input sync invocation
var functionInputs []*pb.FunctionPutInputsItem
functionInputItem := pb.FunctionPutInputsItem_builder{
Idx: 0,
Input: pb.FunctionInput_builder{
Args: argsBytes,
ArgsBlobId: argsBlobId,
DataFormat: pb.DataFormat_DATA_FORMAT_PICKLE,
MethodName: f.MethodName,
}.Build(),
}.Build()
functionInputs = append(functionInputs, functionInputItem)

functionMapResponse, err := client.FunctionMap(f.ctx, pb.FunctionMapRequest_builder{
FunctionId: f.FunctionId,
FunctionCallType: pb.FunctionCallType_FUNCTION_CALL_TYPE_UNARY,
FunctionCallInvocationType: invocationType,
PipelinedInputs: functionInputs,
}.Build())
if err != nil {
return nil, fmt.Errorf("FunctionMap error: %w", err)
}

functionCallId := functionMapResponse.GetFunctionCallId()
return &functionCallId, nil
return pb.FunctionInput_builder{
Args: argsBytes,
ArgsBlobId: argsBlobId,
DataFormat: pb.DataFormat_DATA_FORMAT_PICKLE,
MethodName: f.MethodName,
}.Build(), nil
}

// Remote executes a single input on a remote Function.
func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
invocationType := pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC
functionCallId, err := f.execFunctionCall(args, kwargs, invocationType)
input, err := f.createInput(args, kwargs)
if err != nil {
return nil, err
}

return pollFunctionOutput(f.ctx, *functionCallId, nil)
}

// Spawn starts running a single input on a remote function.
func (f *Function) Spawn(args []any, kwargs map[string]any) (*FunctionCall, error) {
invocationType := pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_ASYNC
functionCallId, err := f.execFunctionCall(args, kwargs, invocationType)
invocation, err := createControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
if err != nil {
return nil, err
}
functionCall := FunctionCall{
FunctionCallId: *functionCallId,
ctx: f.ctx,
}
return &functionCall, nil
}

// Poll for ouputs for a given FunctionCall ID.
func pollFunctionOutput(ctx context.Context, functionCallId string, timeout *time.Duration) (any, error) {
startTime := time.Now()
pollTimeout := outputsTimeout
if timeout != nil {
// Refresh backend call once per outputsTimeout.
pollTimeout = min(*timeout, outputsTimeout)
}

// TODO(ryan): Add tests for retries.
retryCount := uint32(0)
for {
response, err := client.FunctionGetOutputs(ctx, pb.FunctionGetOutputsRequest_builder{
FunctionCallId: functionCallId,
MaxValues: 1,
Timeout: float32(pollTimeout.Seconds()),
LastEntryId: "0-0",
ClearOnSuccess: true,
RequestedAt: timeNowSeconds(),
}.Build())
if err != nil {
return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err)
}

// Output serialization may fail if any of the output items can't be deserialized
// into a supported Go type. Users are expected to serialize outputs correctly.
outputs := response.GetOutputs()
if len(outputs) > 0 {
return processResult(ctx, outputs[0].GetResult(), outputs[0].GetDataFormat())
output, err := invocation.awaitOutput(nil)
if err == nil {
return output, nil
}

if timeout != nil {
remainingTime := *timeout - time.Since(startTime)
if remainingTime <= 0 {
message := fmt.Sprintf("Timeout exceeded: %.1fs", timeout.Seconds())
return nil, FunctionTimeoutError{message}
if errors.As(err, &InternalFailure{}) && retryCount <= maxSystemRetries {
if retryErr := invocation.retry(retryCount); retryErr != nil {
return nil, retryErr
}
pollTimeout = min(outputsTimeout, remainingTime)
retryCount++
continue
}
return nil, err
}
}

// processResult processes the result from an invocation.
func processResult(ctx context.Context, result *pb.GenericResult, dataFormat pb.DataFormat) (any, error) {
if result == nil {
return nil, RemoteError{"Received null result from invocation"}
// Spawn starts running a single input on a remote function.
func (f *Function) Spawn(args []any, kwargs map[string]any) (*FunctionCall, error) {
input, err := f.createInput(args, kwargs)
if err != nil {
return nil, err
}

var data []byte
var err error
switch result.WhichDataOneof() {
case pb.GenericResult_Data_case:
data = result.GetData()
case pb.GenericResult_DataBlobId_case:
data, err = blobDownload(ctx, result.GetDataBlobId())
if err != nil {
return nil, err
}
case pb.GenericResult_DataOneof_not_set_case:
data = nil
invocation, err := createControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
if err != nil {
return nil, err
}

switch result.GetStatus() {
case pb.GenericResult_GENERIC_STATUS_TIMEOUT:
return nil, FunctionTimeoutError{result.GetException()}
case pb.GenericResult_GENERIC_STATUS_INTERNAL_FAILURE:
return nil, InternalFailure{result.GetException()}
case pb.GenericResult_GENERIC_STATUS_SUCCESS:
// Proceed to the block below this switch statement.
default:
// In this case, `result.GetData()` may have a pickled user code exception with traceback
// from Python. We ignore this and only take the string representation.
return nil, RemoteError{result.GetException()}
functionCall := FunctionCall{
FunctionCallId: invocation.FunctionCallId,
ctx: f.ctx,
}

return deserializeDataFormat(data, dataFormat)
return &functionCall, nil
}

// blobUpload uploads a blob to storage and returns its ID.
Expand Down Expand Up @@ -272,40 +199,3 @@ func blobUpload(ctx context.Context, data []byte) (string, error) {
return "", fmt.Errorf("missing upload URL in BlobCreate response")
}
}

// blobDownload downloads a blob by its ID.
func blobDownload(ctx context.Context, blobId string) ([]byte, error) {
resp, err := client.BlobGet(ctx, pb.BlobGetRequest_builder{
BlobId: blobId,
}.Build())
if err != nil {
return nil, err
}
s3resp, err := http.Get(resp.GetDownloadUrl())
if err != nil {
return nil, fmt.Errorf("failed to download blob: %w", err)
}
defer s3resp.Body.Close()
buf, err := io.ReadAll(s3resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read blob data: %w", err)
}
return buf, nil
}

func deserializeDataFormat(data []byte, dataFormat pb.DataFormat) (any, error) {
switch dataFormat {
case pb.DataFormat_DATA_FORMAT_PICKLE:
return pickleDeserialize(data)
case pb.DataFormat_DATA_FORMAT_ASGI:
return nil, fmt.Errorf("ASGI data format is not supported in Go")
case pb.DataFormat_DATA_FORMAT_GENERATOR_DONE:
var done pb.GeneratorDone
if err := proto.Unmarshal(data, &done); err != nil {
return nil, fmt.Errorf("failed to unmarshal GeneratorDone: %w", err)
}
return &done, nil
default:
return nil, fmt.Errorf("unsupported data format: %s", dataFormat.String())
}
}
3 changes: 2 additions & 1 deletion modal-go/function_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func (fc *FunctionCall) Get(options *FunctionCallGetOptions) (any, error) {
options = &FunctionCallGetOptions{}
}
ctx := fc.ctx
return pollFunctionOutput(ctx, fc.FunctionCallId, options.Timeout)
invocation := controlPlaneInvocationFromFunctionCallId(ctx, fc.FunctionCallId)
return invocation.awaitOutput(options.Timeout)
}

// FunctionCallCancelOptions are options for cancelling Function Calls.
Expand Down
Loading