Skip to content

Commit 3635731

Browse files
committed
Address PR review comments
1 parent a7b097b commit 3635731

File tree

6 files changed

+93
-87
lines changed

6 files changed

+93
-87
lines changed

modal-go/client.go

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,12 @@ var clientProfile Profile
7676
// client is the default Modal client that talks to the control plane.
7777
var client pb.ModalClientClient
7878

79-
// clients is a map of server URL => client.
80-
// The us-east client talks to the control plane; all other clients talk to input planes.
81-
var clients = map[string]pb.ModalClientClient{}
79+
// clients is a map of server URL to input-plane client.
80+
var inputPlaneClients = map[string]pb.ModalClientClient{}
81+
82+
// authToken is the auth token received from the control plane on the first request, and sent with all
83+
// subsequent requests to both the control plane and the input plane.
84+
var authToken string
8285

8386
func init() {
8487
defaultConfig, _ = readConfigFile()
@@ -113,17 +116,19 @@ func InitializeClient(options ClientOptions) error {
113116
return err
114117
}
115118

116-
// getOrCreateClient returns a client for the given server URL, creating it if it doesn't exist.
117-
func getOrCreateClient(serverURL string) (pb.ModalClientClient, error) {
118-
if client, ok := clients[serverURL]; ok {
119+
// getOrCreateInputPlaneClient returns a client for the given server URL, creating it if it doesn't exist.
120+
func getOrCreateInputPlaneClient(serverURL string) (pb.ModalClientClient, error) {
121+
if client, ok := inputPlaneClients[serverURL]; ok {
119122
return client, nil
120123
}
121124

122-
_, client, err := newClient(nil)
125+
profile := clientProfile
126+
profile.ServerURL = serverURL
127+
_, client, err := newClient(profile)
123128
if err != nil {
124129
return nil, err
125130
}
126-
clients[serverURL] = client
131+
inputPlaneClients[serverURL] = client
127132
return client, nil
128133
}
129134

@@ -150,6 +155,7 @@ func newClient(profile Profile) (*grpc.ClientConn, pb.ModalClientClient, error)
150155
grpc.MaxCallSendMsgSize(maxMessageSize),
151156
),
152157
grpc.WithChainUnaryInterceptor(
158+
authTokenInterceptor(),
153159
retryInterceptor(),
154160
timeoutInterceptor(),
155161
),
@@ -176,6 +182,37 @@ func clientContext(ctx context.Context) (context.Context, error) {
176182
), nil
177183
}
178184

185+
// authTokenInterceptor handles sending and receiving the "x-modal-auth-token" header.
186+
// We receive an auth token from the control plane on our first request. We then include that auth token in every
187+
// subsequent request to both the control plane and the input plane.
188+
func authTokenInterceptor() grpc.UnaryClientInterceptor {
189+
return func(
190+
ctx context.Context,
191+
method string,
192+
req, reply any,
193+
cc *grpc.ClientConn,
194+
inv grpc.UnaryInvoker,
195+
opts ...grpc.CallOption,
196+
) error {
197+
var headers, trailers metadata.MD
198+
// Add authToken to outgoing context if it's set
199+
if authToken != "" {
200+
ctx = metadata.AppendToOutgoingContext(ctx, "x-modal-auth-token", authToken)
201+
}
202+
opts = append(opts, grpc.Header(&headers), grpc.Trailer(&trailers))
203+
err := inv(ctx, method, req, reply, cc, opts...)
204+
// If we're talking to the control plane, and no auth token was sent, it will return one.
205+
// The python server returns it in the trailers, the worker returns it in the headers.
206+
if val, ok := headers["x-modal-auth-token"]; ok {
207+
authToken = val[0]
208+
} else if val, ok := trailers["x-modal-auth-token"]; ok {
209+
authToken = val[0]
210+
}
211+
212+
return err
213+
}
214+
}
215+
179216
func timeoutInterceptor() grpc.UnaryClientInterceptor {
180217
return func(
181218
ctx context.Context,

modal-go/function.go

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ import (
1010
"encoding/base64"
1111
"errors"
1212
"fmt"
13-
"google.golang.org/grpc"
14-
"google.golang.org/grpc/metadata"
1513
"net/http"
1614
"time"
1715

@@ -53,13 +51,12 @@ func FunctionLookup(ctx context.Context, appName string, name string, options *L
5351
return nil, err
5452
}
5553

56-
var header, trailer metadata.MD
5754
resp, err := client.FunctionGet(ctx, pb.FunctionGetRequest_builder{
5855
AppName: appName,
5956
ObjectTag: name,
6057
Namespace: pb.DeploymentNamespace_DEPLOYMENT_NAMESPACE_WORKSPACE,
6158
EnvironmentName: environmentName(options.Environment),
62-
}.Build(), grpc.Header(&header), grpc.Trailer(&trailer))
59+
}.Build())
6360

6461
if status, ok := status.FromError(err); ok && status.Code() == codes.NotFound {
6562
return nil, NotFoundError{fmt.Sprintf("function '%s/%s' not found", appName, name)}
@@ -68,15 +65,6 @@ func FunctionLookup(ctx context.Context, appName string, name string, options *L
6865
return nil, err
6966
}
7067

71-
// Attach x-modal-auth-token to all future requests.
72-
authTokenArray := header.Get("x-modal-auth-token")
73-
if len(authTokenArray) == 0 {
74-
authTokenArray = trailer.Get("x-modal-auth-token")
75-
}
76-
if len(authTokenArray) > 0 {
77-
authToken := authTokenArray[0]
78-
ctx = metadata.AppendToOutgoingContext(ctx, "x-modal-auth-token", authToken)
79-
}
8068
var inputPlaneUrl *string
8169
if meta := resp.GetHandleMetadata(); meta != nil {
8270
if url := meta.GetInputPlaneUrl(); url != "" {

modal-go/invocation.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ type invocation interface {
1616
retry(retryCount uint32) error
1717
}
1818

19-
// getOutput is a function type that takes a timeout and returns a FunctionGetOutputsItem or nil, and an error.
20-
type getOutput func(timeout time.Duration) (*pb.FunctionGetOutputsItem, error)
21-
2219
// controlPlaneInvocation implements the invocation interface.
2320
type controlPlaneInvocation struct {
2421
FunctionCallId string
@@ -118,7 +115,7 @@ func createInputPlaneInvocation(ctx context.Context, inputPlaneURL string, funct
118115
Idx: 0,
119116
Input: input,
120117
}.Build()
121-
client, err := getOrCreateClient(inputPlaneURL)
118+
client, err := getOrCreateInputPlaneClient(inputPlaneURL)
122119
if err != nil {
123120
return nil, err
124121
}
@@ -171,7 +168,13 @@ func (i *inputPlaneInvocation) retry(retryCount uint32) error {
171168
return nil
172169
}
173170

174-
// pollFunctionOutput repeatedly fetches the output for a given function call, waiting until a result is available or a timeout occurs.
171+
// getOutput is a function type that takes a timeout and returns a FunctionGetOutputsItem or nil, and an error.
172+
// Used by `pollForOutputs` to fetch from either the control plane or the input plane, depending on the implementation.
173+
type getOutput func(timeout time.Duration) (*pb.FunctionGetOutputsItem, error)
174+
175+
// pollFunctionOutput repeatedly tries to fetch an output using the provided `getOutput` function, and the specified
176+
// timeout value. We use a timeout value of 55 seconds if the caller does not specify a timeout value, or if the
177+
// specified timeout value is greater than 55 seconds.
175178
func pollFunctionOutput(ctx context.Context, getOutput getOutput, timeout *time.Duration) (any, error) {
176179
startTime := time.Now()
177180
pollTimeout := outputsTimeout

modal-js/src/client.ts

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import { getProfile, type Profile } from "./config";
1616
const defaultProfile = getProfile(process.env["MODAL_PROFILE"]);
1717

1818
// CLIENT VERSION: Behaves like this Python SDK version
19-
const clientVersion = "1.0.0"
19+
const clientVersion = "1.0.0";
2020

2121
let modalAuthToken: string | undefined;
2222

@@ -45,46 +45,29 @@ function authMiddleware(profile: Profile): ClientMiddleware {
4545
options.metadata.set("x-modal-auth-token", modalAuthToken);
4646
}
4747

48+
// We receive an auth token from the control plane on our first request. We then include that auth token in every
49+
// subsequent request to both the control plane and the input plane. The python server returns it in the trailers,
50+
// the worker returns it in the headers.
4851
const prevOnHeader = options.onHeader;
4952
options.onHeader = (header) => {
5053
const token = header.get("x-modal-auth-token");
5154
if (token) {
5255
modalAuthToken = token;
5356
}
54-
if (prevOnHeader) {
55-
prevOnHeader(header);
56-
}
57+
prevOnHeader?.(header);
5758
};
5859
const prevOnTrailer = options.onTrailer;
5960
options.onTrailer = (trailer) => {
6061
const token = trailer.get("x-modal-auth-token");
6162
if (token) {
6263
modalAuthToken = token;
6364
}
64-
if (prevOnTrailer) {
65-
prevOnTrailer(trailer);
66-
}
65+
prevOnTrailer?.(trailer);
6766
};
6867
return yield* call.next(call.request, options);
6968
};
7069
}
7170

72-
/** gRPC client middleware to add auth token to request. */
73-
function inputPlaneAuthMiddleware(): ClientMiddleware {
74-
return async function* inputPlaneAuthMiddleware<Request, Response>(
75-
call: ClientMiddlewareCall<Request, Response>,
76-
options: CallOptions,
77-
) {
78-
options.metadata ??= new Metadata();
79-
options.metadata.set("x-modal-client-type", String(ClientType.CLIENT_TYPE_LIBMODAL));
80-
options.metadata.set("x-modal-client-version", clientVersion);
81-
if (modalAuthToken) {
82-
options.metadata.set("x-modal-auth-token", modalAuthToken);
83-
}
84-
return yield* call.next(call.request, options);
85-
};
86-
}
87-
8871
type TimeoutOptions = {
8972
/** Timeout for this call, interpreted as a duration in milliseconds */
9073
timeout?: number;
@@ -248,35 +231,27 @@ const retryMiddleware: ClientMiddleware<RetryOptions> =
248231
const inputPlaneClients: Record<string, ReturnType<typeof createClient>> = {};
249232

250233
/** Returns a client for the given server URL, creating it if it doesn't exist. */
251-
export const getOrCreateClient = (
234+
export const getOrCreateInputPlaneClient = (
252235
serverUrl: string,
253236
): ReturnType<typeof createClient> => {
254237
const client = inputPlaneClients[serverUrl];
255238
if (client) {
256239
return client;
257240
}
258-
const channel = createClientChannel(serverUrl);
259-
const newClient = createClientFactory()
260-
.use(inputPlaneAuthMiddleware())
261-
.use(retryMiddleware)
262-
.use(timeoutMiddleware)
263-
.create(ModalClientDefinition, channel);
241+
const profile = { ...clientProfile, serverUrl };
242+
const newClient = createClient(profile);
264243
inputPlaneClients[serverUrl] = newClient;
265244
return newClient;
266245
};
267246

268-
function createClientChannel(serverUrl: string) {
247+
function createClient(profile: Profile) {
269248
// Channels don't do anything until you send a request on them.
270249
// Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py
271-
return createChannel(serverUrl, undefined, {
250+
const channel = createChannel(profile.serverUrl, undefined, {
272251
"grpc.max_receive_message_length": 100 * 1024 * 1024,
273252
"grpc.max_send_message_length": 100 * 1024 * 1024,
274253
"grpc-node.flow_control_window": 64 * 1024 * 1024,
275254
});
276-
}
277-
278-
function createClient(profile: Profile) {
279-
const channel = createClientChannel(profile.serverUrl);
280255
return createClientFactory()
281256
.use(authMiddleware(profile))
282257
.use(retryMiddleware)

modal-js/src/invocation.ts

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import {
1111
GenericResult_GenericStatus,
1212
ModalClientClient,
1313
} from "../proto/modal_proto/api";
14-
import { client, getOrCreateClient } from "./client";
14+
import { client, getOrCreateInputPlaneClient } from "./client";
1515
import { FunctionTimeoutError, InternalFailure, RemoteError } from "./errors";
1616
import { loads } from "./pickle";
1717

@@ -29,14 +29,6 @@ export interface Invocation {
2929
retry(retryCount: number): Promise<void>;
3030
}
3131

32-
/**
33-
* Signature of a function that fetches a single output. Used by `pollForOutputs` to fetch from either
34-
* the control plane or the input plane, depending on the implementation.
35-
*/
36-
type GetOutput = (
37-
timeoutMillis: number,
38-
) => Promise<FunctionGetOutputsItem | undefined>;
39-
4032
/**
4133
* Implementation of Invocation which sends inputs to the control plane.
4234
*/
@@ -97,19 +89,15 @@ export class ControlPlaneInvocation implements Invocation {
9789
async #getOutput(
9890
timeoutMillis: number,
9991
): Promise<FunctionGetOutputsItem | undefined> {
100-
try {
101-
const response = await client.functionGetOutputs({
102-
functionCallId: this.functionCallId,
103-
maxValues: 1,
104-
timeout: timeoutMillis / 1000, // Backend needs seconds
105-
lastEntryId: "0-0",
106-
clearOnSuccess: true,
107-
requestedAt: timeNowSeconds(),
108-
});
109-
return response.outputs ? response.outputs[0] : undefined;
110-
} catch (err) {
111-
throw new Error(`FunctionGetOutputs failed: ${err}`);
112-
}
92+
const response = await client.functionGetOutputs({
93+
functionCallId: this.functionCallId,
94+
maxValues: 1,
95+
timeout: timeoutMillis / 1000, // Backend needs seconds
96+
lastEntryId: "0-0",
97+
clearOnSuccess: true,
98+
requestedAt: timeNowSeconds(),
99+
});
100+
return response.outputs ? response.outputs[0] : undefined;
113101
}
114102

115103
async retry(retryCount: number): Promise<void> {
@@ -160,12 +148,12 @@ export class InputPlaneInvocation implements Invocation {
160148
) {
161149
const functionPutInputsItem = {
162150
idx: 0,
163-
input: input,
151+
input,
164152
};
165-
const client = getOrCreateClient(inputPlaneUrl);
153+
const client = getOrCreateInputPlaneClient(inputPlaneUrl);
166154
// Single input sync invocation
167155
const attemptStartResponse = await client.attemptStart({
168-
functionId: functionId,
156+
functionId,
169157
input: functionPutInputsItem,
170158
});
171159
return new InputPlaneInvocation(
@@ -212,6 +200,19 @@ function timeNowSeconds() {
212200
return Date.now() / 1e3;
213201
}
214202

203+
/**
204+
* Signature of a function that fetches a single output using the given timeout. Used by `pollForOutputs` to fetch
205+
* from either the control plane or the input plane, depending on the implementation.
206+
*/
207+
type GetOutput = (
208+
timeoutMillis: number,
209+
) => Promise<FunctionGetOutputsItem | undefined>;
210+
211+
/***
212+
* Repeatedly tries to fetch an output using the provided `getOutput` function, and the specified timeout value.
213+
* We use a timeout value of 55 seconds if the caller does not specify a timeout value, or if the specified timeout
214+
* value is greater than 55 seconds.
215+
*/
215216
async function pollFunctionOutput(
216217
getOutput: GetOutput,
217218
timeout?: number, // in milliseconds

test-support/libmodal_test_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ def sleep(t: int) -> None:
1919
def bytelength(buf: bytes) -> int:
2020
return len(buf)
2121

22+
2223
@app.function(min_containers=1, experimental_options={"input_plane_region": "us-west"})
2324
def input_plane(s: str) -> str:
2425
return "output: " + s
2526

27+
2628
@app.cls(min_containers=1)
2729
class EchoCls:
2830
@modal.method()

0 commit comments

Comments
 (0)