@@ -4,6 +4,7 @@ import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts.js";
44import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference.js" ;
55import type { InferenceProvider , InferenceProviderOrPolicy , ModelId } from "../types.js" ;
66import { typedInclude } from "../utils/typedInclude.js" ;
7+ import { InferenceClientHubApiError , InferenceClientInputError } from "../errors.js" ;
78
89export const inferenceProviderMappingCache = new Map < ModelId , InferenceProviderMapping > ( ) ;
910
@@ -32,27 +33,46 @@ export async function fetchInferenceProviderMappingForModel(
3233 // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
3334 inferenceProviderMapping = inferenceProviderMappingCache . get ( modelId ) ! ;
3435 } else {
35- const resp = await ( options ?. fetch ?? fetch ) (
36- `${ HF_HUB_URL } /api/models/${ modelId } ?expand[]=inferenceProviderMapping` ,
37- {
38- headers : accessToken ?. startsWith ( "hf_" ) ? { Authorization : `Bearer ${ accessToken } ` } : { } ,
36+ const url = `${ HF_HUB_URL } /api/models/${ modelId } ?expand[]=inferenceProviderMapping` ;
37+ const resp = await ( options ?. fetch ?? fetch ) ( url , {
38+ headers : accessToken ?. startsWith ( "hf_" ) ? { Authorization : `Bearer ${ accessToken } ` } : { } ,
39+ } ) ;
40+ if ( ! resp . ok ) {
41+ if ( resp . headers . get ( "Content-Type" ) ?. startsWith ( "application/json" ) ) {
42+ const error = await resp . json ( ) ;
43+ if ( "error" in error && typeof error . error === "string" ) {
44+ throw new InferenceClientHubApiError (
45+ `Failed to fetch inference provider mapping for model ${ modelId } : ${ error . error } ` ,
46+ { url, method : "GET" } ,
47+ { requestId : resp . headers . get ( "x-request-id" ) ?? "" , status : resp . status , body : error }
48+ ) ;
49+ }
50+ } else {
51+ throw new InferenceClientHubApiError (
52+ `Failed to fetch inference provider mapping for model ${ modelId } ` ,
53+ { url, method : "GET" } ,
54+ { requestId : resp . headers . get ( "x-request-id" ) ?? "" , status : resp . status , body : await resp . text ( ) }
55+ ) ;
3956 }
40- ) ;
41- if ( resp . status === 404 ) {
42- throw new Error ( `Model ${ modelId } does not exist` ) ;
4357 }
44- inferenceProviderMapping = await resp
45- . json ( )
46- . then ( ( json ) => json . inferenceProviderMapping )
47- . catch ( ( ) => null ) ;
48-
49- if ( inferenceProviderMapping ) {
50- inferenceProviderMappingCache . set ( modelId , inferenceProviderMapping ) ;
58+ let payload : { inferenceProviderMapping ?: InferenceProviderMapping } | null = null ;
59+ try {
60+ payload = await resp . json ( ) ;
61+ } catch {
62+ throw new InferenceClientHubApiError (
63+ `Failed to fetch inference provider mapping for model ${ modelId } : malformed API response, invalid JSON` ,
64+ { url, method : "GET" } ,
65+ { requestId : resp . headers . get ( "x-request-id" ) ?? "" , status : resp . status , body : await resp . text ( ) }
66+ ) ;
5167 }
52- }
53-
54- if ( ! inferenceProviderMapping ) {
55- throw new Error ( `We have not been able to find inference provider information for model ${ modelId } .` ) ;
68+ if ( ! payload ?. inferenceProviderMapping ) {
69+ throw new InferenceClientHubApiError (
70+ `We have not been able to find inference provider information for model ${ modelId } .` ,
71+ { url, method : "GET" } ,
72+ { requestId : resp . headers . get ( "x-request-id" ) ?? "" , status : resp . status , body : await resp . text ( ) }
73+ ) ;
74+ }
75+ inferenceProviderMapping = payload . inferenceProviderMapping ;
5676 }
5777 return inferenceProviderMapping ;
5878}
@@ -83,7 +103,7 @@ export async function getInferenceProviderMapping(
83103 ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
84104 : [ params . task ] ;
85105 if ( ! typedInclude ( equivalentTasks , providerMapping . task ) ) {
86- throw new Error (
106+ throw new InferenceClientInputError (
87107 `Model ${ params . modelId } is not supported for task ${ params . task } and provider ${ params . provider } . Supported task: ${ providerMapping . task } .`
88108 ) ;
89109 }
@@ -104,7 +124,7 @@ export async function resolveProvider(
104124) : Promise < InferenceProvider > {
105125 if ( endpointUrl ) {
106126 if ( provider ) {
107- throw new Error ( "Specifying both endpointUrl and provider is not supported." ) ;
127+ throw new InferenceClientInputError ( "Specifying both endpointUrl and provider is not supported." ) ;
108128 }
109129 /// Defaulting to hf-inference helpers / API
110130 return "hf-inference" ;
@@ -117,13 +137,13 @@ export async function resolveProvider(
117137 }
118138 if ( provider === "auto" ) {
119139 if ( ! modelId ) {
120- throw new Error ( "Specifying a model is required when provider is 'auto'" ) ;
140+ throw new InferenceClientInputError ( "Specifying a model is required when provider is 'auto'" ) ;
121141 }
122142 const inferenceProviderMapping = await fetchInferenceProviderMappingForModel ( modelId ) ;
123143 provider = Object . keys ( inferenceProviderMapping ) [ 0 ] as InferenceProvider | undefined ;
124144 }
125145 if ( ! provider ) {
126- throw new Error ( `No Inference Provider available for model ${ modelId } .` ) ;
146+ throw new InferenceClientInputError ( `No Inference Provider available for model ${ modelId } .` ) ;
127147 }
128148 return provider ;
129149}
0 commit comments