Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 134 additions & 115 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
import { omit } from "../utils/omit";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";

Expand All @@ -31,62 +30,46 @@ export async function makeRequestOptions(
chatCompletion?: boolean;
}
): Promise<{ url: string; info: RequestInit }> {
const { accessToken, endpointUrl, provider, ...otherArgs } = args;
let { model } = args;
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args;
const provider = maybeProvider ?? "hf-inference";

const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
options ?? {};

const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
if (endpointUrl && provider !== "hf-inference") {
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
}

if (!model && !tasks && taskHint) {
const res = await fetch(`${HF_HUB_URL}/api/tasks`);

if (res.ok) {
tasks = await res.json();
}
if (forceTask && provider !== "hf-inference") {
throw new Error(`Cannot use forceTask with a third-party provider.`);
}

if (!model && tasks && taskHint) {
const taskInfo = tasks[taskHint];
if (taskInfo) {
model = taskInfo.models[0].id;
let model: string;
if (!maybeModel) {
Copy link
Member

@coyotte508 coyotte508 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add something like

if (maybeModel && URL.parse(maybeModel)) {
  throw new TypeError("model URLs are no longer supported, use endpointUrl instead")
}

(or a regex check since URL.parse may not be supported on some safari devices)

if (taskHint) {
model = mapModel({ model: await loadDefaultModel(taskHint), provider });
} else {
throw new Error("No model provided, and no default model found for this task");
/// TODO : change error message ^
}
} else {
model = mapModel({ model: maybeModel, provider });
}

if (!model) {
throw new Error("No model provided, and no default model found for this task");
}
if (provider) {
if (!INFERENCE_PROVIDERS.includes(provider)) {
throw new Error("Unknown Inference provider");
}
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}

const modelId = (() => {
switch (provider) {
case "replicate":
return REPLICATE_MODEL_IDS[model];
case "sambanova":
return SAMBANOVA_MODEL_IDS[model];
case "together":
return TOGETHER_MODEL_IDS[model]?.id;
case "fal-ai":
return FAL_AI_MODEL_IDS[model];
default:
return model;
}
})();

if (!modelId) {
throw new Error(`Model ${model} is not supported for provider ${provider}`);
}
const url = endpointUrl
? chatCompletion
? endpointUrl + `/v1/chat/completions`
: endpointUrl
: makeUrl({
model,
provider: provider ?? "hf-inference",
taskHint,
chatCompletion: chatCompletion ?? false,
forceTask,
});

model = modelId;
const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
}

const binary = "data" in args && !!args.data;
Expand All @@ -95,73 +78,20 @@ export async function makeRequestOptions(
headers["Content-Type"] = "application/json";
}

if (wait_for_model) {
headers["X-Wait-For-Model"] = "true";
}
if (use_cache === false) {
headers["X-Use-Cache"] = "false";
}
if (dont_load_model) {
headers["X-Load-Model"] = "0";
}
if (provider === "replicate") {
headers["Prefer"] = "wait";
}

let url = (() => {
if (endpointUrl && isUrl(model)) {
throw new TypeError("Both model and endpointUrl cannot be URLs");
}
if (isUrl(model)) {
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
return model;
if (provider === "hf-inference") {
if (wait_for_model) {
headers["X-Wait-For-Model"] = "true";
}
if (endpointUrl) {
return endpointUrl;
if (use_cache === false) {
headers["X-Use-Cache"] = "false";
}
if (forceTask) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`;
if (dont_load_model) {
headers["X-Load-Model"] = "0";
}
if (provider) {
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}
if (accessToken.startsWith("hf_")) {
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
throw new Error("Inference proxying is not implemented yet");
} else {
switch (provider) {
case "fal-ai":
return `${FAL_AI_API_BASE_URL}/${model}`;
case "replicate":
if (model.includes(":")) {
// Versioned models are in the form of `owner/model:version`
return `${REPLICATE_API_BASE_URL}/v1/predictions`;
} else {
// Unversioned models are in the form of `owner/model`
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
}
case "sambanova":
return SAMBANOVA_API_BASE_URL;
case "together":
if (taskHint === "text-to-image") {
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
}
return TOGETHER_API_BASE_URL;
default:
break;
}
}
}

return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
})();

if (chatCompletion && !url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}
if (provider === "together" && taskHint === "text-generation" && !chatCompletion) {
url += "/v1/completions";

if (provider === "replicate") {
headers["Prefer"] = "wait";
}

/**
Expand All @@ -188,13 +118,102 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai"
? omit(otherArgs, "model")
: { ...otherArgs, model }),
...otherArgs,
...(chatCompletion || provider === "together" ? { model } : undefined),
}),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
};

return { url, info };
}

function mapModel(params: { model: string; provider: InferenceProvider }): string {
const model = (() => {
switch (params.provider) {
case "fal-ai":
return FAL_AI_MODEL_IDS[params.model];
case "replicate":
return REPLICATE_MODEL_IDS[params.model];
case "sambanova":
return SAMBANOVA_MODEL_IDS[params.model];
case "together":
return TOGETHER_MODEL_IDS[params.model]?.id;
case "hf-inference":
return params.model;
}
})();

if (!model) {
throw new Error(`Model ${params.model} is not supported for provider ${params.provider}`);
}
return model;
}

function makeUrl(params: {
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
chatCompletion: boolean;
forceTask?: string | InferenceTask;
}): string {
switch (params.provider) {
case "fal-ai":
return `${FAL_AI_API_BASE_URL}/${params.model}`;
case "replicate": {
if (params.model.includes(":")) {
/// Versioned model
return `${REPLICATE_API_BASE_URL}/v1/predictions`;
}
/// Evergreen / Canonical model
return `${REPLICATE_API_BASE_URL}/v1/models/${params.model}/predictions`;
}
case "sambanova":
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
if (params.taskHint === "text-generation" && params.chatCompletion) {
return `${SAMBANOVA_API_BASE_URL}/v1/chat/completions`;
}
return SAMBANOVA_API_BASE_URL;
case "together": {
/// Together API matches OpenAI-like APIs: model is defined in the request body
if (params.taskHint === "text-to-image") {
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
}
if (params.taskHint === "text-generation") {
if (params.chatCompletion) {
return `${TOGETHER_API_BASE_URL}/v1/chat/completions`;
}
return `${TOGETHER_API_BASE_URL}/v1/completions`;
}
return TOGETHER_API_BASE_URL;
}
default: {
const url = params.forceTask
? `${HF_INFERENCE_API_BASE_URL}/pipeline/${params.forceTask}/${params.model}`
: `${HF_INFERENCE_API_BASE_URL}/models/${params.model}`;
if (params.taskHint === "text-generation" && params.chatCompletion) {
return url + `/v1/chat/completions`;
}
return url;
}
}
}
async function loadDefaultModel(task: InferenceTask): Promise<string> {
if (!tasks) {
tasks = await loadTaskInfo();
}
const taskInfo = tasks[task];
if ((taskInfo?.models.length ?? 0) <= 0) {
throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
}
return taskInfo.models[0].id;
}

async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[] }>> {
const res = await fetch(`${HF_HUB_URL}/api/tasks`);

if (!res.ok) {
throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
}
return await res.json();
}
2 changes: 1 addition & 1 deletion packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ type FalAiId = string;
/**
* Mapping from HF model ID -> fal.ai app id
*/
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = {
export const FAL_AI_MODEL_IDS: Partial<Record<ModelId, FalAiId>> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type ReplicateId = string;
* 'https://api.replicate.com/v1/models'
* ```
*/
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
export const REPLICATE_MODEL_IDS: Partial<Record<ModelId, ReplicateId>> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning":
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type SambanovaId = string;
/**
* https://community.sambanova.ai/t/supported-models/193
*/
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = {
export const SAMBANOVA_MODEL_IDS: Partial<Record<ModelId, SambanovaId>> = {
/** Chat completion / conversational */
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
Expand Down
5 changes: 2 additions & 3 deletions packages/inference/src/providers/together.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ type TogetherId = string;
/**
* https://docs.together.ai/reference/models-1
*/
export const TOGETHER_MODEL_IDS: Record<
ModelId,
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }
export const TOGETHER_MODEL_IDS: Partial<
Record<ModelId, { id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }>
> = {
/** text-to-image */
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
Expand Down
6 changes: 5 additions & 1 deletion packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ export async function* streamingRequest<T>(
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (output.error) {
if (typeof output.error === "string") {
throw new Error(output.error);
}
if (output.error && "message" in output.error && typeof output.error.message === "string") {
/// OpenAI errors
throw new Error(output.error.message);
}
}

throw new Error(`Server response contains error: ${response.status}`);
Expand Down
Loading
Loading