Skip to content

Commit 57703ea

Browse files
committed
Add test and fixes
1 parent a0ca9a2 commit 57703ea

File tree

5 files changed

+80
-19
lines changed

5 files changed

+80
-19
lines changed

modal-js/src/client.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import {
1313
import { ClientType, ModalClientDefinition } from "../proto/modal_proto/api";
1414
import { type Profile, profile } from "./config";
1515

16+
let modalAuthToken: string | undefined;
17+
1618
/** gRPC client middleware to add auth token to request. */
1719
function authMiddleware(profile: Profile): ClientMiddleware {
1820
return async function* authMiddleware<Request, Response>(
@@ -27,6 +29,30 @@ function authMiddleware(profile: Profile): ClientMiddleware {
2729
options.metadata.set("x-modal-client-version", "1.0.0"); // CLIENT VERSION: Behaves like this Python SDK version
2830
options.metadata.set("x-modal-token-id", profile.tokenId);
2931
options.metadata.set("x-modal-token-secret", profile.tokenSecret);
32+
if (modalAuthToken) {
33+
options.metadata.set("x-modal-auth-token", modalAuthToken);
34+
}
35+
36+
const prevOnHeader = options.onHeader;
37+
options.onHeader = (header) => {
38+
const token = header.get("x-modal-auth-token");
39+
if (token) {
40+
modalAuthToken = token;
41+
}
42+
if (prevOnHeader) {
43+
prevOnHeader(header);
44+
}
45+
};
46+
const prevOnTrailer = options.onTrailer;
47+
options.onTrailer = (trailer) => {
48+
const token = trailer.get("x-modal-auth-token");
49+
if (token) {
50+
modalAuthToken = token;
51+
}
52+
if (prevOnTrailer) {
53+
prevOnTrailer(trailer);
54+
}
55+
};
3056
return yield* call.next(call.request, options);
3157
};
3258
}
@@ -223,4 +249,5 @@ const createClient = (serverURL: string) => {
223249
.create(ModalClientDefinition, channel);
224250
};
225251

252+
/** The default Modal client that talks to the control plane. */
226253
export const client = getOrCreateClient(profile.serverUrl);

modal-js/src/function.ts

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22

33
import { createHash } from "node:crypto";
44

5-
import {
6-
DataFormat,
7-
DeploymentNamespace,
8-
FunctionPutInputsItem,
9-
} from "../proto/modal_proto/api";
5+
import { DataFormat, DeploymentNamespace } from "../proto/modal_proto/api";
106
import type { LookupOptions } from "./app";
117
import { client } from "./client";
128
import { FunctionCall } from "./function_call";
@@ -18,6 +14,8 @@ import {
1814
ControlPlaneStrategy,
1915
GetOutput,
2016
getOutputsFromControlPlane,
17+
InputPlaneStrategy,
18+
InvocationStrategy,
2119
pollForOutputs,
2220
} from "./invocation_strategy";
2321

@@ -28,11 +26,13 @@ const maxObjectSizeBytes = 2 * 1024 * 1024; // 2 MiB
2826
export class Function_ {
2927
readonly functionId: string;
3028
readonly methodName: string | undefined;
29+
private readonly inputPlaneUrl: string | undefined;
3130

3231
/** @ignore */
33-
constructor(functionId: string, methodName?: string) {
32+
constructor(functionId: string, methodName?: string, inputPlaneUrl?: string) {
3433
this.functionId = functionId;
3534
this.methodName = methodName;
35+
this.inputPlaneUrl = inputPlaneUrl;
3636
}
3737

3838
static async lookup(
@@ -47,7 +47,11 @@ export class Function_ {
4747
namespace: DeploymentNamespace.DEPLOYMENT_NAMESPACE_WORKSPACE,
4848
environmentName: environmentName(options.environment),
4949
});
50-
return new Function_(resp.functionId);
50+
return new Function_(
51+
resp.functionId,
52+
undefined,
53+
resp.handleMetadata?.inputPlaneUrl,
54+
);
5155
} catch (err) {
5256
if (err instanceof ClientError && err.code === Status.NOT_FOUND)
5357
throw new NotFoundError(`Function '${appName}/${name}' not found`);
@@ -60,8 +64,10 @@ export class Function_ {
6064
args: any[] = [],
6165
kwargs: Record<string, any> = {},
6266
): Promise<any> {
63-
const input = await this.#createInput(args, kwargs);
64-
const invocationStrategy = new ControlPlaneStrategy(this.functionId, input);
67+
const invocationStrategy = await this.#createInvocationStrategy(
68+
args,
69+
kwargs,
70+
);
6571
return await invocationStrategy.remote();
6672
}
6773

@@ -70,16 +76,18 @@ export class Function_ {
7076
args: any[] = [],
7177
kwargs: Record<string, any> = {},
7278
): Promise<FunctionCall> {
73-
const input = await this.#createInput(args, kwargs);
74-
const invocationStrategy = new ControlPlaneStrategy(this.functionId, input);
79+
const invocationStrategy = await this.#createInvocationStrategy(
80+
args,
81+
kwargs,
82+
);
7583
const functionCallId = await invocationStrategy.spawn();
7684
return new FunctionCall(functionCallId);
7785
}
7886

79-
async #createInput(
87+
async #createInvocationStrategy(
8088
args: any[] = [],
8189
kwargs: Record<string, any> = {},
82-
): Promise<FunctionPutInputsItem> {
90+
): Promise<InvocationStrategy> {
8391
const payload = dumps([args, kwargs]);
8492

8593
let argsBlobId: string | undefined = undefined;
@@ -88,7 +96,7 @@ export class Function_ {
8896
}
8997

9098
// Single input sync invocation
91-
return {
99+
const input = {
92100
idx: 0,
93101
input: {
94102
args: argsBlobId ? undefined : payload,
@@ -98,6 +106,11 @@ export class Function_ {
98106
finalInput: false, // This field isn't specified in the Python client, so it defaults to false.
99107
},
100108
};
109+
if (this.inputPlaneUrl) {
110+
return new InputPlaneStrategy(this.functionId, this.inputPlaneUrl, input);
111+
}
112+
113+
return new ControlPlaneStrategy(this.functionId, input);
101114
}
102115
}
103116

modal-js/src/invocation_strategy.ts

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,27 @@ export class ControlPlaneStrategy implements InvocationStrategy {
9191
export class InputPlaneStrategy implements InvocationStrategy {
9292
private readonly functionId: string;
9393
private readonly client: ModalClientClient;
94+
private readonly input: FunctionPutInputsItem;
9495
private attemptToken: string | undefined;
9596

96-
constructor(functionId: string, inputPlaneUrl: string) {
97+
constructor(
98+
functionId: string,
99+
inputPlaneUrl: string,
100+
input: FunctionPutInputsItem,
101+
) {
97102
this.client = getOrCreateClient(inputPlaneUrl);
98103
this.functionId = functionId;
104+
this.input = input;
99105
}
100106

101-
async remote(input: FunctionPutInputsItem): Promise<FunctionGetOutputsItem> {
102-
await this.#attemptStart(input);
103-
return await pollForOutputs(this.#attemptAwait);
107+
async remote(): Promise<FunctionGetOutputsItem> {
108+
await this.#attemptStart(this.input);
109+
return await pollForOutputs((timeoutMillis: number) =>
110+
this.#attemptAwait(timeoutMillis),
111+
);
104112
}
105113

106-
async spawn(_input: FunctionPutInputsItem): Promise<string> {
114+
async spawn(): Promise<string> {
107115
throw new Error("Spawn operations are not supported by the input plane.");
108116
}
109117

modal-js/test/function.test.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,12 @@ test("FunctionNotFound", async () => {
3434
);
3535
await expect(promise).rejects.toThrowError(NotFoundError);
3636
});
37+
38+
test("FunctionCallInputPlane", async () => {
39+
const function_ = await Function_.lookup(
40+
"libmodal-test-support",
41+
"input_plane",
42+
);
43+
const result = await function_.remote(["hello"]);
44+
expect(result).toBe("output: hello");
45+
});

test-support/libmodal_test_support.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,7 @@ class EchoClsParametrized:
3434
@modal.method()
3535
def echo_parameter(self) -> str:
3636
return "output: " + self.name
37+
38+
@app.function(min_containers=1, experimental_options={"input_plane_region": "us-west"})
39+
def input_plane(s: str) -> str:
40+
return "output: " + s

0 commit comments

Comments
 (0)