Skip to content

Commit 90b362b

Browse files
committed
Add go implementation
1 parent 99f9f77 commit 90b362b

File tree

3 files changed

+232
-147
lines changed

3 files changed

+232
-147
lines changed

modal-go/function.go

Lines changed: 36 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@ import (
88
"crypto/md5"
99
"crypto/sha256"
1010
"encoding/base64"
11+
"errors"
1112
"fmt"
12-
"io"
1313
"net/http"
1414
"time"
1515

1616
pickle "github.com/kisielk/og-rek"
1717
pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto"
1818
"google.golang.org/grpc/codes"
1919
"google.golang.org/grpc/status"
20-
"google.golang.org/protobuf/proto"
2120
)
2221

2322
// From: modal/_utils/blob_utils.py
@@ -26,6 +25,9 @@ const maxObjectSizeBytes int = 2 * 1024 * 1024 // 2 MiB
2625
// From: modal-client/modal/_utils/function_utils.py
2726
const outputsTimeout time.Duration = time.Second * 55
2827

28+
// From: client/modal/_functions.py
29+
const maxSystemRetries = 8
30+
2931
func timeNowSeconds() float64 {
3032
return float64(time.Now().UnixNano()) / 1e9
3133
}
@@ -85,7 +87,7 @@ func pickleDeserialize(buffer []byte) (any, error) {
8587
}
8688

8789
// Serializes inputs, make a function call and return its ID
88-
func (f *Function) execFunctionCall(args []any, kwargs map[string]any, invocationType pb.FunctionCallInvocationType) (*string, error) {
90+
func (f *Function) createInput(args []any, kwargs map[string]any) (*pb.FunctionInput, error) {
8991
payload, err := pickleSerialize(pickle.Tuple{args, kwargs})
9092
if err != nil {
9193
return nil, err
@@ -102,132 +104,57 @@ func (f *Function) execFunctionCall(args []any, kwargs map[string]any, invocatio
102104
argsBlobId = &blobId
103105
}
104106

105-
// Single input sync invocation
106-
var functionInputs []*pb.FunctionPutInputsItem
107-
functionInputItem := pb.FunctionPutInputsItem_builder{
108-
Idx: 0,
109-
Input: pb.FunctionInput_builder{
110-
Args: argsBytes,
111-
ArgsBlobId: argsBlobId,
112-
DataFormat: pb.DataFormat_DATA_FORMAT_PICKLE,
113-
MethodName: f.MethodName,
114-
}.Build(),
115-
}.Build()
116-
functionInputs = append(functionInputs, functionInputItem)
117-
118-
functionMapResponse, err := client.FunctionMap(f.ctx, pb.FunctionMapRequest_builder{
119-
FunctionId: f.FunctionId,
120-
FunctionCallType: pb.FunctionCallType_FUNCTION_CALL_TYPE_UNARY,
121-
FunctionCallInvocationType: invocationType,
122-
PipelinedInputs: functionInputs,
123-
}.Build())
124-
if err != nil {
125-
return nil, fmt.Errorf("FunctionMap error: %w", err)
126-
}
127-
128-
functionCallId := functionMapResponse.GetFunctionCallId()
129-
return &functionCallId, nil
107+
return pb.FunctionInput_builder{
108+
Args: argsBytes,
109+
ArgsBlobId: argsBlobId,
110+
DataFormat: pb.DataFormat_DATA_FORMAT_PICKLE,
111+
MethodName: f.MethodName,
112+
}.Build(), nil
130113
}
131114

132115
// Remote executes a single input on a remote Function.
133116
func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
134-
invocationType := pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC
135-
functionCallId, err := f.execFunctionCall(args, kwargs, invocationType)
117+
input, err := f.createInput(args, kwargs)
136118
if err != nil {
137119
return nil, err
138120
}
139-
140-
return pollFunctionOutput(f.ctx, *functionCallId, nil)
141-
}
142-
143-
// Spawn starts running a single input on a remote function.
144-
func (f *Function) Spawn(args []any, kwargs map[string]any) (*FunctionCall, error) {
145-
invocationType := pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_ASYNC
146-
functionCallId, err := f.execFunctionCall(args, kwargs, invocationType)
121+
invocation, err := CreateControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
147122
if err != nil {
148123
return nil, err
149124
}
150-
functionCall := FunctionCall{
151-
FunctionCallId: *functionCallId,
152-
ctx: f.ctx,
153-
}
154-
return &functionCall, nil
155-
}
156-
157-
// Poll for ouputs for a given FunctionCall ID.
158-
func pollFunctionOutput(ctx context.Context, functionCallId string, timeout *time.Duration) (any, error) {
159-
startTime := time.Now()
160-
pollTimeout := outputsTimeout
161-
if timeout != nil {
162-
// Refresh backend call once per outputsTimeout.
163-
pollTimeout = min(*timeout, outputsTimeout)
164-
}
165-
125+
// TODO(ryan): Add tests for retries.
126+
retryCount := uint32(0)
166127
for {
167-
response, err := client.FunctionGetOutputs(ctx, pb.FunctionGetOutputsRequest_builder{
168-
FunctionCallId: functionCallId,
169-
MaxValues: 1,
170-
Timeout: float32(pollTimeout.Seconds()),
171-
LastEntryId: "0-0",
172-
ClearOnSuccess: true,
173-
RequestedAt: timeNowSeconds(),
174-
}.Build())
175-
if err != nil {
176-
return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err)
177-
}
178-
179-
// Output serialization may fail if any of the output items can't be deserialized
180-
// into a supported Go type. Users are expected to serialize outputs correctly.
181-
outputs := response.GetOutputs()
182-
if len(outputs) > 0 {
183-
return processResult(ctx, outputs[0].GetResult(), outputs[0].GetDataFormat())
128+
output, err := invocation.AwaitOutput(nil)
129+
if err == nil {
130+
return output, nil
184131
}
185-
186-
if timeout != nil {
187-
remainingTime := *timeout - time.Since(startTime)
188-
if remainingTime <= 0 {
189-
message := fmt.Sprintf("Timeout exceeded: %.1fs", timeout.Seconds())
190-
return nil, FunctionTimeoutError{message}
132+
if errors.As(err, &InternalFailure{}) && retryCount <= maxSystemRetries {
133+
if retryErr := invocation.Retry(retryCount); retryErr != nil {
134+
return nil, retryErr
191135
}
192-
pollTimeout = min(outputsTimeout, remainingTime)
136+
retryCount++
137+
continue
193138
}
139+
return nil, err
194140
}
195141
}
196142

197-
// processResult processes the result from an invocation.
198-
func processResult(ctx context.Context, result *pb.GenericResult, dataFormat pb.DataFormat) (any, error) {
199-
if result == nil {
200-
return nil, RemoteError{"Received null result from invocation"}
143+
// Spawn starts running a single input on a remote function.
144+
func (f *Function) Spawn(args []any, kwargs map[string]any) (*FunctionCall, error) {
145+
input, err := f.createInput(args, kwargs)
146+
if err != nil {
147+
return nil, err
201148
}
202-
203-
var data []byte
204-
var err error
205-
switch result.WhichDataOneof() {
206-
case pb.GenericResult_Data_case:
207-
data = result.GetData()
208-
case pb.GenericResult_DataBlobId_case:
209-
data, err = blobDownload(ctx, result.GetDataBlobId())
210-
if err != nil {
211-
return nil, err
212-
}
213-
case pb.GenericResult_DataOneof_not_set_case:
214-
data = nil
149+
invocation, err := CreateControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
150+
if err != nil {
151+
return nil, err
215152
}
216-
217-
switch result.GetStatus() {
218-
case pb.GenericResult_GENERIC_STATUS_TIMEOUT:
219-
return nil, FunctionTimeoutError{result.GetException()}
220-
case pb.GenericResult_GENERIC_STATUS_INTERNAL_FAILURE:
221-
return nil, InternalFailure{result.GetException()}
222-
case pb.GenericResult_GENERIC_STATUS_SUCCESS:
223-
// Proceed to the block below this switch statement.
224-
default:
225-
// In this case, `result.GetData()` may have a pickled user code exception with traceback
226-
// from Python. We ignore this and only take the string representation.
227-
return nil, RemoteError{result.GetException()}
153+
functionCall := FunctionCall{
154+
FunctionCallId: invocation.FunctionCallId,
155+
ctx: f.ctx,
228156
}
229-
230-
return deserializeDataFormat(data, dataFormat)
157+
return &functionCall, nil
231158
}
232159

233160
// blobUpload uploads a blob to storage and returns its ID.
@@ -272,40 +199,3 @@ func blobUpload(ctx context.Context, data []byte) (string, error) {
272199
return "", fmt.Errorf("missing upload URL in BlobCreate response")
273200
}
274201
}
275-
276-
// blobDownload downloads a blob by its ID.
277-
func blobDownload(ctx context.Context, blobId string) ([]byte, error) {
278-
resp, err := client.BlobGet(ctx, pb.BlobGetRequest_builder{
279-
BlobId: blobId,
280-
}.Build())
281-
if err != nil {
282-
return nil, err
283-
}
284-
s3resp, err := http.Get(resp.GetDownloadUrl())
285-
if err != nil {
286-
return nil, fmt.Errorf("failed to download blob: %w", err)
287-
}
288-
defer s3resp.Body.Close()
289-
buf, err := io.ReadAll(s3resp.Body)
290-
if err != nil {
291-
return nil, fmt.Errorf("failed to read blob data: %w", err)
292-
}
293-
return buf, nil
294-
}
295-
296-
func deserializeDataFormat(data []byte, dataFormat pb.DataFormat) (any, error) {
297-
switch dataFormat {
298-
case pb.DataFormat_DATA_FORMAT_PICKLE:
299-
return pickleDeserialize(data)
300-
case pb.DataFormat_DATA_FORMAT_ASGI:
301-
return nil, fmt.Errorf("ASGI data format is not supported in Go")
302-
case pb.DataFormat_DATA_FORMAT_GENERATOR_DONE:
303-
var done pb.GeneratorDone
304-
if err := proto.Unmarshal(data, &done); err != nil {
305-
return nil, fmt.Errorf("failed to unmarshal GeneratorDone: %w", err)
306-
}
307-
return &done, nil
308-
default:
309-
return nil, fmt.Errorf("unsupported data format: %s", dataFormat.String())
310-
}
311-
}

modal-go/function_call.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ func (fc *FunctionCall) Get(options *FunctionCallGetOptions) (any, error) {
4141
options = &FunctionCallGetOptions{}
4242
}
4343
ctx := fc.ctx
44-
return pollFunctionOutput(ctx, fc.FunctionCallId, options.Timeout)
44+
invocation := ControlPlaneInvocationFromFunctionCallId(ctx, fc.FunctionCallId)
45+
return invocation.AwaitOutput(options.Timeout)
4546
}
4647

4748
// FunctionCallCancelOptions are options for cancelling Function Calls.

0 commit comments

Comments
 (0)