Skip to content

Commit 95cc1f8

Browse files
committed
Merge remote-tracking branch 'origin/main' into pr/nbarr07/1260
2 parents bce9ec0 + 022585b commit 95cc1f8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+168
-66
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { name as packageName, version as packageVersion } from "../../package.json";
22
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
33
import type { InferenceTask, Options, RequestArgs } from "../types";
4-
import { getProviderHelper } from "./getProviderHelper";
4+
import type { getProviderHelper } from "./getProviderHelper";
55
import { getProviderModelId } from "./getProviderModelId";
66
import { isUrl } from "./isUrl";
77

@@ -20,6 +20,7 @@ export async function makeRequestOptions(
2020
data?: Blob | ArrayBuffer;
2121
stream?: boolean;
2222
},
23+
providerHelper: ReturnType<typeof getProviderHelper>,
2324
options?: Options & {
2425
/** In most cases (unless we pass a endpointUrl) we know the task */
2526
task?: InferenceTask;
@@ -28,20 +29,26 @@ export async function makeRequestOptions(
2829
const { provider: maybeProvider, model: maybeModel } = args;
2930
const provider = maybeProvider ?? "hf-inference";
3031
const { task } = options ?? {};
32+
3133
// Validate inputs
3234
if (args.endpointUrl && provider !== "hf-inference") {
3335
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
3436
}
3537
if (maybeModel && isUrl(maybeModel)) {
3638
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
3739
}
40+
41+
if (args.endpointUrl) {
42+
// No need to have maybeModel, or to load default model for a task
43+
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, options);
44+
}
45+
3846
if (!maybeModel && !task) {
3947
throw new Error("No model provided, and no task has been specified.");
4048
}
4149

4250
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
4351
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
44-
const providerHelper = getProviderHelper(provider, task);
4552

4653
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
4754
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
@@ -56,7 +63,7 @@ export async function makeRequestOptions(
5663
});
5764

5865
// Use the sync version with the resolved model
59-
return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
66+
return makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, options);
6067
}
6168

6269
/**
@@ -65,6 +72,7 @@ export async function makeRequestOptions(
6572
*/
6673
export function makeRequestOptionsFromResolvedModel(
6774
resolvedModel: string,
75+
providerHelper: ReturnType<typeof getProviderHelper>,
6876
args: RequestArgs & {
6977
data?: Blob | ArrayBuffer;
7078
stream?: boolean;
@@ -79,7 +87,6 @@ export function makeRequestOptionsFromResolvedModel(
7987
const provider = maybeProvider ?? "hf-inference";
8088

8189
const { includeCredentials, task, signal, billTo } = options ?? {};
82-
const providerHelper = getProviderHelper(provider, task);
8390
const authMethod = (() => {
8491
if (providerHelper.clientSideRoutingOnly) {
8592
// Closed-source providers require an accessToken (cannot be routed).

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingf
1111
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
1212
import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
1313
import { templates } from "./templates.exported";
14+
import { getProviderHelper } from "../lib/getProviderHelper";
1415

1516
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
1617
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
@@ -130,10 +131,18 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
130131
inputPreparationFn = prepareConversationalInput;
131132
task = "conversational";
132133
}
134+
let providerHelper: ReturnType<typeof getProviderHelper>;
135+
try {
136+
providerHelper = getProviderHelper(provider, task);
137+
} catch (e) {
138+
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
139+
return [];
140+
}
133141
/// Prepare inputs + make request
134142
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
135143
const request = makeRequestOptionsFromResolvedModel(
136144
providerModelId ?? model.id,
145+
providerHelper,
137146
{
138147
accessToken: accessToken,
139148
provider: provider,

packages/inference/src/tasks/audio/audioClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export async function audioClassification(
1717
): Promise<AudioClassificationOutput> {
1818
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
1919
const payload = preparePayload(args);
20-
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, {
20+
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
2121
...options,
2222
task: "audio-classification",
2323
});

packages/inference/src/tasks/audio/audioToAudio.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ export interface AudioToAudioOutput {
3838
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
3939
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
4040
const payload = preparePayload(args);
41-
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, {
41+
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
4242
...options,
4343
task: "audio-to-audio",
4444
});

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export async function automaticSpeechRecognition(
2020
): Promise<AutomaticSpeechRecognitionOutput> {
2121
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
2222
const payload = await buildPayload(args);
23-
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
23+
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
2424
...options,
2525
task: "automatic-speech-recognition",
2626
});

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ interface OutputUrlTextToSpeechGeneration {
1414
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
1515
const provider = args.provider ?? "hf-inference";
1616
const providerHelper = getProviderHelper(provider, "text-to-speech");
17-
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, {
17+
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
1818
...options,
1919
task: "text-to-speech",
2020
});

packages/inference/src/tasks/custom/request.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { getProviderHelper } from "../../lib/getProviderHelper";
12
import type { InferenceTask, Options, RequestArgs } from "../../types";
23
import { innerRequest } from "../../utils/request";
34

@@ -15,6 +16,7 @@ export async function request<T>(
1516
console.warn(
1617
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1718
);
18-
const result = await innerRequest<T>(args, options);
19+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
const result = await innerRequest<T>(args, providerHelper, options);
1921
return result.data;
2022
}

packages/inference/src/tasks/custom/streamingRequest.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import { getProviderHelper } from "../../lib/getProviderHelper";
12
import type { InferenceTask, Options, RequestArgs } from "../../types";
23
import { innerStreamingRequest } from "../../utils/request";
4+
35
/**
46
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
57
* @deprecated Use specific task functions instead. This function will be removed in a future version.
@@ -14,5 +16,6 @@ export async function* streamingRequest<T>(
1416
console.warn(
1517
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1618
);
17-
yield* innerStreamingRequest(args, options);
19+
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
yield* innerStreamingRequest(args, providerHelper, options);
1821
}

packages/inference/src/tasks/cv/imageClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export async function imageClassification(
1616
): Promise<ImageClassificationOutput> {
1717
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
1818
const payload = preparePayload(args);
19-
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, {
19+
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
2020
...options,
2121
task: "image-classification",
2222
});

packages/inference/src/tasks/cv/imageSegmentation.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export async function imageSegmentation(
1616
): Promise<ImageSegmentationOutput> {
1717
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
1818
const payload = preparePayload(args);
19-
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, {
19+
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
2020
...options,
2121
task: "image-segmentation",
2222
});

0 commit comments

Comments
 (0)