@@ -6,19 +6,48 @@ import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../t
66import { typedInclude } from "../utils/typedInclude.js" ;
77import { InferenceClientHubApiError , InferenceClientInputError } from "../errors.js" ;
88
9- export const inferenceProviderMappingCache = new Map < ModelId , InferenceProviderMapping > ( ) ;
9+ export const inferenceProviderMappingCache = new Map < ModelId , InferenceProviderMappingEntry [ ] > ( ) ;
1010
11- export type InferenceProviderMapping = Partial <
12- Record < InferenceProvider , Omit < InferenceProviderModelMapping , "hfModelId" > >
13- > ;
14-
15- export interface InferenceProviderModelMapping {
11+ export interface InferenceProviderMappingEntry {
1612 adapter ?: string ;
1713 adapterWeightsPath ?: string ;
1814 hfModelId : ModelId ;
15+ provider : string ;
1916 providerId : string ;
2017 status : "live" | "staging" ;
2118 task : WidgetType ;
19+ type ?: "single-model" | "tag-filter" ;
20+ }
21+
22+ /**
23+ * Normalize inferenceProviderMapping to always return an array format.
24+ * This provides backward and forward compatibility for the API changes.
25+ *
26+ * Vendored from @huggingface/hub to avoid extra dependency.
27+ */
28+ function normalizeInferenceProviderMapping (
29+ modelId : ModelId ,
30+ inferenceProviderMapping ?:
31+ | InferenceProviderMappingEntry [ ]
32+ | Record < string , { providerId : string ; status : "live" | "staging" ; task : WidgetType } >
33+ ) : InferenceProviderMappingEntry [ ] {
34+ if ( ! inferenceProviderMapping ) {
35+ return [ ] ;
36+ }
37+
38+ // If it's already an array, return it as is
39+ if ( Array . isArray ( inferenceProviderMapping ) ) {
40+ return inferenceProviderMapping ;
41+ }
42+
43+ // Convert mapping to array format
44+ return Object . entries ( inferenceProviderMapping ) . map ( ( [ provider , mapping ] ) => ( {
45+ provider,
46+ hfModelId : modelId ,
47+ providerId : mapping . providerId ,
48+ status : mapping . status ,
49+ task : mapping . task ,
50+ } ) ) ;
2251}
2352
2453export async function fetchInferenceProviderMappingForModel (
@@ -27,8 +56,8 @@ export async function fetchInferenceProviderMappingForModel(
2756 options ?: {
2857 fetch ?: ( input : RequestInfo , init ?: RequestInit ) => Promise < Response > ;
2958 }
30- ) : Promise < InferenceProviderMapping > {
31- let inferenceProviderMapping : InferenceProviderMapping | null ;
59+ ) : Promise < InferenceProviderMappingEntry [ ] > {
60+ let inferenceProviderMapping : InferenceProviderMappingEntry [ ] | null ;
3261 if ( inferenceProviderMappingCache . has ( modelId ) ) {
3362 // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
3463 inferenceProviderMapping = inferenceProviderMappingCache . get ( modelId ) ! ;
@@ -55,7 +84,11 @@ export async function fetchInferenceProviderMappingForModel(
5584 ) ;
5685 }
5786 }
58- let payload : { inferenceProviderMapping ?: InferenceProviderMapping } | null = null ;
87+ let payload : {
88+ inferenceProviderMapping ?:
89+ | InferenceProviderMappingEntry [ ]
90+ | Record < string , { providerId : string ; status : "live" | "staging" ; task : WidgetType } > ;
91+ } | null = null ;
5992 try {
6093 payload = await resp . json ( ) ;
6194 } catch {
@@ -72,7 +105,8 @@ export async function fetchInferenceProviderMappingForModel(
72105 { requestId : resp . headers . get ( "x-request-id" ) ?? "" , status : resp . status , body : await resp . text ( ) }
73106 ) ;
74107 }
75- inferenceProviderMapping = payload . inferenceProviderMapping ;
108+ inferenceProviderMapping = normalizeInferenceProviderMapping ( modelId , payload . inferenceProviderMapping ) ;
109+ inferenceProviderMappingCache . set ( modelId , inferenceProviderMapping ) ;
76110 }
77111 return inferenceProviderMapping ;
78112}
@@ -87,16 +121,12 @@ export async function getInferenceProviderMapping(
87121 options : {
88122 fetch ?: ( input : RequestInfo , init ?: RequestInit ) => Promise < Response > ;
89123 }
90- ) : Promise < InferenceProviderModelMapping | null > {
124+ ) : Promise < InferenceProviderMappingEntry | null > {
91125 if ( HARDCODED_MODEL_INFERENCE_MAPPING [ params . provider ] [ params . modelId ] ) {
92126 return HARDCODED_MODEL_INFERENCE_MAPPING [ params . provider ] [ params . modelId ] ;
93127 }
94- const inferenceProviderMapping = await fetchInferenceProviderMappingForModel (
95- params . modelId ,
96- params . accessToken ,
97- options
98- ) ;
99- const providerMapping = inferenceProviderMapping [ params . provider ] ;
128+ const mappings = await fetchInferenceProviderMappingForModel ( params . modelId , params . accessToken , options ) ;
129+ const providerMapping = mappings . find ( ( mapping ) => mapping . provider === params . provider ) ;
100130 if ( providerMapping ) {
101131 const equivalentTasks =
102132 params . provider === "hf-inference" && typedInclude ( EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS , params . task )
@@ -112,7 +142,7 @@ export async function getInferenceProviderMapping(
112142 `Model ${ params . modelId } is in staging mode for provider ${ params . provider } . Meant for test purposes only.`
113143 ) ;
114144 }
115- return { ... providerMapping , hfModelId : params . modelId } ;
145+ return providerMapping ;
116146 }
117147 return null ;
118148}
@@ -139,8 +169,8 @@ export async function resolveProvider(
139169 if ( ! modelId ) {
140170 throw new InferenceClientInputError ( "Specifying a model is required when provider is 'auto'" ) ;
141171 }
142- const inferenceProviderMapping = await fetchInferenceProviderMappingForModel ( modelId ) ;
143- provider = Object . keys ( inferenceProviderMapping ) [ 0 ] as InferenceProvider | undefined ;
172+ const mappings = await fetchInferenceProviderMappingForModel ( modelId ) ;
173+ provider = mappings [ 0 ] ?. provider as InferenceProvider | undefined ;
144174 }
145175 if ( ! provider ) {
146176 throw new InferenceClientInputError ( `No Inference Provider available for model ${ modelId } .` ) ;
0 commit comments