Skip to content

Commit 66d7b0d

Browse files
Merge branch 'main' into v2ark/add_centml
2 parents d73d604 + eaa1b9c commit 66d7b0d

File tree

95 files changed

+216
-116
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+216
-116
lines changed

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import type { WidgetType } from "@huggingface/tasks";
2-
import type { InferenceProvider, ModelId } from "../types";
32
import { HF_HUB_URL } from "../config";
43
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
54
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
5+
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types";
66
import { typedInclude } from "../utils/typedInclude";
77

88
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
2020
task: WidgetType;
2121
}
2222

23-
export async function getInferenceProviderMapping(
24-
params: {
25-
accessToken?: string;
26-
modelId: ModelId;
27-
provider: InferenceProvider;
28-
task: WidgetType;
29-
},
30-
options: {
23+
export async function fetchInferenceProviderMappingForModel(
24+
modelId: ModelId,
25+
accessToken?: string,
26+
options?: {
3127
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
3228
}
33-
): Promise<InferenceProviderModelMapping | null> {
34-
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
35-
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
36-
}
29+
): Promise<InferenceProviderMapping> {
3730
let inferenceProviderMapping: InferenceProviderMapping | null;
38-
if (inferenceProviderMappingCache.has(params.modelId)) {
31+
if (inferenceProviderMappingCache.has(modelId)) {
3932
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40-
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
33+
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
4134
} else {
4235
const resp = await (options?.fetch ?? fetch)(
43-
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
36+
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
4437
{
45-
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
38+
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
4639
}
4740
);
4841
if (resp.status === 404) {
49-
throw new Error(`Model ${params.modelId} does not exist`);
42+
throw new Error(`Model ${modelId} does not exist`);
5043
}
5144
inferenceProviderMapping = await resp
5245
.json()
5346
.then((json) => json.inferenceProviderMapping)
5447
.catch(() => null);
48+
49+
if (inferenceProviderMapping) {
50+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
51+
}
5552
}
5653

5754
if (!inferenceProviderMapping) {
58-
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
55+
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
5956
}
57+
return inferenceProviderMapping;
58+
}
6059

60+
export async function getInferenceProviderMapping(
61+
params: {
62+
accessToken?: string;
63+
modelId: ModelId;
64+
provider: InferenceProvider;
65+
task: WidgetType;
66+
},
67+
options: {
68+
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
69+
}
70+
): Promise<InferenceProviderModelMapping | null> {
71+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
72+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
73+
}
74+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
75+
params.modelId,
76+
params.accessToken,
77+
options
78+
);
6179
const providerMapping = inferenceProviderMapping[params.provider];
6280
if (providerMapping) {
6381
const equivalentTasks =
@@ -78,3 +96,23 @@ export async function getInferenceProviderMapping(
7896
}
7997
return null;
8098
}
99+
100+
export async function resolveProvider(
101+
provider?: InferenceProviderOrPolicy,
102+
modelId?: string
103+
): Promise<InferenceProvider> {
104+
if (!provider) {
105+
console.log(
106+
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
107+
);
108+
provider = "auto";
109+
}
110+
if (provider === "auto") {
111+
if (!modelId) {
112+
throw new Error("Specifying a model is required when provider is 'auto'");
113+
}
114+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
115+
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider;
116+
}
117+
return provider;
118+
}

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ export async function makeRequestOptions(
2727
task?: InferenceTask;
2828
}
2929
): Promise<{ url: string; info: RequestInit }> {
30-
const { provider: maybeProvider, model: maybeModel } = args;
31-
const provider = maybeProvider ?? "hf-inference";
30+
const { model: maybeModel } = args;
31+
const provider = providerHelper.provider;
3232
const { task } = options ?? {};
3333

3434
// Validate inputs
@@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
113113
): { url: string; info: RequestInit } {
114114
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
115115
void model;
116+
void maybeProvider;
116117

117-
const provider = maybeProvider ?? "hf-inference";
118+
const provider = providerHelper.provider;
118119

119120
const { includeCredentials, task, signal, billTo } = options ?? {};
120121
const authMethod = (() => {

packages/inference/src/providers/providerHelper.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
5656
*/
5757
export abstract class TaskProviderHelper {
5858
constructor(
59-
private provider: InferenceProvider,
59+
readonly provider: InferenceProvider,
6060
private baseUrl: string,
6161
readonly clientSideRoutingOnly: boolean = false
6262
) {}

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ const prepareConversationalInput = (
272272
return {
273273
messages: opts?.messages ?? getModelInputSnippet(model),
274274
...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
275-
max_tokens: opts?.max_tokens ?? 512,
275+
...(opts?.max_tokens ? { max_tokens: opts?.max_tokens } : undefined),
276276
...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
277277
};
278278
};

packages/inference/src/tasks/audio/audioClassification.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options } from "../../types";
45
import { innerRequest } from "../../utils/request";
@@ -15,7 +16,8 @@ export async function audioClassification(
1516
args: AudioClassificationArgs,
1617
options?: Options
1718
): Promise<AudioClassificationOutput> {
18-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
19+
const provider = await resolveProvider(args.provider, args.model);
20+
const providerHelper = getProviderHelper(provider, "audio-classification");
1921
const payload = preparePayload(args);
2022
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
2123
...options,

packages/inference/src/tasks/audio/audioToAudio.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
12
import { getProviderHelper } from "../../lib/getProviderHelper";
23
import type { BaseArgs, Options } from "../../types";
34
import { innerRequest } from "../../utils/request";
@@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
3637
* Example model: speechbrain/sepformer-wham does audio source separation.
3738
*/
3839
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
39-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
40+
const model = "inputs" in args ? args.model : undefined;
41+
const provider = await resolveProvider(args.provider, model);
42+
const providerHelper = getProviderHelper(provider, "audio-to-audio");
4043
const payload = preparePayload(args);
4144
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
4245
...options,

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import { InferenceOutputError } from "../../lib/InferenceOutputError";
45
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
@@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
1819
args: AutomaticSpeechRecognitionArgs,
1920
options?: Options
2021
): Promise<AutomaticSpeechRecognitionOutput> {
21-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
22+
const provider = await resolveProvider(args.provider, args.model);
23+
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
2224
const payload = await buildPayload(args);
2325
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
2426
...options,

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { TextToSpeechInput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options } from "../../types";
45
import { innerRequest } from "../../utils/request";
@@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
1213
* Recommended model: espnet/kan-bayashi_ljspeech_vits
1314
*/
1415
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
15-
const provider = args.provider ?? "hf-inference";
16+
const provider = await resolveProvider(args.provider, args.model);
1617
const providerHelper = getProviderHelper(provider, "text-to-speech");
1718
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
1819
...options,

packages/inference/src/tasks/custom/request.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
12
import { getProviderHelper } from "../../lib/getProviderHelper";
23
import type { InferenceTask, Options, RequestArgs } from "../../types";
34
import { innerRequest } from "../../utils/request";
@@ -16,7 +17,8 @@ export async function request<T>(
1617
console.warn(
1718
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1819
);
19-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
const provider = await resolveProvider(args.provider, args.model);
21+
const providerHelper = getProviderHelper(provider, options?.task);
2022
const result = await innerRequest<T>(args, providerHelper, options);
2123
return result.data;
2224
}

packages/inference/src/tasks/custom/streamingRequest.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
12
import { getProviderHelper } from "../../lib/getProviderHelper";
23
import type { InferenceTask, Options, RequestArgs } from "../../types";
34
import { innerStreamingRequest } from "../../utils/request";
@@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
1617
console.warn(
1718
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1819
);
19-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
const provider = await resolveProvider(args.provider, args.model);
21+
const providerHelper = getProviderHelper(provider, options?.task);
2022
yield* innerStreamingRequest(args, providerHelper, options);
2123
}

0 commit comments

Comments
 (0)