11import type { WidgetType } from "@huggingface/tasks" ;
2- import type { InferenceProvider , ModelId } from "../types" ;
32import { HF_HUB_URL } from "../config" ;
43import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts" ;
54import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference" ;
5+ import type { InferenceProvider , InferenceProviderOrPolicy , ModelId } from "../types" ;
66import { typedInclude } from "../utils/typedInclude" ;
77
88export const inferenceProviderMappingCache = new Map < ModelId , InferenceProviderMapping > ( ) ;
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
2020 task : WidgetType ;
2121}
2222
23- export async function getInferenceProviderMapping (
24- params : {
25- accessToken ?: string ;
26- modelId : ModelId ;
27- provider : InferenceProvider ;
28- task : WidgetType ;
29- } ,
30- options : {
23+ export async function fetchInferenceProviderMappingForModel (
24+ modelId : ModelId ,
25+ accessToken ?: string ,
26+ options ?: {
3127 fetch ?: ( input : RequestInfo , init ?: RequestInit ) => Promise < Response > ;
3228 }
33- ) : Promise < InferenceProviderModelMapping | null > {
34- if ( HARDCODED_MODEL_INFERENCE_MAPPING [ params . provider ] [ params . modelId ] ) {
35- return HARDCODED_MODEL_INFERENCE_MAPPING [ params . provider ] [ params . modelId ] ;
36- }
29+ ) : Promise < InferenceProviderMapping > {
3730 let inferenceProviderMapping : InferenceProviderMapping | null ;
38- if ( inferenceProviderMappingCache . has ( params . modelId ) ) {
31+ if ( inferenceProviderMappingCache . has ( modelId ) ) {
3932 // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40- inferenceProviderMapping = inferenceProviderMappingCache . get ( params . modelId ) ! ;
33+ inferenceProviderMapping = inferenceProviderMappingCache . get ( modelId ) ! ;
4134 } else {
4235 const resp = await ( options ?. fetch ?? fetch ) (
43- `${ HF_HUB_URL } /api/models/${ params . modelId } ?expand[]=inferenceProviderMapping` ,
36+ `${ HF_HUB_URL } /api/models/${ modelId } ?expand[]=inferenceProviderMapping` ,
4437 {
45- headers : params . accessToken ?. startsWith ( "hf_" ) ? { Authorization : `Bearer ${ params . accessToken } ` } : { } ,
38+ headers : accessToken ?. startsWith ( "hf_" ) ? { Authorization : `Bearer ${ accessToken } ` } : { } ,
4639 }
4740 ) ;
4841 if ( resp . status === 404 ) {
49- throw new Error ( `Model ${ params . modelId } does not exist` ) ;
42+ throw new Error ( `Model ${ modelId } does not exist` ) ;
5043 }
5144 inferenceProviderMapping = await resp
5245 . json ( )
5346 . then ( ( json ) => json . inferenceProviderMapping )
5447 . catch ( ( ) => null ) ;
48+
49+ if ( inferenceProviderMapping ) {
50+ inferenceProviderMappingCache . set ( modelId , inferenceProviderMapping ) ;
51+ }
5552 }
5653
5754 if ( ! inferenceProviderMapping ) {
58- throw new Error ( `We have not been able to find inference provider information for model ${ params . modelId } .` ) ;
55+ throw new Error ( `We have not been able to find inference provider information for model ${ modelId } .` ) ;
5956 }
57+ return inferenceProviderMapping ;
58+ }
6059
60+ export async function getInferenceProviderMapping (
61+ params : {
62+ accessToken ?: string ;
63+ modelId : ModelId ;
64+ provider : InferenceProvider ;
65+ task : WidgetType ;
66+ } ,
67+ options : {
68+ fetch ?: ( input : RequestInfo , init ?: RequestInit ) => Promise < Response > ;
69+ }
70+ ) : Promise < InferenceProviderModelMapping | null > {
71+ if ( HARDCODED_MODEL_INFERENCE_MAPPING [ params . provider ] [ params . modelId ] ) {
72+ return HARDCODED_MODEL_INFERENCE_MAPPING [ params . provider ] [ params . modelId ] ;
73+ }
74+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel (
75+ params . modelId ,
76+ params . accessToken ,
77+ options
78+ ) ;
6179 const providerMapping = inferenceProviderMapping [ params . provider ] ;
6280 if ( providerMapping ) {
6381 const equivalentTasks =
@@ -78,3 +96,23 @@ export async function getInferenceProviderMapping(
7896 }
7997 return null ;
8098}
99+
100+ export async function resolveProvider (
101+ provider ?: InferenceProviderOrPolicy ,
102+ modelId ?: string
103+ ) : Promise < InferenceProvider > {
104+ if ( ! provider ) {
105+ console . log (
106+ "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
107+ ) ;
108+ provider = "auto" ;
109+ }
110+ if ( provider === "auto" ) {
111+ if ( ! modelId ) {
112+ throw new Error ( "Specifying a model is required when provider is 'auto'" ) ;
113+ }
114+ const inferenceProviderMapping = await fetchInferenceProviderMappingForModel ( modelId ) ;
115+ provider = Object . keys ( inferenceProviderMapping ) [ 0 ] as InferenceProvider ;
116+ }
117+ return provider ;
118+ }
0 commit comments