Skip to content

Commit 291fad9

Browse files
committed
Add InvocationStrategy for sending inputs to control plane
1 parent a447216 commit 291fad9

File tree

2 files changed

+209
-148
lines changed

2 files changed

+209
-148
lines changed

modal-js/src/function.ts

Lines changed: 27 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,35 @@ import { createHash } from "node:crypto";
55
import {
66
DataFormat,
77
DeploymentNamespace,
8-
FunctionCallInvocationType,
9-
FunctionCallType,
10-
FunctionGetOutputsResponse,
11-
GeneratorDone,
12-
GenericResult,
13-
GenericResult_GenericStatus,
8+
FunctionPutInputsItem,
149
} from "../proto/modal_proto/api";
1510
import type { LookupOptions } from "./app";
1611
import { client } from "./client";
1712
import { FunctionCall } from "./function_call";
1813
import { environmentName } from "./config";
19-
import {
20-
InternalFailure,
21-
NotFoundError,
22-
RemoteError,
23-
FunctionTimeoutError,
24-
} from "./errors";
25-
import { dumps, loads } from "./pickle";
14+
import { NotFoundError } from "./errors";
15+
import { dumps } from "./pickle";
2616
import { ClientError, Status } from "nice-grpc";
17+
import {
18+
ControlPlaneStrategy,
19+
InvocationStrategy,
20+
pollControlPlaneForOutput,
21+
} from "./invocation_strategy";
2722

2823
// From: modal/_utils/blob_utils.py
2924
const maxObjectSizeBytes = 2 * 1024 * 1024; // 2 MiB
3025

31-
// From: modal-client/modal/_utils/function_utils.py
32-
const outputsTimeout = 55 * 1000;
33-
34-
function timeNowSeconds() {
35-
return Date.now() / 1e3;
36-
}
37-
3826
/** Represents a deployed Modal Function, which can be invoked remotely. */
3927
export class Function_ {
4028
readonly functionId: string;
4129
readonly methodName: string | undefined;
30+
private readonly invocationStrategy: InvocationStrategy;
4231

4332
/** @ignore */
4433
constructor(functionId: string, methodName?: string) {
4534
this.functionId = functionId;
4635
this.methodName = methodName;
36+
this.invocationStrategy = new ControlPlaneStrategy(this.functionId);
4737
}
4838

4939
static async lookup(
@@ -71,32 +61,24 @@ export class Function_ {
7161
args: any[] = [],
7262
kwargs: Record<string, any> = {},
7363
): Promise<any> {
74-
const functionCallId = await this.#execFunctionCall(
75-
args,
76-
kwargs,
77-
FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
78-
);
79-
return await pollFunctionOutput(functionCallId);
64+
const input = await this.#createInput(args, kwargs);
65+
return await this.invocationStrategy.remote(input);
8066
}
8167

8268
// Spawn a single input into a remote function.
8369
async spawn(
8470
args: any[] = [],
8571
kwargs: Record<string, any> = {},
8672
): Promise<FunctionCall> {
87-
const functionCallId = await this.#execFunctionCall(
88-
args,
89-
kwargs,
90-
FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
91-
);
73+
const input = await this.#createInput(args, kwargs);
74+
const functionCallId = await this.invocationStrategy.spawn(input);
9275
return new FunctionCall(functionCallId);
9376
}
9477

95-
async #execFunctionCall(
78+
async #createInput(
9679
args: any[] = [],
9780
kwargs: Record<string, any> = {},
98-
invocationType: FunctionCallInvocationType = FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
99-
): Promise<string> {
81+
): Promise<FunctionPutInputsItem> {
10082
const payload = dumps([args, kwargs]);
10183

10284
let argsBlobId: string | undefined = undefined;
@@ -105,97 +87,24 @@ export class Function_ {
10587
}
10688

10789
// Single input sync invocation
108-
const functionMapResponse = await client.functionMap({
109-
functionId: this.functionId,
110-
functionCallType: FunctionCallType.FUNCTION_CALL_TYPE_UNARY,
111-
functionCallInvocationType: invocationType,
112-
pipelinedInputs: [
113-
{
114-
idx: 0,
115-
input: {
116-
args: argsBlobId ? undefined : payload,
117-
argsBlobId,
118-
dataFormat: DataFormat.DATA_FORMAT_PICKLE,
119-
methodName: this.methodName,
120-
},
121-
},
122-
],
123-
});
124-
125-
return functionMapResponse.functionCallId;
90+
return {
91+
idx: 0,
92+
input: {
93+
args: argsBlobId ? undefined : payload,
94+
argsBlobId,
95+
dataFormat: DataFormat.DATA_FORMAT_PICKLE,
96+
methodName: this.methodName,
97+
finalInput: false, // This field isn't specified in the Python client, so it defaults to false.
98+
},
99+
};
126100
}
127101
}
128102

129103
export async function pollFunctionOutput(
130104
functionCallId: string,
131105
timeout?: number, // in milliseconds
132106
): Promise<any> {
133-
const startTime = Date.now();
134-
let pollTimeout = outputsTimeout;
135-
if (timeout !== undefined) {
136-
pollTimeout = Math.min(timeout, outputsTimeout);
137-
}
138-
139-
while (true) {
140-
let response: FunctionGetOutputsResponse;
141-
try {
142-
response = await client.functionGetOutputs({
143-
functionCallId: functionCallId,
144-
maxValues: 1,
145-
timeout: pollTimeout / 1000, // Backend needs seconds
146-
lastEntryId: "0-0",
147-
clearOnSuccess: true,
148-
requestedAt: timeNowSeconds(),
149-
});
150-
} catch (err) {
151-
throw new Error(`FunctionGetOutputs failed: ${err}`);
152-
}
153-
154-
const outputs = response.outputs;
155-
if (outputs.length > 0) {
156-
return await processResult(outputs[0].result, outputs[0].dataFormat);
157-
}
158-
159-
if (timeout !== undefined) {
160-
const remainingTime = timeout - (Date.now() - startTime);
161-
if (remainingTime <= 0) {
162-
const message = `Timeout exceeded: ${(timeout / 1000).toFixed(1)}s`;
163-
throw new FunctionTimeoutError(message);
164-
}
165-
pollTimeout = Math.min(outputsTimeout, remainingTime);
166-
}
167-
}
168-
}
169-
170-
async function processResult(
171-
result: GenericResult | undefined,
172-
dataFormat: DataFormat,
173-
): Promise<unknown> {
174-
if (!result) {
175-
throw new Error("Received null result from invocation");
176-
}
177-
178-
let data = new Uint8Array();
179-
if (result.data !== undefined) {
180-
data = result.data;
181-
} else if (result.dataBlobId) {
182-
data = await blobDownload(result.dataBlobId);
183-
}
184-
185-
switch (result.status) {
186-
case GenericResult_GenericStatus.GENERIC_STATUS_TIMEOUT:
187-
throw new FunctionTimeoutError(`Timeout: ${result.exception}`);
188-
case GenericResult_GenericStatus.GENERIC_STATUS_INTERNAL_FAILURE:
189-
throw new InternalFailure(`Internal failure: ${result.exception}`);
190-
case GenericResult_GenericStatus.GENERIC_STATUS_SUCCESS:
191-
// Proceed to deserialize the data.
192-
break;
193-
default:
194-
// Handle other statuses, e.g., remote error.
195-
throw new RemoteError(`Remote error: ${result.exception}`);
196-
}
197-
198-
return deserializeDataFormat(data, dataFormat);
107+
return pollControlPlaneForOutput(functionCallId, timeout);
199108
}
200109

201110
async function blobUpload(data: Uint8Array): Promise<string> {
@@ -228,33 +137,3 @@ async function blobUpload(data: Uint8Array): Promise<string> {
228137
throw new Error("Missing upload URL in BlobCreate response");
229138
}
230139
}
231-
232-
async function blobDownload(blobId: string): Promise<Uint8Array> {
233-
const resp = await client.blobGet({ blobId });
234-
const s3resp = await fetch(resp.downloadUrl);
235-
if (!s3resp.ok) {
236-
throw new Error(`Failed to download blob: ${s3resp.statusText}`);
237-
}
238-
const buf = await s3resp.arrayBuffer();
239-
return new Uint8Array(buf);
240-
}
241-
242-
function deserializeDataFormat(
243-
data: Uint8Array | undefined,
244-
dataFormat: DataFormat,
245-
): unknown {
246-
if (!data) {
247-
return null; // No data to deserialize.
248-
}
249-
250-
switch (dataFormat) {
251-
case DataFormat.DATA_FORMAT_PICKLE:
252-
return loads(data);
253-
case DataFormat.DATA_FORMAT_ASGI:
254-
throw new Error("ASGI data format is not supported in Go");
255-
case DataFormat.DATA_FORMAT_GENERATOR_DONE:
256-
return GeneratorDone.decode(data);
257-
default:
258-
throw new Error(`Unsupported data format: ${dataFormat}`);
259-
}
260-
}

0 commit comments

Comments
 (0)