Skip to content

Commit 022585b

Browse files
[Inference] top-down injection of providerHelper (#1350)
cc @hanouticelina @Wauplin Prompted by: https://huggingface.slack.com/archives/C02EMARJ65P/p1744370608755449 (internal) This PR reduces the number of places where we call `getProviderHelper` by enforcing it's passed as an argument by the caller when needed Because `getProviderHelper` is fallible, calling it from different places (sometimes deep in the function call chain) can be dangerous and result in unexpected bugs. This PR reduces this risk by reducing the number of different places where we call `getProviderHelper` , and instead pass it as an argument top-down It also has the benefits of clearly indicating which functions depend on provider-specific logic --------- Co-authored-by: Celina Hanouti <[email protected]>
1 parent 68c6201 commit 022585b

36 files changed

+107
-65
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { name as packageName, version as packageVersion } from "../../package.json";
22
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
33
import type { InferenceTask, Options, RequestArgs } from "../types";
4-
import { getProviderHelper } from "./getProviderHelper";
4+
import type { getProviderHelper } from "./getProviderHelper";
55
import { getProviderModelId } from "./getProviderModelId";
66
import { isUrl } from "./isUrl";
77

@@ -20,6 +20,7 @@ export async function makeRequestOptions(
2020
data?: Blob | ArrayBuffer;
2121
stream?: boolean;
2222
},
23+
providerHelper: ReturnType<typeof getProviderHelper>,
2324
options?: Options & {
2425
/** In most cases (unless we pass a endpointUrl) we know the task */
2526
task?: InferenceTask;
@@ -28,6 +29,7 @@ export async function makeRequestOptions(
2829
const { provider: maybeProvider, model: maybeModel } = args;
2930
const provider = maybeProvider ?? "hf-inference";
3031
const { task } = options ?? {};
32+
3133
// Validate inputs
3234
if (args.endpointUrl && provider !== "hf-inference") {
3335
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -38,7 +40,7 @@ export async function makeRequestOptions(
3840

3941
if (args.endpointUrl) {
4042
// No need to have maybeModel, or to load default model for a task
41-
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, args, options);
43+
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, options);
4244
}
4345

4446
if (!maybeModel && !task) {
@@ -47,7 +49,6 @@ export async function makeRequestOptions(
4749

4850
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
4951
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
50-
const providerHelper = getProviderHelper(provider, task);
5152

5253
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
5354
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
@@ -62,7 +63,7 @@ export async function makeRequestOptions(
6263
});
6364

6465
// Use the sync version with the resolved model
65-
return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
66+
return makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, options);
6667
}
6768

6869
/**
@@ -71,6 +72,7 @@ export async function makeRequestOptions(
7172
*/
7273
export function makeRequestOptionsFromResolvedModel(
7374
resolvedModel: string,
75+
providerHelper: ReturnType<typeof getProviderHelper>,
7476
args: RequestArgs & {
7577
data?: Blob | ArrayBuffer;
7678
stream?: boolean;
@@ -85,7 +87,6 @@ export function makeRequestOptionsFromResolvedModel(
8587
const provider = maybeProvider ?? "hf-inference";
8688

8789
const { includeCredentials, task, signal, billTo } = options ?? {};
88-
const providerHelper = getProviderHelper(provider, task);
8990
const authMethod = (() => {
9091
if (providerHelper.clientSideRoutingOnly) {
9192
// Closed-source providers require an accessToken (cannot be routed).

packages/inference/src/providers/hf-inference.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,13 +385,13 @@ export class HFInferenceQuestionAnsweringTask extends HFInferenceTask implements
385385
typeof elem.end === "number" &&
386386
typeof elem.score === "number" &&
387387
typeof elem.start === "number"
388-
)
388+
)
389389
: typeof response === "object" &&
390-
!!response &&
391-
typeof response.answer === "string" &&
392-
typeof response.end === "number" &&
393-
typeof response.score === "number" &&
394-
typeof response.start === "number"
390+
!!response &&
391+
typeof response.answer === "string" &&
392+
typeof response.end === "number" &&
393+
typeof response.score === "number" &&
394+
typeof response.start === "number"
395395
) {
396396
return Array.isArray(response) ? response[0] : response;
397397
}

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingf
1111
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
1212
import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
1313
import { templates } from "./templates.exported";
14+
import { getProviderHelper } from "../lib/getProviderHelper";
1415

1516
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
1617
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
@@ -130,10 +131,18 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
130131
inputPreparationFn = prepareConversationalInput;
131132
task = "conversational";
132133
}
134+
let providerHelper: ReturnType<typeof getProviderHelper>;
135+
try {
136+
providerHelper = getProviderHelper(provider, task);
137+
} catch (e) {
138+
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
139+
return [];
140+
}
133141
/// Prepare inputs + make request
134142
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
135143
const request = makeRequestOptionsFromResolvedModel(
136144
providerModelId ?? model.id,
145+
providerHelper,
137146
{
138147
accessToken: accessToken,
139148
provider: provider,

packages/inference/src/tasks/audio/audioClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export async function audioClassification(
1717
): Promise<AudioClassificationOutput> {
1818
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
1919
const payload = preparePayload(args);
20-
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, {
20+
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
2121
...options,
2222
task: "audio-classification",
2323
});

packages/inference/src/tasks/audio/audioToAudio.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ export interface AudioToAudioOutput {
3838
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
3939
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
4040
const payload = preparePayload(args);
41-
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, {
41+
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
4242
...options,
4343
task: "audio-to-audio",
4444
});

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export async function automaticSpeechRecognition(
2020
): Promise<AutomaticSpeechRecognitionOutput> {
2121
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
2222
const payload = await buildPayload(args);
23-
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
23+
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
2424
...options,
2525
task: "automatic-speech-recognition",
2626
});

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ interface OutputUrlTextToSpeechGeneration {
1414
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
1515
const provider = args.provider ?? "hf-inference";
1616
const providerHelper = getProviderHelper(provider, "text-to-speech");
17-
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, {
17+
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
1818
...options,
1919
task: "text-to-speech",
2020
});

packages/inference/src/tasks/custom/request.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { getProviderHelper } from "../../lib/getProviderHelper";
12
import type { InferenceTask, Options, RequestArgs } from "../../types";
23
import { innerRequest } from "../../utils/request";
34

@@ -15,6 +16,7 @@ export async function request<T>(
1516
console.warn(
1617
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1718
);
18-
const result = await innerRequest<T>(args, options);
19+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
const result = await innerRequest<T>(args, providerHelper, options);
1921
return result.data;
2022
}

packages/inference/src/tasks/custom/streamingRequest.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import { getProviderHelper } from "../../lib/getProviderHelper";
12
import type { InferenceTask, Options, RequestArgs } from "../../types";
23
import { innerStreamingRequest } from "../../utils/request";
4+
35
/**
46
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
57
* @deprecated Use specific task functions instead. This function will be removed in a future version.
@@ -14,5 +16,6 @@ export async function* streamingRequest<T>(
1416
console.warn(
1517
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1618
);
17-
yield* innerStreamingRequest(args, options);
19+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
yield* innerStreamingRequest(args, providerHelper, options);
1821
}

packages/inference/src/tasks/cv/imageClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export async function imageClassification(
1616
): Promise<ImageClassificationOutput> {
1717
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1818
const payload = preparePayload(args);
19-
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, {
19+
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
2020
...options,
2121
task: "image-classification",
2222
});

0 commit comments

Comments
 (0)