Skip to content

Commit 83f8b96

Browse files
committed
return request context in request function
1 parent a93de9c commit 83f8b96

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

packages/inference/src/tasks/custom/request.ts

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
import type { InferenceTask, Options, RequestArgs } from "../../types";
21
import { makeRequestOptions } from "../../lib/makeRequestOptions";
2+
import type { InferenceTask, Options, RequestArgs } from "../../types";
3+
4+
export interface ResponseWrapper<T> {
5+
data: T;
6+
requestContext: {
7+
url: string;
8+
headers: Record<string, string>;
9+
};
10+
}
311

412
/**
513
* Primitive to make custom calls to the inference provider
@@ -11,8 +19,10 @@ export async function request<T>(
1119
task?: InferenceTask;
1220
/** Is chat completion compatible */
1321
chatCompletion?: boolean;
22+
/** Whether to include request context in the response */
23+
withRequestContext?: boolean;
1424
}
15-
): Promise<T> {
25+
): Promise<Options extends { withRequestContext: true } ? ResponseWrapper<T> : T> {
1626
const { url, info } = await makeRequestOptions(args, options);
1727
const response = await (options?.fetch ?? fetch)(url, info);
1828

@@ -39,9 +49,13 @@ export async function request<T>(
3949
throw new Error(message ?? "An error occurred while fetching the blob");
4050
}
4151

52+
const requestContext = { url, headers: info.headers as Record<string, string> };
53+
4254
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
43-
return await response.json();
55+
const data = await response.json();
56+
return options?.withRequestContext ? ({ data, requestContext } as unknown as T) : (data as T);
4457
}
4558

46-
return (await response.blob()) as T;
59+
const blob = await response.blob();
60+
return options?.withRequestContext ? ({ data: blob, requestContext } as unknown as T) : (blob as T);
4761
}

packages/inference/src/tasks/cv/textToVideo.ts

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import type { BaseArgs, InferenceProvider, Options } from "../../types";
21
import type { TextToVideoInput } from "@huggingface/tasks";
3-
import { request } from "../custom/request";
4-
import { omit } from "../../utils/omit";
5-
import { isUrl } from "../../lib/isUrl";
62
import { InferenceOutputError } from "../../lib/InferenceOutputError";
7-
import { typedInclude } from "../../utils/typedInclude";
8-
import { makeRequestOptions } from "../../lib/makeRequestOptions";
3+
import { isUrl } from "../../lib/isUrl";
94
import { pollFalResponse, type FalAiQueueOutput } from "../../providers/fal-ai";
5+
import type { BaseArgs, InferenceProvider, Options } from "../../types";
6+
import { omit } from "../../utils/omit";
7+
import { typedInclude } from "../../utils/typedInclude";
8+
import { request, type ResponseWrapper } from "../custom/request";
109

1110
export type TextToVideoArgs = BaseArgs & TextToVideoInput;
1211

@@ -35,13 +34,15 @@ export async function textToVideo(args: TextToVideoArgs, options?: Options): Pro
3534
args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita"
3635
? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
3736
: args;
38-
const res = await request<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(payload, {
37+
const response = await request<ResponseWrapper<FalAiQueueOutput | ReplicateOutput | NovitaOutput>>(payload, {
3938
...options,
4039
task: "text-to-video",
40+
withRequestContext: true,
4141
});
42+
43+
const { data: res, requestContext } = response;
4244
if (args.provider === "fal-ai") {
43-
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
44-
return await pollFalResponse(res as FalAiQueueOutput, url, info.headers as Record<string, string>);
45+
return await pollFalResponse(res as FalAiQueueOutput, requestContext.url, requestContext.headers);
4546
} else if (args.provider === "novita") {
4647
const isValidOutput =
4748
typeof res === "object" &&

0 commit comments

Comments
 (0)