Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions packages/inference/src/error.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import type { JsonObject } from "./vendor/type-fest/basic.js";

/**
* Base class for all inference-related errors.
*/
export abstract class HfInferenceError extends Error {
constructor(message: string) {
super(message);
this.name = "HfInferenceError";
}
}

export class HfInferenceInputError extends HfInferenceError {
constructor(message: string) {
super(message);
this.name = "InputError";
}
}

interface HttpRequest {
url: string;
method: string;
headers?: Record<string, string>;
body?: JsonObject;
}

interface HttpResponse {
requestId: string;
status: number;
body: JsonObject | string;
}

abstract class HfInferenceHttpRequestError extends HfInferenceError {
httpRequest: HttpRequest;
httpResponse: HttpResponse;
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
super(message);
this.httpRequest = {
...httpRequest,
...(httpRequest.headers ? {
headers: {
...httpRequest.headers,
...("Authorization" in httpRequest.headers ? { Authorization: `Bearer [redacted]` } : undefined),
/// redact authentication in the request headers
},
} : undefined)
};
this.httpResponse = httpResponse;
}
}

/**
* Thrown when the HTTP request to the provider fails, e.g. due to API issues or server errors.
*/
export class HfInferenceProviderApiError extends HfInferenceHttpRequestError {
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
super(message, httpRequest, httpResponse);
this.name = "ProviderApiError";
}
}

/**
* Thrown when the HTTP request to the hub fails, e.g. due to API issues or server errors.
*/
export class HfInferenceHubApiError extends HfInferenceHttpRequestError {
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
super(message, httpRequest, httpResponse);
this.name = "HubApiError";
}
}

/**
* Thrown when the inference output returned by the provider is invalid / does not match the expectations
*/
export class HfInferenceProviderOutputError extends HfInferenceError {
constructor(message: string) {
super(message);
this.name = "ProviderOutputError";
}
}

2 changes: 1 addition & 1 deletion packages/inference/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export { InferenceClient, InferenceClientEndpoint, HfInference } from "./InferenceClient.js";
export { InferenceOutputError } from "./lib/InferenceOutputError.js";
export * from "./error.js"
export * from "./types.js";
export * from "./tasks/index.js";
import * as snippets from "./snippets/index.js";
Expand Down
8 changes: 0 additions & 8 deletions packages/inference/src/lib/InferenceOutputError.ts

This file was deleted.

59 changes: 41 additions & 18 deletions packages/inference/src/lib/getInferenceProviderMapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts.js";
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference.js";
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types.js";
import { typedInclude } from "../utils/typedInclude.js";
import { HfInferenceHubApiError, HfInferenceInputError } from "../error.js";

export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();

Expand Down Expand Up @@ -32,27 +33,49 @@ export async function fetchInferenceProviderMappingForModel(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
} else {
const url = `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`;
const resp = await (options?.fetch ?? fetch)(
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
url,
{
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
}
);
if (resp.status === 404) {
throw new Error(`Model ${modelId} does not exist`);
if (!resp.ok) {
if (resp.headers.get("Content-Type")?.startsWith("application/json")) {
const error = await resp.json();
if ("error" in error && typeof error.error === "string") {
throw new HfInferenceHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}: ${error.error}`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: error }
);
}
} else {
throw new HfInferenceHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
}
inferenceProviderMapping = await resp
.json()
.then((json) => json.inferenceProviderMapping)
.catch(() => null);

if (inferenceProviderMapping) {
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
try {
payload = await resp.json();
} catch {
throw new HfInferenceHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}: malformed API response, invalid JSON`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
}

if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
if (!payload?.inferenceProviderMapping) {
throw new HfInferenceHubApiError(
`We have not been able to find inference provider information for model ${modelId}.`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
inferenceProviderMapping = payload.inferenceProviderMapping;
}
return inferenceProviderMapping;
}
Expand Down Expand Up @@ -83,7 +106,7 @@ export async function getInferenceProviderMapping(
? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
: [params.task];
if (!typedInclude(equivalentTasks, providerMapping.task)) {
throw new Error(
throw new HfInferenceInputError(
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
);
}
Expand All @@ -104,7 +127,7 @@ export async function resolveProvider(
): Promise<InferenceProvider> {
if (endpointUrl) {
if (provider) {
throw new Error("Specifying both endpointUrl and provider is not supported.");
throw new HfInferenceInputError("Specifying both endpointUrl and provider is not supported.");
}
/// Defaulting to hf-inference helpers / API
return "hf-inference";
Expand All @@ -117,13 +140,13 @@ export async function resolveProvider(
}
if (provider === "auto") {
if (!modelId) {
throw new Error("Specifying a model is required when provider is 'auto'");
throw new HfInferenceInputError("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
}
if (!provider) {
throw new Error(`No Inference Provider available for model ${modelId}.`);
throw new HfInferenceInputError(`No Inference Provider available for model ${modelId}.`);
}
return provider;
}
7 changes: 4 additions & 3 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import * as Replicate from "../providers/replicate.js";
import * as Sambanova from "../providers/sambanova.js";
import * as Together from "../providers/together.js";
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js";
import { HfInferenceInputError } from "../error.js";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
"black-forest-labs": {
Expand Down Expand Up @@ -281,14 +282,14 @@ export function getProviderHelper(
return new HFInference.HFInferenceTask();
}
if (!task) {
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
throw new HfInferenceInputError("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
}
if (!(provider in PROVIDERS)) {
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
throw new HfInferenceInputError(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
}
const providerTasks = PROVIDERS[provider];
if (!providerTasks || !(task in providerTasks)) {
throw new Error(
throw new HfInferenceInputError(
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
);
}
Expand Down
59 changes: 32 additions & 27 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type { InferenceProviderModelMapping } from "./getInferenceProviderMappin
import { getInferenceProviderMapping } from "./getInferenceProviderMapping.js";
import type { getProviderHelper } from "./getProviderHelper.js";
import { isUrl } from "./isUrl.js";
import { HfInferenceHubApiError, HfInferenceInputError } from "../error.js";

/**
* Lazy-loaded from huggingface.co/api/tasks when needed
Expand Down Expand Up @@ -33,10 +34,10 @@ export async function makeRequestOptions(

// Validate inputs
if (args.endpointUrl && provider !== "hf-inference") {
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
throw new HfInferenceInputError(`Cannot use endpointUrl with a third-party provider.`);
}
if (maybeModel && isUrl(maybeModel)) {
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
throw new HfInferenceInputError(`Model URLs are no longer supported. Use endpointUrl instead.`);
}

if (args.endpointUrl) {
Expand All @@ -51,38 +52,38 @@ export async function makeRequestOptions(
}

if (!maybeModel && !task) {
throw new Error("No model provided, and no task has been specified.");
throw new HfInferenceInputError("No model provided, and no task has been specified.");
}

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task!));

if (providerHelper.clientSideRoutingOnly && !maybeModel) {
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
throw new HfInferenceInputError(`Provider ${provider} requires a model ID to be passed directly.`);
}

const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
? ({
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
providerId: removeProviderPrefix(maybeModel!, provider),
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
hfModelId: maybeModel!,
status: "live",
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
providerId: removeProviderPrefix(maybeModel!, provider),
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
hfModelId: maybeModel!,
status: "live",
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
task: task!,
} satisfies InferenceProviderModelMapping)
: await getInferenceProviderMapping(
{
modelId: hfModel,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
task: task!,
} satisfies InferenceProviderModelMapping)
: await getInferenceProviderMapping(
{
modelId: hfModel,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
task: task!,
provider,
accessToken: args.accessToken,
},
{ fetch: options?.fetch }
);
provider,
accessToken: args.accessToken,
},
{ fetch: options?.fetch }
);
if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
throw new HfInferenceInputError(`We have not been able to find inference provider information for model ${hfModel}.`);
}

// Use the sync version with the resolved model
Expand Down Expand Up @@ -122,9 +123,8 @@ export function makeRequestOptionsFromResolvedModel(
if (providerHelper.clientSideRoutingOnly) {
// Closed-source providers require an accessToken (cannot be routed).
if (accessToken && accessToken.startsWith("hf_")) {
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
throw new HfInferenceInputError(`Provider ${provider} is closed-source and does not support HF tokens.`);
}
return "provider-key";
}
if (accessToken) {
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key";
Expand Down Expand Up @@ -197,23 +197,28 @@ async function loadDefaultModel(task: InferenceTask): Promise<string> {
}
const taskInfo = tasks[task];
if ((taskInfo?.models.length ?? 0) <= 0) {
throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
throw new HfInferenceInputError(`No default model defined for task ${task}, please define the model explicitly.`);
}
return taskInfo.models[0].id;
}

async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[] }>> {
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
const url = `${HF_HUB_URL}/api/tasks`;
const res = await fetch(url);

if (!res.ok) {
throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
throw new HfInferenceHubApiError(
"Failed to load tasks definitions from Hugging Face Hub.",
{ url, method: "GET" },
{ requestId: res.headers.get("x-request-id") ?? "", status: res.status, body: await res.text() },
);
}
return await res.json();
}

function removeProviderPrefix(model: string, provider: string): string {
if (!model.startsWith(`${provider}/`)) {
throw new Error(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
throw new HfInferenceInputError(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
}
return model.slice(provider.length + 1);
}
11 changes: 7 additions & 4 deletions packages/inference/src/providers/black-forest-labs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
*
* Thanks!
*/
import { InferenceOutputError } from "../lib/InferenceOutputError.js";
import { HfInferenceInputError, HfInferenceProviderApiError, HfInferenceProviderOutputError } from "../error.js";
import type { BodyParams, HeaderParams, UrlParams } from "../types.js";
import { delay } from "../utils/delay.js";
import { omit } from "../utils/omit.js";
Expand Down Expand Up @@ -52,7 +52,7 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implement

makeRoute(params: UrlParams): string {
if (!params) {
throw new Error("Params are required");
throw new HfInferenceInputError("Params are required");
}
return `/v1/${params.model}`;
}
Expand All @@ -70,7 +70,10 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implement
urlObj.searchParams.set("attempt", step.toString(10));
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
if (!resp.ok) {
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
throw new HfInferenceProviderApiError("Failed to fetch result from black forest labs API",
{ url: urlObj.toString(), method: "GET", headers: { "Content-Type": "application/json" } },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
const payload = await resp.json();
if (
Expand All @@ -92,6 +95,6 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implement
return await image.blob();
}
}
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
throw new HfInferenceProviderOutputError(`Timed out while waiting for the result from black forest labs API - aborting after 5 attempts`);
}
}
Loading
Loading