@@ -8,16 +8,15 @@ import (
88 "crypto/md5"
99 "crypto/sha256"
1010 "encoding/base64"
11+ "errors"
1112 "fmt"
12- "io"
1313 "net/http"
1414 "time"
1515
1616 pickle "github.com/kisielk/og-rek"
1717 pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto"
1818 "google.golang.org/grpc/codes"
1919 "google.golang.org/grpc/status"
20- "google.golang.org/protobuf/proto"
2120)
2221
2322// From: modal/_utils/blob_utils.py
@@ -26,6 +25,9 @@ const maxObjectSizeBytes int = 2 * 1024 * 1024 // 2 MiB
2625// From: modal-client/modal/_utils/function_utils.py
2726const outputsTimeout time.Duration = time .Second * 55
2827
28+ // From: client/modal/_functions.py
29+ const maxSystemRetries = 8
30+
2931func timeNowSeconds () float64 {
3032 return float64 (time .Now ().UnixNano ()) / 1e9
3133}
@@ -85,7 +87,7 @@ func pickleDeserialize(buffer []byte) (any, error) {
8587}
8688
8789// Serializes inputs, make a function call and return its ID
88- func (f * Function ) execFunctionCall (args []any , kwargs map [string ]any , invocationType pb. FunctionCallInvocationType ) (* string , error ) {
90+ func (f * Function ) createInput (args []any , kwargs map [string ]any ) (* pb. FunctionInput , error ) {
8991 payload , err := pickleSerialize (pickle.Tuple {args , kwargs })
9092 if err != nil {
9193 return nil , err
@@ -102,132 +104,57 @@ func (f *Function) execFunctionCall(args []any, kwargs map[string]any, invocatio
102104 argsBlobId = & blobId
103105 }
104106
105- // Single input sync invocation
106- var functionInputs []* pb.FunctionPutInputsItem
107- functionInputItem := pb.FunctionPutInputsItem_builder {
108- Idx : 0 ,
109- Input : pb.FunctionInput_builder {
110- Args : argsBytes ,
111- ArgsBlobId : argsBlobId ,
112- DataFormat : pb .DataFormat_DATA_FORMAT_PICKLE ,
113- MethodName : f .MethodName ,
114- }.Build (),
115- }.Build ()
116- functionInputs = append (functionInputs , functionInputItem )
117-
118- functionMapResponse , err := client .FunctionMap (f .ctx , pb.FunctionMapRequest_builder {
119- FunctionId : f .FunctionId ,
120- FunctionCallType : pb .FunctionCallType_FUNCTION_CALL_TYPE_UNARY ,
121- FunctionCallInvocationType : invocationType ,
122- PipelinedInputs : functionInputs ,
123- }.Build ())
124- if err != nil {
125- return nil , fmt .Errorf ("FunctionMap error: %w" , err )
126- }
127-
128- functionCallId := functionMapResponse .GetFunctionCallId ()
129- return & functionCallId , nil
107+ return pb.FunctionInput_builder {
108+ Args : argsBytes ,
109+ ArgsBlobId : argsBlobId ,
110+ DataFormat : pb .DataFormat_DATA_FORMAT_PICKLE ,
111+ MethodName : f .MethodName ,
112+ }.Build (), nil
130113}
131114
132115// Remote executes a single input on a remote Function.
133116func (f * Function ) Remote (args []any , kwargs map [string ]any ) (any , error ) {
134- invocationType := pb .FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC
135- functionCallId , err := f .execFunctionCall (args , kwargs , invocationType )
117+ input , err := f .createInput (args , kwargs )
136118 if err != nil {
137119 return nil , err
138120 }
139-
140- return pollFunctionOutput (f .ctx , * functionCallId , nil )
141- }
142-
143- // Spawn starts running a single input on a remote function.
144- func (f * Function ) Spawn (args []any , kwargs map [string ]any ) (* FunctionCall , error ) {
145- invocationType := pb .FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_ASYNC
146- functionCallId , err := f .execFunctionCall (args , kwargs , invocationType )
121+ invocation , err := CreateControlPlaneInvocation (f .ctx , f .FunctionId , input , pb .FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC )
147122 if err != nil {
148123 return nil , err
149124 }
150- functionCall := FunctionCall {
151- FunctionCallId : * functionCallId ,
152- ctx : f .ctx ,
153- }
154- return & functionCall , nil
155- }
156-
157- // Poll for ouputs for a given FunctionCall ID.
158- func pollFunctionOutput (ctx context.Context , functionCallId string , timeout * time.Duration ) (any , error ) {
159- startTime := time .Now ()
160- pollTimeout := outputsTimeout
161- if timeout != nil {
162- // Refresh backend call once per outputsTimeout.
163- pollTimeout = min (* timeout , outputsTimeout )
164- }
165-
125+ // TODO(ryan): Add tests for retries.
126+ retryCount := uint32 (0 )
166127 for {
167- response , err := client .FunctionGetOutputs (ctx , pb.FunctionGetOutputsRequest_builder {
168- FunctionCallId : functionCallId ,
169- MaxValues : 1 ,
170- Timeout : float32 (pollTimeout .Seconds ()),
171- LastEntryId : "0-0" ,
172- ClearOnSuccess : true ,
173- RequestedAt : timeNowSeconds (),
174- }.Build ())
175- if err != nil {
176- return nil , fmt .Errorf ("FunctionGetOutputs failed: %w" , err )
177- }
178-
179- // Output serialization may fail if any of the output items can't be deserialized
180- // into a supported Go type. Users are expected to serialize outputs correctly.
181- outputs := response .GetOutputs ()
182- if len (outputs ) > 0 {
183- return processResult (ctx , outputs [0 ].GetResult (), outputs [0 ].GetDataFormat ())
128+ output , err := invocation .AwaitOutput (nil )
129+ if err == nil {
130+ return output , nil
184131 }
185-
186- if timeout != nil {
187- remainingTime := * timeout - time .Since (startTime )
188- if remainingTime <= 0 {
189- message := fmt .Sprintf ("Timeout exceeded: %.1fs" , timeout .Seconds ())
190- return nil , FunctionTimeoutError {message }
132+ if errors .As (err , & InternalFailure {}) && retryCount <= maxSystemRetries {
133+ if retryErr := invocation .Retry (retryCount ); retryErr != nil {
134+ return nil , retryErr
191135 }
192- pollTimeout = min (outputsTimeout , remainingTime )
136+ retryCount ++
137+ continue
193138 }
139+ return nil , err
194140 }
195141}
196142
197- // processResult processes the result from an invocation.
198- func processResult (ctx context.Context , result * pb.GenericResult , dataFormat pb.DataFormat ) (any , error ) {
199- if result == nil {
200- return nil , RemoteError {"Received null result from invocation" }
143+ // Spawn starts running a single input on a remote function.
144+ func (f * Function ) Spawn (args []any , kwargs map [string ]any ) (* FunctionCall , error ) {
145+ input , err := f .createInput (args , kwargs )
146+ if err != nil {
147+ return nil , err
201148 }
202-
203- var data []byte
204- var err error
205- switch result .WhichDataOneof () {
206- case pb .GenericResult_Data_case :
207- data = result .GetData ()
208- case pb .GenericResult_DataBlobId_case :
209- data , err = blobDownload (ctx , result .GetDataBlobId ())
210- if err != nil {
211- return nil , err
212- }
213- case pb .GenericResult_DataOneof_not_set_case :
214- data = nil
149+ invocation , err := CreateControlPlaneInvocation (f .ctx , f .FunctionId , input , pb .FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC )
150+ if err != nil {
151+ return nil , err
215152 }
216-
217- switch result .GetStatus () {
218- case pb .GenericResult_GENERIC_STATUS_TIMEOUT :
219- return nil , FunctionTimeoutError {result .GetException ()}
220- case pb .GenericResult_GENERIC_STATUS_INTERNAL_FAILURE :
221- return nil , InternalFailure {result .GetException ()}
222- case pb .GenericResult_GENERIC_STATUS_SUCCESS :
223- // Proceed to the block below this switch statement.
224- default :
225- // In this case, `result.GetData()` may have a pickled user code exception with traceback
226- // from Python. We ignore this and only take the string representation.
227- return nil , RemoteError {result .GetException ()}
153+ functionCall := FunctionCall {
154+ FunctionCallId : invocation .FunctionCallId ,
155+ ctx : f .ctx ,
228156 }
229-
230- return deserializeDataFormat (data , dataFormat )
157+ return & functionCall , nil
231158}
232159
233160// blobUpload uploads a blob to storage and returns its ID.
@@ -272,40 +199,3 @@ func blobUpload(ctx context.Context, data []byte) (string, error) {
272199 return "" , fmt .Errorf ("missing upload URL in BlobCreate response" )
273200 }
274201}
275-
276- // blobDownload downloads a blob by its ID.
277- func blobDownload (ctx context.Context , blobId string ) ([]byte , error ) {
278- resp , err := client .BlobGet (ctx , pb.BlobGetRequest_builder {
279- BlobId : blobId ,
280- }.Build ())
281- if err != nil {
282- return nil , err
283- }
284- s3resp , err := http .Get (resp .GetDownloadUrl ())
285- if err != nil {
286- return nil , fmt .Errorf ("failed to download blob: %w" , err )
287- }
288- defer s3resp .Body .Close ()
289- buf , err := io .ReadAll (s3resp .Body )
290- if err != nil {
291- return nil , fmt .Errorf ("failed to read blob data: %w" , err )
292- }
293- return buf , nil
294- }
295-
296- func deserializeDataFormat (data []byte , dataFormat pb.DataFormat ) (any , error ) {
297- switch dataFormat {
298- case pb .DataFormat_DATA_FORMAT_PICKLE :
299- return pickleDeserialize (data )
300- case pb .DataFormat_DATA_FORMAT_ASGI :
301- return nil , fmt .Errorf ("ASGI data format is not supported in Go" )
302- case pb .DataFormat_DATA_FORMAT_GENERATOR_DONE :
303- var done pb.GeneratorDone
304- if err := proto .Unmarshal (data , & done ); err != nil {
305- return nil , fmt .Errorf ("failed to unmarshal GeneratorDone: %w" , err )
306- }
307- return & done , nil
308- default :
309- return nil , fmt .Errorf ("unsupported data format: %s" , dataFormat .String ())
310- }
311- }
0 commit comments