Skip to content
22 changes: 18 additions & 4 deletions packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import type { InferenceTask, Options, RequestArgs } from "../../types";

export interface ResponseWrapper<T> {
data: T;
requestContext: {
url: string;
headers: Record<string, string>;
};
}

/**
* Primitive to make custom calls to the inference provider
Expand All @@ -11,8 +19,10 @@ export async function request<T>(
task?: InferenceTask;
/** Is chat completion compatible */
chatCompletion?: boolean;
/** Whether to include request context in the response */
withRequestContext?: boolean;
}
): Promise<T> {
): Promise<Options extends { withRequestContext: true } ? ResponseWrapper<T> : T> {
const { url, info } = await makeRequestOptions(args, options);
const response = await (options?.fetch ?? fetch)(url, info);

Expand All @@ -39,9 +49,13 @@ export async function request<T>(
throw new Error(message ?? "An error occurred while fetching the blob");
}

const requestContext = { url, headers: info.headers as Record<string, string> };

if (response.headers.get("Content-Type")?.startsWith("application/json")) {
return await response.json();
const data = await response.json();
return options?.withRequestContext ? ({ data, requestContext } as unknown as T) : (data as T);
}

return (await response.blob()) as T;
const blob = await response.blob();
return options?.withRequestContext ? ({ data: blob, requestContext } as unknown as T) : (blob as T);
}
19 changes: 10 additions & 9 deletions packages/inference/src/tasks/cv/textToVideo.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import type { BaseArgs, InferenceProvider, Options } from "../../types";
import type { TextToVideoInput } from "@huggingface/tasks";
import { request } from "../custom/request";
import { omit } from "../../utils/omit";
import { isUrl } from "../../lib/isUrl";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { typedInclude } from "../../utils/typedInclude";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import { isUrl } from "../../lib/isUrl";
import { pollFalResponse, type FalAiQueueOutput } from "../../providers/fal-ai";
import type { BaseArgs, InferenceProvider, Options } from "../../types";
import { omit } from "../../utils/omit";
import { typedInclude } from "../../utils/typedInclude";
import { request, type ResponseWrapper } from "../custom/request";

export type TextToVideoArgs = BaseArgs & TextToVideoInput;

Expand Down Expand Up @@ -35,13 +34,15 @@ export async function textToVideo(args: TextToVideoArgs, options?: Options): Pro
args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita"
? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
: args;
const res = await request<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(payload, {
const response = await request<ResponseWrapper<FalAiQueueOutput | ReplicateOutput | NovitaOutput>>(payload, {
...options,
task: "text-to-video",
withRequestContext: true,
});

const { data: res, requestContext } = response;
if (args.provider === "fal-ai") {
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
return await pollFalResponse(res as FalAiQueueOutput, url, info.headers as Record<string, string>);
return await pollFalResponse(res as FalAiQueueOutput, requestContext.url, requestContext.headers);
} else if (args.provider === "novita") {
const isValidOutput =
typeof res === "object" &&
Expand Down