Skip to content

Commit 50e8d8e

Browse files
committed
fix also in huggingface/inference
1 parent bb1b64e commit 50e8d8e

File tree

4 files changed

+58
-28
lines changed

4 files changed

+58
-28
lines changed

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,48 @@ import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../t
66
import { typedInclude } from "../utils/typedInclude.js";
77
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js";
88

9-
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
9+
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMappingEntry[]>();
1010

11-
export type InferenceProviderMapping = Partial<
12-
Record<InferenceProvider, Omit<InferenceProviderModelMapping, "hfModelId">>
13-
>;
14-
15-
export interface InferenceProviderModelMapping {
11+
export interface InferenceProviderMappingEntry {
1612
adapter?: string;
1713
adapterWeightsPath?: string;
1814
hfModelId: ModelId;
15+
provider: string;
1916
providerId: string;
2017
status: "live" | "staging";
2118
task: WidgetType;
19+
type?: "single-model" | "tag-filter";
20+
}
21+
22+
/**
23+
* Normalize inferenceProviderMapping to always return an array format.
24+
* This provides backward and forward compatibility for the API changes.
25+
*
26+
* Vendored from @huggingface/hub to avoid extra dependency.
27+
*/
28+
function normalizeInferenceProviderMapping(
29+
modelId: ModelId,
30+
inferenceProviderMapping?:
31+
| InferenceProviderMappingEntry[]
32+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
33+
): InferenceProviderMappingEntry[] {
34+
if (!inferenceProviderMapping) {
35+
return [];
36+
}
37+
38+
// If it's already an array, return it as is
39+
if (Array.isArray(inferenceProviderMapping)) {
40+
return inferenceProviderMapping;
41+
}
42+
43+
// Convert mapping to array format
44+
return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({
45+
provider,
46+
hfModelId: modelId,
47+
providerId: mapping.providerId,
48+
status: mapping.status,
49+
task: mapping.task,
50+
}));
2251
}
2352

2453
export async function fetchInferenceProviderMappingForModel(
@@ -27,8 +56,8 @@ export async function fetchInferenceProviderMappingForModel(
2756
options?: {
2857
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
2958
}
30-
): Promise<InferenceProviderMapping> {
31-
let inferenceProviderMapping: InferenceProviderMapping | null;
59+
): Promise<InferenceProviderMappingEntry[]> {
60+
let inferenceProviderMapping: InferenceProviderMappingEntry[] | null;
3261
if (inferenceProviderMappingCache.has(modelId)) {
3362
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
3463
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
@@ -55,7 +84,11 @@ export async function fetchInferenceProviderMappingForModel(
5584
);
5685
}
5786
}
58-
let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
87+
let payload: {
88+
inferenceProviderMapping?:
89+
| InferenceProviderMappingEntry[]
90+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>;
91+
} | null = null;
5992
try {
6093
payload = await resp.json();
6194
} catch {
@@ -72,7 +105,8 @@ export async function fetchInferenceProviderMappingForModel(
72105
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
73106
);
74107
}
75-
inferenceProviderMapping = payload.inferenceProviderMapping;
108+
inferenceProviderMapping = normalizeInferenceProviderMapping(modelId, payload.inferenceProviderMapping);
109+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
76110
}
77111
return inferenceProviderMapping;
78112
}
@@ -87,16 +121,12 @@ export async function getInferenceProviderMapping(
87121
options: {
88122
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
89123
}
90-
): Promise<InferenceProviderModelMapping | null> {
124+
): Promise<InferenceProviderMappingEntry | null> {
91125
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
92126
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
93127
}
94-
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
95-
params.modelId,
96-
params.accessToken,
97-
options
98-
);
99-
const providerMapping = inferenceProviderMapping[params.provider];
128+
const mappings = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options);
129+
const providerMapping = mappings.find((mapping) => mapping.provider === params.provider);
100130
if (providerMapping) {
101131
const equivalentTasks =
102132
params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
@@ -112,7 +142,7 @@ export async function getInferenceProviderMapping(
112142
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
113143
);
114144
}
115-
return { ...providerMapping, hfModelId: params.modelId };
145+
return providerMapping;
116146
}
117147
return null;
118148
}
@@ -139,8 +169,8 @@ export async function resolveProvider(
139169
if (!modelId) {
140170
throw new InferenceClientInputError("Specifying a model is required when provider is 'auto'");
141171
}
142-
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
143-
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
172+
const mappings = await fetchInferenceProviderMappingForModel(modelId);
173+
provider = mappings[0]?.provider as InferenceProvider | undefined;
144174
}
145175
if (!provider) {
146176
throw new InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);

packages/inference/src/providers/consts.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping.js";
1+
import type { InferenceProviderMappingEntry } from "../lib/getInferenceProviderMapping.js";
22
import type { InferenceProvider } from "../types.js";
33
import { type ModelId } from "../types.js";
44

@@ -11,7 +11,7 @@ import { type ModelId } from "../types.js";
1111
*/
1212
export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
1313
InferenceProvider,
14-
Record<ModelId, InferenceProviderModelMapping>
14+
Record<ModelId, InferenceProviderMappingEntry>
1515
> = {
1616
/**
1717
* "HF model ID" => "Model ID on Inference Provider's side"

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import {
88
} from "@huggingface/tasks";
99
import type { PipelineType, WidgetType } from "@huggingface/tasks";
1010
import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks";
11-
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping.js";
11+
import type { InferenceProviderMappingEntry } from "../lib/getInferenceProviderMapping.js";
1212
import { getProviderHelper } from "../lib/getProviderHelper.js";
1313
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.js";
1414
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js";
@@ -131,7 +131,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
131131
return (
132132
model: ModelDataMinimal,
133133
provider: InferenceProviderOrPolicy,
134-
inferenceProviderMapping?: InferenceProviderModelMapping,
134+
inferenceProviderMapping?: InferenceProviderMappingEntry,
135135
opts?: InferenceSnippetOptions
136136
): InferenceSnippet[] => {
137137
const providerModelId = inferenceProviderMapping?.providerId ?? model.id;
@@ -311,7 +311,7 @@ const snippets: Partial<
311311
(
312312
model: ModelDataMinimal,
313313
provider: InferenceProviderOrPolicy,
314-
inferenceProviderMapping?: InferenceProviderModelMapping,
314+
inferenceProviderMapping?: InferenceProviderMappingEntry,
315315
opts?: InferenceSnippetOptions
316316
) => InferenceSnippet[]
317317
>
@@ -350,7 +350,7 @@ const snippets: Partial<
350350
export function getInferenceSnippets(
351351
model: ModelDataMinimal,
352352
provider: InferenceProviderOrPolicy,
353-
inferenceProviderMapping?: InferenceProviderModelMapping,
353+
inferenceProviderMapping?: InferenceProviderMappingEntry,
354354
opts?: Record<string, unknown>
355355
): InferenceSnippet[] {
356356
return model.pipeline_tag && model.pipeline_tag in snippets

packages/inference/src/types.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
2-
import type { InferenceProviderModelMapping } from "./lib/getInferenceProviderMapping.js";
2+
import type { InferenceProviderMappingEntry } from "./lib/getInferenceProviderMapping.js";
33

44
/**
55
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
@@ -126,6 +126,6 @@ export interface UrlParams {
126126
export interface BodyParams<T extends Record<string, unknown> = Record<string, unknown>> {
127127
args: T;
128128
model: string;
129-
mapping?: InferenceProviderModelMapping | undefined;
129+
mapping?: InferenceProviderMappingEntry | undefined;
130130
task?: InferenceTask;
131131
}

0 commit comments

Comments
 (0)