Skip to content

Commit 94d225f

Browse files
authored
Add support for input plane invocation (#32)
* Add support for input plane * Address PR review comments * Remove catch
1 parent 7ae30cd commit 94d225f

File tree

9 files changed

+394
-53
lines changed

9 files changed

+394
-53
lines changed

modal-go/client.go

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,16 @@ var defaultProfile Profile
7373
// clientProfile is the actual profile, from defaultProfile + InitializeClient().
7474
var clientProfile Profile
7575

76+
// client is the default Modal client that talks to the control plane.
7677
var client pb.ModalClientClient
7778

79+
// clients is a map of server URL to input-plane client.
80+
var inputPlaneClients = map[string]pb.ModalClientClient{}
81+
82+
// authToken is the auth token received from the control plane on the first request, and sent with all
83+
// subsequent requests to both the control plane and the input plane.
84+
var authToken string
85+
7886
func init() {
7987
defaultConfig, _ = readConfigFile()
8088
defaultProfile = getProfile(os.Getenv("MODAL_PROFILE"))
@@ -108,7 +116,23 @@ func InitializeClient(options ClientOptions) error {
108116
return err
109117
}
110118

111-
// newClient dials api.modal.com with auth/timeout/retry interceptors installed.
119+
// getOrCreateInputPlaneClient returns a client for the given server URL, creating it if it doesn't exist.
120+
func getOrCreateInputPlaneClient(serverURL string) (pb.ModalClientClient, error) {
121+
if client, ok := inputPlaneClients[serverURL]; ok {
122+
return client, nil
123+
}
124+
125+
profile := clientProfile
126+
profile.ServerURL = serverURL
127+
_, client, err := newClient(profile)
128+
if err != nil {
129+
return nil, err
130+
}
131+
inputPlaneClients[serverURL] = client
132+
return client, nil
133+
}
134+
135+
// newClient dials the given server URL with auth/timeout/retry interceptors installed.
112136
// It returns (conn, stub). Close the conn when done.
113137
func newClient(profile Profile) (*grpc.ClientConn, pb.ModalClientClient, error) {
114138
var target string
@@ -131,6 +155,7 @@ func newClient(profile Profile) (*grpc.ClientConn, pb.ModalClientClient, error)
131155
grpc.MaxCallSendMsgSize(maxMessageSize),
132156
),
133157
grpc.WithChainUnaryInterceptor(
158+
authTokenInterceptor(),
134159
retryInterceptor(),
135160
timeoutInterceptor(),
136161
),
@@ -157,6 +182,37 @@ func clientContext(ctx context.Context) (context.Context, error) {
157182
), nil
158183
}
159184

185+
// authTokenInterceptor handles sending and receiving the "x-modal-auth-token" header.
186+
// We receive an auth token from the control plane on our first request. We then include that auth token in every
187+
// subsequent request to both the control plane and the input plane.
188+
func authTokenInterceptor() grpc.UnaryClientInterceptor {
189+
return func(
190+
ctx context.Context,
191+
method string,
192+
req, reply any,
193+
cc *grpc.ClientConn,
194+
inv grpc.UnaryInvoker,
195+
opts ...grpc.CallOption,
196+
) error {
197+
var headers, trailers metadata.MD
198+
// Add authToken to outgoing context if it's set
199+
if authToken != "" {
200+
ctx = metadata.AppendToOutgoingContext(ctx, "x-modal-auth-token", authToken)
201+
}
202+
opts = append(opts, grpc.Header(&headers), grpc.Trailer(&trailers))
203+
err := inv(ctx, method, req, reply, cc, opts...)
204+
// If we're talking to the control plane, and no auth token was sent, it will return one.
205+
// The python server returns it in the trailers, the worker returns it in the headers.
206+
if val, ok := headers["x-modal-auth-token"]; ok {
207+
authToken = val[0]
208+
} else if val, ok := trailers["x-modal-auth-token"]; ok {
209+
authToken = val[0]
210+
}
211+
212+
return err
213+
}
214+
}
215+
160216
func timeoutInterceptor() grpc.UnaryClientInterceptor {
161217
return func(
162218
ctx context.Context,

modal-go/function.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ func timeNowSeconds() float64 {
3434

3535
// Function references a deployed Modal Function.
3636
type Function struct {
37-
FunctionId string
38-
MethodName *string // used for class methods
39-
ctx context.Context
37+
FunctionId string
38+
MethodName *string // used for class methods
39+
InputPlaneURL *string // if nil, use control plane
40+
ctx context.Context
4041
}
4142

4243
// FunctionLookup looks up an existing Function.
@@ -64,7 +65,13 @@ func FunctionLookup(ctx context.Context, appName string, name string, options *L
6465
return nil, err
6566
}
6667

67-
return &Function{FunctionId: resp.GetFunctionId(), ctx: ctx}, nil
68+
var inputPlaneUrl *string
69+
if meta := resp.GetHandleMetadata(); meta != nil {
70+
if url := meta.GetInputPlaneUrl(); url != "" {
71+
inputPlaneUrl = &url
72+
}
73+
}
74+
return &Function{FunctionId: resp.GetFunctionId(), InputPlaneURL: inputPlaneUrl, ctx: ctx}, nil
6875
}
6976

7077
// Serialize Go data types to the Python pickle format.
@@ -122,7 +129,7 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
122129
if err != nil {
123130
return nil, err
124131
}
125-
invocation, err := createControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
132+
invocation, err := f.createRemoteInvocation(input)
126133
if err != nil {
127134
return nil, err
128135
}
@@ -144,6 +151,14 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
144151
}
145152
}
146153

154+
// createRemoteInvocation creates an Invocation using either the input plane or control plane.
155+
func (f *Function) createRemoteInvocation(input *pb.FunctionInput) (invocation, error) {
156+
if f.InputPlaneURL != nil {
157+
return createInputPlaneInvocation(f.ctx, *f.InputPlaneURL, f.FunctionId, input)
158+
}
159+
return createControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
160+
}
161+
147162
// Spawn starts running a single input on a remote function.
148163
func (f *Function) Spawn(args []any, kwargs map[string]any) (*FunctionCall, error) {
149164
input, err := f.createInput(args, kwargs)

modal-go/invocation.go

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ import (
1111
"google.golang.org/protobuf/proto"
1212
)
1313

14+
type invocation interface {
15+
awaitOutput(timeout *time.Duration) (any, error)
16+
retry(retryCount uint32) error
17+
}
18+
1419
// controlPlaneInvocation implements the invocation interface.
1520
type controlPlaneInvocation struct {
1621
FunctionCallId string
@@ -52,7 +57,7 @@ func controlPlaneInvocationFromFunctionCallId(ctx context.Context, functionCallI
5257
}
5358

5459
func (c *controlPlaneInvocation) awaitOutput(timeout *time.Duration) (any, error) {
55-
return pollFunctionOutput(c.ctx, c.FunctionCallId, timeout)
60+
return pollFunctionOutput(c.ctx, c.getOutput, timeout)
5661
}
5762

5863
func (c *controlPlaneInvocation) retry(retryCount uint32) error {
@@ -75,8 +80,102 @@ func (c *controlPlaneInvocation) retry(retryCount uint32) error {
7580
return nil
7681
}
7782

78-
// Poll for outputs for a given FunctionCall ID.
79-
func pollFunctionOutput(ctx context.Context, functionCallId string, timeout *time.Duration) (any, error) {
83+
// getOutput fetches the output for the current function call with a timeout in milliseconds.
84+
func (c *controlPlaneInvocation) getOutput(timeout time.Duration) (*pb.FunctionGetOutputsItem, error) {
85+
response, err := client.FunctionGetOutputs(c.ctx, pb.FunctionGetOutputsRequest_builder{
86+
FunctionCallId: c.FunctionCallId,
87+
MaxValues: 1,
88+
Timeout: float32(timeout.Seconds()),
89+
LastEntryId: "0-0",
90+
ClearOnSuccess: true,
91+
RequestedAt: timeNowSeconds(),
92+
}.Build())
93+
if err != nil {
94+
return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err)
95+
}
96+
outputs := response.GetOutputs()
97+
if len(outputs) > 0 {
98+
return outputs[0], nil
99+
}
100+
return nil, nil
101+
}
102+
103+
// InputPlaneInvocation implements the Invocation interface for the input plane.
104+
type inputPlaneInvocation struct {
105+
client pb.ModalClientClient
106+
functionId string
107+
input *pb.FunctionPutInputsItem
108+
attemptToken string
109+
ctx context.Context
110+
}
111+
112+
// CreateInputPlaneInvocation creates a new InputPlaneInvocation by starting an attempt.
113+
func createInputPlaneInvocation(ctx context.Context, inputPlaneURL string, functionId string, input *pb.FunctionInput) (*inputPlaneInvocation, error) {
114+
functionPutInputsItem := pb.FunctionPutInputsItem_builder{
115+
Idx: 0,
116+
Input: input,
117+
}.Build()
118+
client, err := getOrCreateInputPlaneClient(inputPlaneURL)
119+
if err != nil {
120+
return nil, err
121+
}
122+
attemptStartResp, err := client.AttemptStart(ctx, pb.AttemptStartRequest_builder{
123+
FunctionId: functionId,
124+
Input: functionPutInputsItem,
125+
}.Build())
126+
if err != nil {
127+
return nil, err
128+
}
129+
return &inputPlaneInvocation{
130+
client: client,
131+
functionId: functionId,
132+
input: functionPutInputsItem,
133+
attemptToken: attemptStartResp.GetAttemptToken(),
134+
ctx: ctx,
135+
}, nil
136+
}
137+
138+
// awaitOutput waits for the output with an optional timeout.
139+
func (i *inputPlaneInvocation) awaitOutput(timeout *time.Duration) (any, error) {
140+
return pollFunctionOutput(i.ctx, i.getOutput, timeout)
141+
}
142+
143+
// getOutput fetches the output for the current attempt.
144+
func (i *inputPlaneInvocation) getOutput(timeout time.Duration) (*pb.FunctionGetOutputsItem, error) {
145+
resp, err := i.client.AttemptAwait(i.ctx, pb.AttemptAwaitRequest_builder{
146+
AttemptToken: i.attemptToken,
147+
RequestedAt: timeNowSeconds(),
148+
TimeoutSecs: float32(timeout.Seconds()),
149+
}.Build())
150+
if err != nil {
151+
return nil, fmt.Errorf("AttemptAwait failed: %w", err)
152+
}
153+
return resp.GetOutput(), nil
154+
}
155+
156+
// retry retries the invocation.
157+
func (i *inputPlaneInvocation) retry(retryCount uint32) error {
158+
// We ignore retryCount - it is used only by controlPlaneInvocation.
159+
resp, err := i.client.AttemptRetry(context.Background(), pb.AttemptRetryRequest_builder{
160+
FunctionId: i.functionId,
161+
Input: i.input,
162+
AttemptToken: i.attemptToken,
163+
}.Build())
164+
if err != nil {
165+
return err
166+
}
167+
i.attemptToken = resp.GetAttemptToken()
168+
return nil
169+
}
170+
171+
// getOutput is a function type that takes a timeout and returns a FunctionGetOutputsItem or nil, and an error.
172+
// Used by `pollForOutputs` to fetch from either the control plane or the input plane, depending on the implementation.
173+
type getOutput func(timeout time.Duration) (*pb.FunctionGetOutputsItem, error)
174+
175+
// pollFunctionOutput repeatedly tries to fetch an output using the provided `getOutput` function, and the specified
176+
// timeout value. We use a timeout value of 55 seconds if the caller does not specify a timeout value, or if the
177+
// specified timeout value is greater than 55 seconds.
178+
func pollFunctionOutput(ctx context.Context, getOutput getOutput, timeout *time.Duration) (any, error) {
80179
startTime := time.Now()
81180
pollTimeout := outputsTimeout
82181
if timeout != nil {
@@ -85,23 +184,14 @@ func pollFunctionOutput(ctx context.Context, functionCallId string, timeout *tim
85184
}
86185

87186
for {
88-
response, err := client.FunctionGetOutputs(ctx, pb.FunctionGetOutputsRequest_builder{
89-
FunctionCallId: functionCallId,
90-
MaxValues: 1,
91-
Timeout: float32(pollTimeout.Seconds()),
92-
LastEntryId: "0-0",
93-
ClearOnSuccess: true,
94-
RequestedAt: timeNowSeconds(),
95-
}.Build())
187+
output, err := getOutput(pollTimeout)
96188
if err != nil {
97-
return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err)
189+
return nil, err
98190
}
99-
100191
// Output serialization may fail if any of the output items can't be deserialized
101192
// into a supported Go type. Users are expected to serialize outputs correctly.
102-
outputs := response.GetOutputs()
103-
if len(outputs) > 0 {
104-
return processResult(ctx, outputs[0].GetResult(), outputs[0].GetDataFormat())
193+
if output != nil {
194+
return processResult(ctx, output.GetResult(), output.GetDataFormat())
105195
}
106196

107197
if timeout != nil {

modal-go/test/function_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,16 @@ func TestFunctionNotFound(t *testing.T) {
4747
_, err := modal.FunctionLookup(context.Background(), "libmodal-test-support", "not_a_real_function", nil)
4848
g.Expect(err).Should(gomega.BeAssignableToTypeOf(modal.NotFoundError{}))
4949
}
50+
51+
func TestFunctionCallInputPlane(t *testing.T) {
52+
t.Parallel()
53+
g := gomega.NewWithT(t)
54+
55+
function, err := modal.FunctionLookup(context.Background(), "libmodal-test-support", "input_plane", nil)
56+
g.Expect(err).ShouldNot(gomega.HaveOccurred())
57+
58+
// Try the same, but with args.
59+
result, err := function.Remote([]any{"hello"}, nil)
60+
g.Expect(err).ShouldNot(gomega.HaveOccurred())
61+
g.Expect(result).Should(gomega.Equal("output: hello"))
62+
}

modal-js/src/client.ts

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import { getProfile, type Profile } from "./config";
1515

1616
const defaultProfile = getProfile(process.env["MODAL_PROFILE"]);
1717

18+
let modalAuthToken: string | undefined;
19+
1820
/** gRPC client middleware to add auth token to request. */
1921
function authMiddleware(profile: Profile): ClientMiddleware {
2022
return async function* authMiddleware<Request, Response>(
@@ -36,6 +38,29 @@ function authMiddleware(profile: Profile): ClientMiddleware {
3638
options.metadata.set("x-modal-client-version", "1.0.0"); // CLIENT VERSION: Behaves like this Python SDK version
3739
options.metadata.set("x-modal-token-id", tokenId);
3840
options.metadata.set("x-modal-token-secret", tokenSecret);
41+
if (modalAuthToken) {
42+
options.metadata.set("x-modal-auth-token", modalAuthToken);
43+
}
44+
45+
// We receive an auth token from the control plane on our first request. We then include that auth token in every
46+
// subsequent request to both the control plane and the input plane. The python server returns it in the trailers,
47+
// the worker returns it in the headers.
48+
const prevOnHeader = options.onHeader;
49+
options.onHeader = (header) => {
50+
const token = header.get("x-modal-auth-token");
51+
if (token) {
52+
modalAuthToken = token;
53+
}
54+
prevOnHeader?.(header);
55+
};
56+
const prevOnTrailer = options.onTrailer;
57+
options.onTrailer = (trailer) => {
58+
const token = trailer.get("x-modal-auth-token");
59+
if (token) {
60+
modalAuthToken = token;
61+
}
62+
prevOnTrailer?.(trailer);
63+
};
3964
return yield* call.next(call.request, options);
4065
};
4166
}
@@ -199,6 +224,23 @@ const retryMiddleware: ClientMiddleware<RetryOptions> =
199224
}
200225
};
201226

227+
/** Map of server URL to input-plane client. */
228+
const inputPlaneClients: Record<string, ReturnType<typeof createClient>> = {};
229+
230+
/** Returns a client for the given server URL, creating it if it doesn't exist. */
231+
export const getOrCreateInputPlaneClient = (
232+
serverUrl: string,
233+
): ReturnType<typeof createClient> => {
234+
const client = inputPlaneClients[serverUrl];
235+
if (client) {
236+
return client;
237+
}
238+
const profile = { ...clientProfile, serverUrl };
239+
const newClient = createClient(profile);
240+
inputPlaneClients[serverUrl] = newClient;
241+
return newClient;
242+
};
243+
202244
function createClient(profile: Profile) {
203245
// Channels don't do anything until you send a request on them.
204246
// Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py
@@ -207,7 +249,6 @@ function createClient(profile: Profile) {
207249
"grpc.max_send_message_length": 100 * 1024 * 1024,
208250
"grpc-node.flow_control_window": 64 * 1024 * 1024,
209251
});
210-
211252
return createClientFactory()
212253
.use(authMiddleware(profile))
213254
.use(retryMiddleware)

0 commit comments

Comments
 (0)