Skip to content

Commit c3fc83d

Browse files
committed
narrower error types
1 parent 615c348 commit c3fc83d

File tree

8 files changed

+397
-69
lines changed

8 files changed

+397
-69
lines changed

packages/inference/src/error.ts

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import type { JsonObject } from "./vendor/type-fest/basic.js";
2+
3+
/**
4+
* Base class for all inference-related errors.
5+
*/
6+
export abstract class HfInferenceError extends Error {
7+
constructor(message: string) {
8+
super(message);
9+
this.name = "HfInferenceError";
10+
}
11+
}
12+
13+
export class HfInferenceInputError extends HfInferenceError {
14+
constructor(message: string) {
15+
super(message);
16+
this.name = "InputError";
17+
}
18+
}
19+
20+
interface HttpRequest {
21+
url: string;
22+
method: string;
23+
headers?: Record<string, string>;
24+
body?: JsonObject;
25+
}
26+
27+
interface HttpResponse {
28+
requestId: string;
29+
status: number;
30+
body: JsonObject | string;
31+
}
32+
33+
abstract class HfInferenceHttpRequestError extends HfInferenceError {
34+
httpRequest: HttpRequest;
35+
httpResponse: HttpResponse;
36+
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
37+
super(message);
38+
this.httpRequest = {
39+
...httpRequest,
40+
...(httpRequest.headers ? {
41+
headers: {
42+
...httpRequest.headers,
43+
...("Authorization" in httpRequest.headers ? { Authorization: `Bearer [redacted]` } : undefined),
44+
/// redact authentication in the request headers
45+
},
46+
} : undefined)
47+
};
48+
this.httpResponse = httpResponse;
49+
}
50+
}
51+
52+
/**
53+
* Thrown when the HTTP request to the provider fails, e.g. due to API issues or server errors.
54+
*/
55+
export class HfInferenceProviderApiError extends HfInferenceHttpRequestError {
56+
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
57+
super(message, httpRequest, httpResponse);
58+
this.name = "ProviderApiError";
59+
}
60+
}
61+
62+
/**
63+
* Thrown when the HTTP request to the hub fails, e.g. due to API issues or server errors.
64+
*/
65+
export class HfInferenceHubApiError extends HfInferenceHttpRequestError {
66+
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
67+
super(message, httpRequest, httpResponse);
68+
this.name = "HubApiError";
69+
}
70+
}
71+
72+
/**
73+
* Thrown when the inference output returned by the provider is invalid / does not match the expectations
74+
*/
75+
export class HfInferenceProviderOutputError extends HfInferenceError {
76+
httpRequest: HttpRequest;
77+
httpResponse: HttpResponse;
78+
error: string | JsonObject;
79+
constructor(message: string, httpRequest: { url: string; method: string; headers: Record<string, string>; body: JsonObject }, httpResponse: { requestId: string; status: number; headers: Record<string, string>; body: JsonObject | string }, error: string | JsonObject) {
80+
super(message);
81+
this.name = "ProviderOutputError";
82+
this.httpRequest = httpRequest;
83+
this.httpResponse = httpResponse;
84+
this.error = error;
85+
}
86+
}
87+

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts.js";
44
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference.js";
55
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types.js";
66
import { typedInclude } from "../utils/typedInclude.js";
7+
import { HfInferenceHubApiError, HfInferenceInputError } from "../error.js";
78

89
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
910

@@ -32,27 +33,49 @@ export async function fetchInferenceProviderMappingForModel(
3233
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
3334
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
3435
} else {
36+
const url = `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`;
3537
const resp = await (options?.fetch ?? fetch)(
36-
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
38+
url,
3739
{
3840
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
3941
}
4042
);
41-
if (resp.status === 404) {
42-
throw new Error(`Model ${modelId} does not exist`);
43+
if (!resp.ok) {
44+
if (resp.headers.get("Content-Type")?.startsWith("application/json")) {
45+
const error = await resp.json();
46+
if ("error" in error && typeof error.error === "string") {
47+
throw new HfInferenceHubApiError(
48+
`Failed to fetch inference provider mapping for model ${modelId}: ${error.error}`,
49+
{ url, method: "GET" },
50+
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: error }
51+
);
52+
}
53+
} else {
54+
throw new HfInferenceHubApiError(
55+
`Failed to fetch inference provider mapping for model ${modelId}`,
56+
{ url, method: "GET" },
57+
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
58+
);
59+
}
4360
}
44-
inferenceProviderMapping = await resp
45-
.json()
46-
.then((json) => json.inferenceProviderMapping)
47-
.catch(() => null);
48-
49-
if (inferenceProviderMapping) {
50-
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
61+
let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
62+
try {
63+
payload = await resp.json();
64+
} catch {
65+
throw new HfInferenceHubApiError(
66+
`Failed to fetch inference provider mapping for model ${modelId}: malformed API response, invalid JSON`,
67+
{ url, method: "GET" },
68+
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
69+
);
5170
}
52-
}
53-
54-
if (!inferenceProviderMapping) {
55-
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
71+
if (!payload?.inferenceProviderMapping) {
72+
throw new HfInferenceHubApiError(
73+
`We have not been able to find inference provider information for model ${modelId}.`,
74+
{ url, method: "GET" },
75+
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
76+
);
77+
}
78+
inferenceProviderMapping = payload.inferenceProviderMapping;
5679
}
5780
return inferenceProviderMapping;
5881
}
@@ -83,7 +106,7 @@ export async function getInferenceProviderMapping(
83106
? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
84107
: [params.task];
85108
if (!typedInclude(equivalentTasks, providerMapping.task)) {
86-
throw new Error(
109+
throw new HfInferenceInputError(
87110
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
88111
);
89112
}
@@ -104,7 +127,7 @@ export async function resolveProvider(
104127
): Promise<InferenceProvider> {
105128
if (endpointUrl) {
106129
if (provider) {
107-
throw new Error("Specifying both endpointUrl and provider is not supported.");
130+
throw new HfInferenceInputError("Specifying both endpointUrl and provider is not supported.");
108131
}
109132
/// Defaulting to hf-inference helpers / API
110133
return "hf-inference";
@@ -117,13 +140,13 @@ export async function resolveProvider(
117140
}
118141
if (provider === "auto") {
119142
if (!modelId) {
120-
throw new Error("Specifying a model is required when provider is 'auto'");
143+
throw new HfInferenceInputError("Specifying a model is required when provider is 'auto'");
121144
}
122145
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
123146
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
124147
}
125148
if (!provider) {
126-
throw new Error(`No Inference Provider available for model ${modelId}.`);
149+
throw new HfInferenceInputError(`No Inference Provider available for model ${modelId}.`);
127150
}
128151
return provider;
129152
}

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import * as Replicate from "../providers/replicate.js";
4848
import * as Sambanova from "../providers/sambanova.js";
4949
import * as Together from "../providers/together.js";
5050
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js";
51+
import { HfInferenceInputError } from "../error.js";
5152

5253
export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
5354
"black-forest-labs": {
@@ -281,14 +282,14 @@ export function getProviderHelper(
281282
return new HFInference.HFInferenceTask();
282283
}
283284
if (!task) {
284-
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
285+
throw new HfInferenceInputError("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
285286
}
286287
if (!(provider in PROVIDERS)) {
287-
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
288+
throw new HfInferenceInputError(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
288289
}
289290
const providerTasks = PROVIDERS[provider];
290291
if (!providerTasks || !(task in providerTasks)) {
291-
throw new Error(
292+
throw new HfInferenceInputError(
292293
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
293294
);
294295
}

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type { InferenceProviderModelMapping } from "./getInferenceProviderMappin
55
import { getInferenceProviderMapping } from "./getInferenceProviderMapping.js";
66
import type { getProviderHelper } from "./getProviderHelper.js";
77
import { isUrl } from "./isUrl.js";
8+
import { HfInferenceHubApiError, HfInferenceInputError } from "../error.js";
89

910
/**
1011
* Lazy-loaded from huggingface.co/api/tasks when needed
@@ -33,10 +34,10 @@ export async function makeRequestOptions(
3334

3435
// Validate inputs
3536
if (args.endpointUrl && provider !== "hf-inference") {
36-
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
37+
throw new HfInferenceInputError(`Cannot use endpointUrl with a third-party provider.`);
3738
}
3839
if (maybeModel && isUrl(maybeModel)) {
39-
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
40+
throw new HfInferenceInputError(`Model URLs are no longer supported. Use endpointUrl instead.`);
4041
}
4142

4243
if (args.endpointUrl) {
@@ -51,38 +52,38 @@ export async function makeRequestOptions(
5152
}
5253

5354
if (!maybeModel && !task) {
54-
throw new Error("No model provided, and no task has been specified.");
55+
throw new HfInferenceInputError("No model provided, and no task has been specified.");
5556
}
5657

5758
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
5859
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
5960

6061
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
61-
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
62+
throw new HfInferenceInputError(`Provider ${provider} requires a model ID to be passed directly.`);
6263
}
6364

6465
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
6566
? ({
66-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
67-
providerId: removeProviderPrefix(maybeModel!, provider),
68-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
69-
hfModelId: maybeModel!,
70-
status: "live",
67+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
68+
providerId: removeProviderPrefix(maybeModel!, provider),
69+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
70+
hfModelId: maybeModel!,
71+
status: "live",
72+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
73+
task: task!,
74+
} satisfies InferenceProviderModelMapping)
75+
: await getInferenceProviderMapping(
76+
{
77+
modelId: hfModel,
7178
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
7279
task: task!,
73-
} satisfies InferenceProviderModelMapping)
74-
: await getInferenceProviderMapping(
75-
{
76-
modelId: hfModel,
77-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
78-
task: task!,
79-
provider,
80-
accessToken: args.accessToken,
81-
},
82-
{ fetch: options?.fetch }
83-
);
80+
provider,
81+
accessToken: args.accessToken,
82+
},
83+
{ fetch: options?.fetch }
84+
);
8485
if (!inferenceProviderMapping) {
85-
throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
86+
throw new HfInferenceInputError(`We have not been able to find inference provider information for model ${hfModel}.`);
8687
}
8788

8889
// Use the sync version with the resolved model
@@ -122,9 +123,8 @@ export function makeRequestOptionsFromResolvedModel(
122123
if (providerHelper.clientSideRoutingOnly) {
123124
// Closed-source providers require an accessToken (cannot be routed).
124125
if (accessToken && accessToken.startsWith("hf_")) {
125-
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
126+
throw new HfInferenceInputError(`Provider ${provider} is closed-source and does not support HF tokens.`);
126127
}
127-
return "provider-key";
128128
}
129129
if (accessToken) {
130130
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key";
@@ -197,23 +197,28 @@ async function loadDefaultModel(task: InferenceTask): Promise<string> {
197197
}
198198
const taskInfo = tasks[task];
199199
if ((taskInfo?.models.length ?? 0) <= 0) {
200-
throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
200+
throw new HfInferenceInputError(`No default model defined for task ${task}, please define the model explicitly.`);
201201
}
202202
return taskInfo.models[0].id;
203203
}
204204

205205
async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[] }>> {
206-
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
206+
const url = `${HF_HUB_URL}/api/tasks`;
207+
const res = await fetch(url);
207208

208209
if (!res.ok) {
209-
throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
210+
throw new HfInferenceHubApiError(
211+
"Failed to load tasks definitions from Hugging Face Hub.",
212+
{ url, method: "GET" },
213+
{ requestId: res.headers.get("x-request-id") ?? "", status: res.status, body: await res.text() },
214+
);
210215
}
211216
return await res.json();
212217
}
213218

214219
function removeProviderPrefix(model: string, provider: string): string {
215220
if (!model.startsWith(`${provider}/`)) {
216-
throw new Error(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
221+
throw new HfInferenceInputError(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
217222
}
218223
return model.slice(provider.length + 1);
219224
}

0 commit comments

Comments
 (0)