Skip to content

Commit b788fdc

Browse files
committed
Merge branch 'main' into benank/main
2 parents f82e4b4 + 022585b commit b788fdc

36 files changed

+107
-65
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { name as packageName, version as packageVersion } from "../../package.json";
22
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
33
import type { InferenceTask, Options, RequestArgs } from "../types";
4-
import { getProviderHelper } from "./getProviderHelper";
4+
import type { getProviderHelper } from "./getProviderHelper";
55
import { getProviderModelId } from "./getProviderModelId";
66
import { isUrl } from "./isUrl";
77

@@ -20,6 +20,7 @@ export async function makeRequestOptions(
2020
data?: Blob | ArrayBuffer;
2121
stream?: boolean;
2222
},
23+
providerHelper: ReturnType<typeof getProviderHelper>,
2324
options?: Options & {
2425
/** In most cases (unless we pass a endpointUrl) we know the task */
2526
task?: InferenceTask;
@@ -28,6 +29,7 @@ export async function makeRequestOptions(
2829
const { provider: maybeProvider, model: maybeModel } = args;
2930
const provider = maybeProvider ?? "hf-inference";
3031
const { task } = options ?? {};
32+
3133
// Validate inputs
3234
if (args.endpointUrl && provider !== "hf-inference") {
3335
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -38,7 +40,7 @@ export async function makeRequestOptions(
3840

3941
if (args.endpointUrl) {
4042
// No need to have maybeModel, or to load default model for a task
41-
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, args, options);
43+
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, options);
4244
}
4345

4446
if (!maybeModel && !task) {
@@ -47,7 +49,6 @@ export async function makeRequestOptions(
4749

4850
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
4951
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
50-
const providerHelper = getProviderHelper(provider, task);
5152

5253
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
5354
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
@@ -62,7 +63,7 @@ export async function makeRequestOptions(
6263
});
6364

6465
// Use the sync version with the resolved model
65-
return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
66+
return makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, options);
6667
}
6768

6869
/**
@@ -71,6 +72,7 @@ export async function makeRequestOptions(
7172
*/
7273
export function makeRequestOptionsFromResolvedModel(
7374
resolvedModel: string,
75+
providerHelper: ReturnType<typeof getProviderHelper>,
7476
args: RequestArgs & {
7577
data?: Blob | ArrayBuffer;
7678
stream?: boolean;
@@ -85,7 +87,6 @@ export function makeRequestOptionsFromResolvedModel(
8587
const provider = maybeProvider ?? "hf-inference";
8688

8789
const { includeCredentials, task, signal, billTo } = options ?? {};
88-
const providerHelper = getProviderHelper(provider, task);
8990
const authMethod = (() => {
9091
if (providerHelper.clientSideRoutingOnly) {
9192
// Closed-source providers require an accessToken (cannot be routed).

packages/inference/src/providers/hf-inference.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,13 +385,13 @@ export class HFInferenceQuestionAnsweringTask extends HFInferenceTask implements
385385
typeof elem.end === "number" &&
386386
typeof elem.score === "number" &&
387387
typeof elem.start === "number"
388-
)
388+
)
389389
: typeof response === "object" &&
390-
!!response &&
391-
typeof response.answer === "string" &&
392-
typeof response.end === "number" &&
393-
typeof response.score === "number" &&
394-
typeof response.start === "number"
390+
!!response &&
391+
typeof response.answer === "string" &&
392+
typeof response.end === "number" &&
393+
typeof response.score === "number" &&
394+
typeof response.start === "number"
395395
) {
396396
return Array.isArray(response) ? response[0] : response;
397397
}

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingf
1111
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
1212
import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
1313
import { templates } from "./templates.exported";
14+
import { getProviderHelper } from "../lib/getProviderHelper";
1415

1516
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
1617
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
@@ -130,10 +131,18 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
130131
inputPreparationFn = prepareConversationalInput;
131132
task = "conversational";
132133
}
134+
let providerHelper: ReturnType<typeof getProviderHelper>;
135+
try {
136+
providerHelper = getProviderHelper(provider, task);
137+
} catch (e) {
138+
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
139+
return [];
140+
}
133141
/// Prepare inputs + make request
134142
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
135143
const request = makeRequestOptionsFromResolvedModel(
136144
providerModelId ?? model.id,
145+
providerHelper,
137146
{
138147
accessToken: accessToken,
139148
provider: provider,

packages/inference/src/tasks/audio/audioClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export async function audioClassification(
1717
): Promise<AudioClassificationOutput> {
1818
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
1919
const payload = preparePayload(args);
20-
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, {
20+
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
2121
...options,
2222
task: "audio-classification",
2323
});

packages/inference/src/tasks/audio/audioToAudio.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ export interface AudioToAudioOutput {
3838
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
3939
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
4040
const payload = preparePayload(args);
41-
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, {
41+
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
4242
...options,
4343
task: "audio-to-audio",
4444
});

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export async function automaticSpeechRecognition(
2020
): Promise<AutomaticSpeechRecognitionOutput> {
2121
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
2222
const payload = await buildPayload(args);
23-
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
23+
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
2424
...options,
2525
task: "automatic-speech-recognition",
2626
});

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ interface OutputUrlTextToSpeechGeneration {
1414
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
1515
const provider = args.provider ?? "hf-inference";
1616
const providerHelper = getProviderHelper(provider, "text-to-speech");
17-
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, {
17+
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
1818
...options,
1919
task: "text-to-speech",
2020
});

packages/inference/src/tasks/custom/request.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { getProviderHelper } from "../../lib/getProviderHelper";
12
import type { InferenceTask, Options, RequestArgs } from "../../types";
23
import { innerRequest } from "../../utils/request";
34

@@ -15,6 +16,7 @@ export async function request<T>(
1516
console.warn(
1617
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1718
);
18-
const result = await innerRequest<T>(args, options);
19+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
const result = await innerRequest<T>(args, providerHelper, options);
1921
return result.data;
2022
}

packages/inference/src/tasks/custom/streamingRequest.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import { getProviderHelper } from "../../lib/getProviderHelper";
12
import type { InferenceTask, Options, RequestArgs } from "../../types";
23
import { innerStreamingRequest } from "../../utils/request";
4+
35
/**
46
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
57
* @deprecated Use specific task functions instead. This function will be removed in a future version.
@@ -14,5 +16,6 @@ export async function* streamingRequest<T>(
1416
console.warn(
1517
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1618
);
17-
yield* innerStreamingRequest(args, options);
19+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
yield* innerStreamingRequest(args, providerHelper, options);
1821
}

packages/inference/src/tasks/cv/imageClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export async function imageClassification(
1616
): Promise<ImageClassificationOutput> {
1717
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1818
const payload = preparePayload(args);
19-
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, {
19+
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
2020
...options,
2121
task: "image-classification",
2222
});

0 commit comments

Comments
 (0)