Skip to content

Commit 56efc9b

Browse files
committed
Add go implementation
1 parent d338dcf commit 56efc9b

File tree

4 files changed

+170
-39
lines changed

4 files changed

+170
-39
lines changed

modal-go/client.go

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,13 @@ var defaultConfig config
6868
// defaultProfile is resolved at package init from MODAL_PROFILE, ~/.modal.toml, etc.
6969
var defaultProfile Profile
7070

71+
// client is the default Modal client that talks to the control plane.
7172
var client pb.ModalClientClient
7273

74+
// clients is a map of server URL => client.
75+
// The us-east client talks to the control plane; all other clients talk to input planes.
76+
var clients = map[string]pb.ModalClientClient{}
77+
7378
func init() {
7479
var err error
7580
defaultConfig, _ = readConfigFile()
@@ -78,25 +83,39 @@ func init() {
7883
panic(err) // fail fast – credentials are required to proceed
7984
}
8085

81-
_, client, err = newClient(defaultProfile)
86+
client, err = getOrCreateClient(defaultProfile.ServerURL)
8287
if err != nil {
8388
panic(err)
8489
}
8590
}
8691

87-
// newClient dials api.modal.com with auth/timeout/retry interceptors installed.
92+
// getOrCreateClient returns a client for the given server URL, creating it if it doesn't exist.
93+
func getOrCreateClient(serverURL string) (pb.ModalClientClient, error) {
94+
if client, ok := clients[serverURL]; ok {
95+
return client, nil
96+
}
97+
98+
_, client, err := createClient(serverURL)
99+
if err != nil {
100+
return nil, err
101+
}
102+
clients[serverURL] = client
103+
return client, nil
104+
}
105+
106+
// createClient dials the given server URL with auth/timeout/retry interceptors installed.
88107
// It returns (conn, stub). Close the conn when done.
89-
func newClient(profile Profile) (*grpc.ClientConn, pb.ModalClientClient, error) {
108+
func createClient(serverURL string) (*grpc.ClientConn, pb.ModalClientClient, error) {
90109
var target string
91110
var creds credentials.TransportCredentials
92-
if strings.HasPrefix(profile.ServerURL, "https://") {
93-
target = strings.TrimPrefix(profile.ServerURL, "https://")
111+
if strings.HasPrefix(serverURL, "https://") {
112+
target = strings.TrimPrefix(serverURL, "https://")
94113
creds = credentials.NewTLS(&tls.Config{})
95-
} else if strings.HasPrefix(profile.ServerURL, "http://") {
96-
target = strings.TrimPrefix(profile.ServerURL, "http://")
114+
} else if strings.HasPrefix(serverURL, "http://") {
115+
target = strings.TrimPrefix(serverURL, "http://")
97116
creds = insecure.NewCredentials()
98117
} else {
99-
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid server URL: %s", profile.ServerURL)
118+
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid server URL: %s", serverURL)
100119
}
101120

102121
conn, err := grpc.NewClient(

modal-go/function.go

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"encoding/base64"
1111
"errors"
1212
"fmt"
13+
"google.golang.org/grpc"
14+
"google.golang.org/grpc/metadata"
1315
"net/http"
1416
"time"
1517

@@ -34,9 +36,10 @@ func timeNowSeconds() float64 {
3436

3537
// Function references a deployed Modal Function.
3638
type Function struct {
37-
FunctionId string
38-
MethodName *string // used for class methods
39-
ctx context.Context
39+
FunctionId string
40+
MethodName *string // used for class methods
41+
InputPlaneURL *string // if nil, use control plane
42+
ctx context.Context
4043
}
4144

4245
// FunctionLookup looks up an existing Function.
@@ -46,12 +49,13 @@ func FunctionLookup(ctx context.Context, appName string, name string, options *L
4649
}
4750
ctx = clientContext(ctx)
4851

52+
var header, trailer metadata.MD
4953
resp, err := client.FunctionGet(ctx, pb.FunctionGetRequest_builder{
5054
AppName: appName,
5155
ObjectTag: name,
5256
Namespace: pb.DeploymentNamespace_DEPLOYMENT_NAMESPACE_WORKSPACE,
5357
EnvironmentName: environmentName(options.Environment),
54-
}.Build())
58+
}.Build(), grpc.Header(&header), grpc.Trailer(&trailer))
5559

5660
if status, ok := status.FromError(err); ok && status.Code() == codes.NotFound {
5761
return nil, NotFoundError{fmt.Sprintf("function '%s/%s' not found", appName, name)}
@@ -60,7 +64,22 @@ func FunctionLookup(ctx context.Context, appName string, name string, options *L
6064
return nil, err
6165
}
6266

63-
return &Function{FunctionId: resp.GetFunctionId(), ctx: ctx}, nil
67+
// Attach x-modal-auth-token to all future requests.
68+
authTokenArray := header.Get("x-modal-auth-token")
69+
if len(authTokenArray) == 0 {
70+
authTokenArray = trailer.Get("x-modal-auth-token")
71+
}
72+
if len(authTokenArray) > 0 {
73+
authToken := authTokenArray[0]
74+
ctx = metadata.AppendToOutgoingContext(ctx, "x-modal-auth-token", authToken)
75+
}
76+
var inputPlaneUrl *string
77+
if meta := resp.GetHandleMetadata(); meta != nil {
78+
if url := meta.GetInputPlaneUrl(); url != "" {
79+
inputPlaneUrl = &url
80+
}
81+
}
82+
return &Function{FunctionId: resp.GetFunctionId(), InputPlaneURL: inputPlaneUrl, ctx: ctx}, nil
6483
}
6584

6685
// Serialize Go data types to the Python pickle format.
@@ -118,7 +137,7 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
118137
if err != nil {
119138
return nil, err
120139
}
121-
invocation, err := CreateControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
140+
invocation, err := f.createRemoteInvocation(input)
122141
if err != nil {
123142
return nil, err
124143
}
@@ -140,6 +159,14 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
140159
}
141160
}
142161

162+
// createRemoteInvocation creates an Invocation using either the input plane or control plane.
163+
func (f *Function) createRemoteInvocation(input *pb.FunctionInput) (Invocation, error) {
164+
if f.InputPlaneURL != nil {
165+
return CreateInputPlaneInvocation(f.ctx, *f.InputPlaneURL, f.FunctionId, input)
166+
}
167+
return CreateControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
168+
}
169+
143170
// Spawn starts running a single input on a remote function.
144171
func (f *Function) Spawn(args []any, kwargs map[string]any) (*FunctionCall, error) {
145172
input, err := f.createInput(args, kwargs)

modal-go/invocation.go

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@ import (
1111
"google.golang.org/protobuf/proto"
1212
)
1313

14+
// GetOutput is a function type that takes a timeout in milliseconds and returns a FunctionGetOutputsItem or nil, and an error.
15+
type GetOutput func(timeout time.Duration) (*pb.FunctionGetOutputsItem, error)
16+
1417
type Invocation interface {
15-
<<<<<<< HEAD
1618
AwaitOutput(timeout *time.Duration) (any, error)
17-
=======
18-
AwaitOutput(timeout ...int) (interface{}, error)
19-
>>>>>>> a38d9c2 (Add go implementation)
20-
Retry(retryCount int) error
19+
Retry(retryCount uint32) error
2120
}
2221

2322
// ControlPlaneInvocation implements the Invocation interface.
@@ -60,12 +59,8 @@ func ControlPlaneInvocationFromFunctionCallId(ctx context.Context, functionCallI
6059
return &ControlPlaneInvocation{FunctionCallId: functionCallId, ctx: ctx}
6160
}
6261

63-
<<<<<<< HEAD
6462
func (c *ControlPlaneInvocation) AwaitOutput(timeout *time.Duration) (any, error) {
65-
=======
66-
func (c *ControlPlaneInvocation) AwaitOutput(timeout *time.Duration) (interface{}, error) {
67-
>>>>>>> a38d9c2 (Add go implementation)
68-
return pollFunctionOutput(c.ctx, c.FunctionCallId, timeout)
63+
return pollFunctionOutput(c.ctx, c.getOutput, timeout)
6964
}
7065

7166
func (c *ControlPlaneInvocation) Retry(retryCount uint32) error {
@@ -88,8 +83,94 @@ func (c *ControlPlaneInvocation) Retry(retryCount uint32) error {
8883
return nil
8984
}
9085

91-
// Poll for outputs for a given FunctionCall ID.
92-
func pollFunctionOutput(ctx context.Context, functionCallId string, timeout *time.Duration) (any, error) {
86+
// getOutput fetches the output for the current function call with a timeout in milliseconds.
87+
func (c *ControlPlaneInvocation) getOutput(timeout time.Duration) (*pb.FunctionGetOutputsItem, error) {
88+
response, err := client.FunctionGetOutputs(c.ctx, pb.FunctionGetOutputsRequest_builder{
89+
FunctionCallId: c.FunctionCallId,
90+
MaxValues: 1,
91+
Timeout: float32(timeout.Seconds()),
92+
LastEntryId: "0-0",
93+
ClearOnSuccess: true,
94+
RequestedAt: timeNowSeconds(),
95+
}.Build())
96+
if err != nil {
97+
return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err)
98+
}
99+
outputs := response.GetOutputs()
100+
if len(outputs) > 0 {
101+
return outputs[0], nil
102+
}
103+
return nil, nil
104+
}
105+
106+
// InputPlaneInvocation implements the Invocation interface for the input plane.
107+
type InputPlaneInvocation struct {
108+
client pb.ModalClientClient
109+
functionId string
110+
input *pb.FunctionPutInputsItem
111+
attemptToken string
112+
ctx context.Context
113+
}
114+
115+
// CreateInputPlaneInvocation creates a new InputPlaneInvocation by starting an attempt.
116+
func CreateInputPlaneInvocation(ctx context.Context, inputPlaneURL string, functionId string, input *pb.FunctionInput) (*InputPlaneInvocation, error) {
117+
functionPutInputsItem := pb.FunctionPutInputsItem_builder{
118+
Idx: 0,
119+
Input: input,
120+
}.Build()
121+
client, err := getOrCreateClient(inputPlaneURL)
122+
if err != nil {
123+
return nil, err
124+
}
125+
attemptStartResp, err := client.AttemptStart(ctx, pb.AttemptStartRequest_builder{
126+
FunctionId: functionId,
127+
Input: functionPutInputsItem,
128+
}.Build())
129+
if err != nil {
130+
return nil, err
131+
}
132+
return &InputPlaneInvocation{
133+
client: client,
134+
functionId: functionId,
135+
input: functionPutInputsItem,
136+
attemptToken: attemptStartResp.GetAttemptToken(),
137+
ctx: ctx,
138+
}, nil
139+
}
140+
141+
// AwaitOutput waits for the output with an optional timeout.
142+
func (i *InputPlaneInvocation) AwaitOutput(timeout *time.Duration) (any, error) {
143+
return pollFunctionOutput(i.ctx, i.getOutput, timeout)
144+
}
145+
146+
// getOutput fetches the output for the current attempt.
147+
func (i *InputPlaneInvocation) getOutput(timeout time.Duration) (*pb.FunctionGetOutputsItem, error) {
148+
resp, err := i.client.AttemptAwait(i.ctx, pb.AttemptAwaitRequest_builder{
149+
AttemptToken: i.attemptToken,
150+
RequestedAt: timeNowSeconds(),
151+
TimeoutSecs: float32(timeout.Seconds()),
152+
}.Build())
153+
if err != nil {
154+
return nil, fmt.Errorf("AttemptAwait failed: %w", err)
155+
}
156+
return resp.GetOutput(), nil
157+
}
158+
159+
// Retry retries the invocation.
160+
func (i *InputPlaneInvocation) Retry(retryCount uint32) error {
161+
resp, err := i.client.AttemptRetry(context.Background(), pb.AttemptRetryRequest_builder{
162+
FunctionId: i.functionId,
163+
Input: i.input,
164+
AttemptToken: i.attemptToken,
165+
}.Build())
166+
if err != nil {
167+
return err
168+
}
169+
i.attemptToken = resp.GetAttemptToken()
170+
return nil
171+
}
172+
173+
func pollFunctionOutput(ctx context.Context, getOutput GetOutput, timeout *time.Duration) (any, error) {
93174
startTime := time.Now()
94175
pollTimeout := outputsTimeout
95176
if timeout != nil {
@@ -98,23 +179,14 @@ func pollFunctionOutput(ctx context.Context, functionCallId string, timeout *tim
98179
}
99180

100181
for {
101-
response, err := client.FunctionGetOutputs(ctx, pb.FunctionGetOutputsRequest_builder{
102-
FunctionCallId: functionCallId,
103-
MaxValues: 1,
104-
Timeout: float32(pollTimeout.Seconds()),
105-
LastEntryId: "0-0",
106-
ClearOnSuccess: true,
107-
RequestedAt: timeNowSeconds(),
108-
}.Build())
182+
output, err := getOutput(pollTimeout)
109183
if err != nil {
110-
return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err)
184+
return nil, err
111185
}
112-
113186
// Output serialization may fail if any of the output items can't be deserialized
114187
// into a supported Go type. Users are expected to serialize outputs correctly.
115-
outputs := response.GetOutputs()
116-
if len(outputs) > 0 {
117-
return processResult(ctx, outputs[0].GetResult(), outputs[0].GetDataFormat())
188+
if output != nil {
189+
return processResult(ctx, output.GetResult(), output.GetDataFormat())
118190
}
119191

120192
if timeout != nil {

modal-go/test/function_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,16 @@ func TestFunctionNotFound(t *testing.T) {
4747
_, err := modal.FunctionLookup(context.Background(), "libmodal-test-support", "not_a_real_function", nil)
4848
g.Expect(err).Should(gomega.BeAssignableToTypeOf(modal.NotFoundError{}))
4949
}
50+
51+
func TestFunctionCallInputPlane(t *testing.T) {
52+
t.Parallel()
53+
g := gomega.NewWithT(t)
54+
55+
function, err := modal.FunctionLookup(context.Background(), "libmodal-test-support", "input_plane", nil)
56+
g.Expect(err).ShouldNot(gomega.HaveOccurred())
57+
58+
// Try the same, but with args.
59+
result, err := function.Remote([]any{"hello"}, nil)
60+
g.Expect(err).ShouldNot(gomega.HaveOccurred())
61+
g.Expect(result).Should(gomega.Equal("output: hello"))
62+
}

0 commit comments

Comments
 (0)