Skip to content

Commit a998493

Browse files
committed
Add support for input plane invocation
1 parent 23a98f7 commit a998493

File tree

5 files changed

+217
-42
lines changed

5 files changed

+217
-42
lines changed

modal-js/src/client.ts

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import {
1313
import { ClientType, ModalClientDefinition } from "../proto/modal_proto/api";
1414
import { type Profile, profile } from "./config";
1515

16+
let modalAuthToken: string | undefined;
17+
1618
/** gRPC client middleware to add auth token to request. */
1719
function authMiddleware(profile: Profile): ClientMiddleware {
1820
return async function* authMiddleware<Request, Response>(
@@ -27,6 +29,30 @@ function authMiddleware(profile: Profile): ClientMiddleware {
2729
options.metadata.set("x-modal-client-version", "1.0.0"); // CLIENT VERSION: Behaves like this Python SDK version
2830
options.metadata.set("x-modal-token-id", profile.tokenId);
2931
options.metadata.set("x-modal-token-secret", profile.tokenSecret);
32+
if (modalAuthToken) {
33+
options.metadata.set("x-modal-auth-token", modalAuthToken);
34+
}
35+
36+
const prevOnHeader = options.onHeader;
37+
options.onHeader = (header) => {
38+
const token = header.get("x-modal-auth-token");
39+
if (token) {
40+
modalAuthToken = token;
41+
}
42+
if (prevOnHeader) {
43+
prevOnHeader(header);
44+
}
45+
};
46+
const prevOnTrailer = options.onTrailer;
47+
options.onTrailer = (trailer) => {
48+
const token = trailer.get("x-modal-auth-token");
49+
if (token) {
50+
modalAuthToken = token;
51+
}
52+
if (prevOnTrailer) {
53+
prevOnTrailer(trailer);
54+
}
55+
};
3056
return yield* call.next(call.request, options);
3157
};
3258
}
@@ -190,15 +216,37 @@ const retryMiddleware: ClientMiddleware<RetryOptions> =
190216
}
191217
};
192218

193-
// Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py
194-
const channel = createChannel(profile.serverUrl, undefined, {
195-
"grpc.max_receive_message_length": 100 * 1024 * 1024,
196-
"grpc.max_send_message_length": 100 * 1024 * 1024,
197-
"grpc-node.flow_control_window": 64 * 1024 * 1024,
198-
});
199-
200-
export const client = createClientFactory()
201-
.use(authMiddleware(profile))
202-
.use(retryMiddleware)
203-
.use(timeoutMiddleware)
204-
.create(ModalClientDefinition, channel);
219+
/**
220+
* Map of server URL => client.
221+
* The us-east client talks to the control plane; all other clients talk to input planes.
222+
*/
223+
const clients: Record<string, ReturnType<typeof createClient>> = {};
224+
225+
/** Returns a client for the given server URL, creating it if it doesn't exist. */
226+
export const getOrCreateClient = (
227+
serverURL: string,
228+
): ReturnType<typeof createClient> => {
229+
if (serverURL in clients) {
230+
return clients[serverURL];
231+
}
232+
233+
clients[serverURL] = createClient(serverURL);
234+
return clients[serverURL];
235+
};
236+
237+
const createClient = (serverURL: string) => {
238+
// Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py
239+
const channel = createChannel(serverURL, undefined, {
240+
"grpc.max_receive_message_length": 100 * 1024 * 1024,
241+
"grpc.max_send_message_length": 100 * 1024 * 1024,
242+
"grpc-node.flow_control_window": 64 * 1024 * 1024,
243+
});
244+
245+
return createClientFactory()
246+
.use(authMiddleware(profile))
247+
.use(retryMiddleware)
248+
.use(timeoutMiddleware)
249+
.create(ModalClientDefinition, channel);
250+
};
251+
252+
export const client = getOrCreateClient(profile.serverUrl);

modal-js/src/function.ts

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ import { environmentName } from "./config";
1515
import { InternalFailure, NotFoundError } from "./errors";
1616
import { dumps } from "./pickle";
1717
import { ClientError, Status } from "nice-grpc";
18-
import { ControlPlaneInvocation } from "./invocation";
18+
import {
19+
ControlPlaneInvocation,
20+
InputPlaneInvocation,
21+
Invocation,
22+
} from "./invocation";
1923

2024
// From: modal/_utils/blob_utils.py
2125
const maxObjectSizeBytes = 2 * 1024 * 1024; // 2 MiB
@@ -27,11 +31,13 @@ const maxSystemRetries = 8;
2731
export class Function_ {
2832
readonly functionId: string;
2933
readonly methodName: string | undefined;
34+
readonly inputPlaneUrl: string | undefined;
3035

3136
/** @ignore */
32-
constructor(functionId: string, methodName?: string) {
37+
constructor(functionId: string, methodName?: string, inputPlaneUrl?: string) {
3338
this.functionId = functionId;
3439
this.methodName = methodName;
40+
this.inputPlaneUrl = inputPlaneUrl;
3541
}
3642

3743
static async lookup(
@@ -46,7 +52,11 @@ export class Function_ {
4652
namespace: DeploymentNamespace.DEPLOYMENT_NAMESPACE_WORKSPACE,
4753
environmentName: environmentName(options.environment),
4854
});
49-
return new Function_(resp.functionId);
55+
return new Function_(
56+
resp.functionId,
57+
undefined,
58+
resp.handleMetadata?.inputPlaneUrl,
59+
);
5060
} catch (err) {
5161
if (err instanceof ClientError && err.code === Status.NOT_FOUND)
5262
throw new NotFoundError(`Function '${appName}/${name}' not found`);
@@ -60,11 +70,7 @@ export class Function_ {
6070
kwargs: Record<string, any> = {},
6171
): Promise<any> {
6272
const input = await this.#createInput(args, kwargs);
63-
const invocation = await ControlPlaneInvocation.create(
64-
this.functionId,
65-
input,
66-
FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
67-
);
73+
const invocation = await this.#createRemoteInvocation(input);
6874
// TODO(ryan): Add tests for retries.
6975
let retryCount = 0;
7076
while (true) {
@@ -81,6 +87,22 @@ export class Function_ {
8187
}
8288
}
8389

90+
async #createRemoteInvocation(input: FunctionInput): Promise<Invocation> {
91+
if (this.inputPlaneUrl) {
92+
return await InputPlaneInvocation.create(
93+
this.inputPlaneUrl,
94+
this.functionId,
95+
input,
96+
);
97+
}
98+
99+
return await ControlPlaneInvocation.create(
100+
this.functionId,
101+
input,
102+
FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
103+
);
104+
}
105+
84106
// Spawn a single input into a remote function.
85107
async spawn(
86108
args: any[] = [],

modal-js/src/invocation.ts

Lines changed: 115 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@ import {
22
DataFormat,
33
FunctionCallInvocationType,
44
FunctionCallType,
5-
FunctionGetOutputsResponse,
5+
FunctionGetOutputsItem,
66
FunctionInput,
77
FunctionMapResponse,
8+
FunctionPutInputsItem,
89
FunctionRetryInputsItem,
910
GeneratorDone,
1011
GenericResult,
1112
GenericResult_GenericStatus,
13+
ModalClientClient,
1214
} from "../proto/modal_proto/api";
13-
import { client } from "./client";
15+
import { client, getOrCreateClient } from "./client";
1416
import { FunctionTimeoutError, InternalFailure, RemoteError } from "./errors";
1517
import { loads } from "./pickle";
1618

@@ -33,6 +35,14 @@ export interface Invocation {
3335
retry(retryCount: number): Promise<void>;
3436
}
3537

38+
/**
39+
* Signature of a function that fetches a single output. Used by `pollForOutputs` to fetch from either
40+
* the control plane or the input plane, depending on the implementation.
41+
*/
42+
type GetOutput = (
43+
timeoutMillis: number,
44+
) => Promise<FunctionGetOutputsItem | undefined>;
45+
3646
/**
3747
* Implementation of Invocation which sends inputs to the control plane.
3848
*/
@@ -77,7 +87,28 @@ export class ControlPlaneInvocation implements Invocation {
7787
}
7888

7989
async await(timeout?: number): Promise<any> {
80-
return await pollControlPlaneForOutput(this.functionCallId, timeout);
90+
return await pollControlPlaneForOutput(
91+
(timeoutMillis: number) => this.#getOutput(timeoutMillis),
92+
timeout,
93+
);
94+
}
95+
96+
async #getOutput(
97+
timeoutMillis: number,
98+
): Promise<FunctionGetOutputsItem | undefined> {
99+
try {
100+
const response = await client.functionGetOutputs({
101+
functionCallId: this.functionCallId,
102+
maxValues: 1,
103+
timeout: timeoutMillis / 1000, // Backend needs seconds
104+
lastEntryId: "0-0",
105+
clearOnSuccess: true,
106+
requestedAt: timeNowSeconds(),
107+
});
108+
return response.outputs ? response.outputs[0] : undefined;
109+
} catch (err) {
110+
throw new Error(`FunctionGetOutputs failed: ${err}`);
111+
}
81112
}
82113

83114
async retry(retryCount: number): Promise<void> {
@@ -118,12 +149,88 @@ export class ControlPlaneInvocation implements Invocation {
118149
}
119150
}
120151

152+
/**
153+
* Implementation of Invocation which sends inputs to the input plane.
154+
*/
155+
export class InputPlaneInvocation implements Invocation {
156+
private readonly client: ModalClientClient;
157+
private readonly functionId: string;
158+
private readonly input: FunctionPutInputsItem;
159+
private attemptToken: string;
160+
161+
constructor(
162+
client: ModalClientClient,
163+
functionId: string,
164+
input: FunctionPutInputsItem,
165+
attemptToken: string,
166+
) {
167+
this.client = client;
168+
this.functionId = functionId;
169+
this.input = input;
170+
this.attemptToken = attemptToken;
171+
}
172+
173+
static async create(
174+
inputPlaneUrl: string,
175+
functionId: string,
176+
input: FunctionInput,
177+
) {
178+
const functionPutInputsItem = {
179+
idx: 0,
180+
input: input,
181+
};
182+
const client = getOrCreateClient(inputPlaneUrl);
183+
// Single input sync invocation
184+
const attemptStartResponse = await client.attemptStart({
185+
functionId: functionId,
186+
input: functionPutInputsItem,
187+
});
188+
return new InputPlaneInvocation(
189+
client,
190+
functionId,
191+
functionPutInputsItem,
192+
attemptStartResponse.attemptToken,
193+
);
194+
}
195+
196+
async await(timeout?: number): Promise<any> {
197+
return await pollControlPlaneForOutput(
198+
(timeoutMillis: number) => this.#getOutput(timeoutMillis),
199+
timeout,
200+
);
201+
}
202+
203+
async #getOutput(
204+
timeoutMillis: number,
205+
): Promise<FunctionGetOutputsItem | undefined> {
206+
try {
207+
const response = await this.client.attemptAwait({
208+
attemptToken: this.attemptToken,
209+
requestedAt: timeNowSeconds(),
210+
timeoutSecs: timeoutMillis / 1000,
211+
});
212+
return response.output;
213+
} catch (err) {
214+
throw new Error(`AttemptAwait failed: ${err}`);
215+
}
216+
}
217+
218+
async retry(_retryCount: number): Promise<void> {
219+
const attemptRetryResponse = await this.client.attemptRetry({
220+
functionId: this.functionId,
221+
input: this.input,
222+
attemptToken: this.attemptToken,
223+
});
224+
this.attemptToken = attemptRetryResponse.attemptToken;
225+
}
226+
}
227+
121228
function timeNowSeconds() {
122229
return Date.now() / 1e3;
123230
}
124231

125-
export async function pollControlPlaneForOutput(
126-
functionCallId: string,
232+
async function pollControlPlaneForOutput(
233+
getOutput: GetOutput,
127234
timeout?: number, // in milliseconds
128235
): Promise<any> {
129236
const startTime = Date.now();
@@ -133,23 +240,9 @@ export async function pollControlPlaneForOutput(
133240
}
134241

135242
while (true) {
136-
let response: FunctionGetOutputsResponse;
137-
try {
138-
response = await client.functionGetOutputs({
139-
functionCallId: functionCallId,
140-
maxValues: 1,
141-
timeout: pollTimeout / 1000, // Backend needs seconds
142-
lastEntryId: "0-0",
143-
clearOnSuccess: true,
144-
requestedAt: timeNowSeconds(),
145-
});
146-
} catch (err) {
147-
throw new Error(`FunctionGetOutputs failed: ${err}`);
148-
}
149-
150-
const outputs = response.outputs;
151-
if (outputs.length > 0) {
152-
return await processResult(outputs[0].result, outputs[0].dataFormat);
243+
const output = await getOutput(pollTimeout);
244+
if (output) {
245+
return await processResult(output.result, output.dataFormat);
153246
}
154247

155248
if (timeout !== undefined) {

modal-js/test/function.test.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,12 @@ test("FunctionNotFound", async () => {
3434
);
3535
await expect(promise).rejects.toThrowError(NotFoundError);
3636
});
37+
38+
test("FunctionCallInputPlane", async () => {
39+
const function_ = await Function_.lookup(
40+
"libmodal-test-support",
41+
"input_plane",
42+
);
43+
const result = await function_.remote(["hello"]);
44+
expect(result).toBe("output: hello");
45+
});

test-support/libmodal_test_support.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def sleep(t: int) -> None:
1919
def bytelength(buf: bytes) -> int:
2020
return len(buf)
2121

22+
@app.function(min_containers=1, experimental_options={"input_plane_region": "us-west"})
23+
def input_plane(s: str) -> str:
24+
return "output: " + s
2225

2326
@app.cls(min_containers=1)
2427
class EchoCls:

0 commit comments

Comments
 (0)