Skip to content

Commit c2d1490

Browse files
Wauplinjulien-c
andauthored
[Inference] Move provider-specific logic away from makeRequestOptions (1 provider == 1 module) (#1208)
Goal of this PR is to move away any provider-specific logic from `makeRequestOptions.ts`. In theory no tests should be updated as we only want to modify the internal logic without modifying input/output behavior. Feedback around structure, TS convention, naming, etc. is welcome. ### How it works ? Each provider must define a `providerConfig`consisting of: ```ts export interface ProviderConfig { baseUrl: string; makeBody: (params: BodyParams) => unknown; makeHeaders: (params: HeaderParams) => Record<string, string>; makeUrl: (params: UrlParams) => string; } ``` (with `HeaderParams`, `UrlParams`, `BodyParams` the parameters required to build these) --------- Co-authored-by: Julien Chaumond <[email protected]>
1 parent 9abb7f5 commit c2d1490

File tree

15 files changed

+428
-232
lines changed

15 files changed

+428
-232
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 66 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
2-
import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
3-
import { NEBIUS_API_BASE_URL } from "../providers/nebius";
4-
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
5-
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
6-
import { TOGETHER_API_BASE_URL } from "../providers/together";
7-
import { NOVITA_API_BASE_URL } from "../providers/novita";
8-
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
9-
import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
10-
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
11-
import type { InferenceProvider } from "../types";
12-
import type { InferenceTask, Options, RequestArgs } from "../types";
2+
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
3+
import { FAL_AI_CONFIG } from "../providers/fal-ai";
4+
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
5+
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
6+
import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
7+
import { NEBIUS_CONFIG } from "../providers/nebius";
8+
import { NOVITA_CONFIG } from "../providers/novita";
9+
import { REPLICATE_CONFIG } from "../providers/replicate";
10+
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
11+
import { TOGETHER_CONFIG } from "../providers/together";
12+
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
1313
import { isUrl } from "./isUrl";
1414
import { version as packageVersion, name as packageName } from "../../package.json";
1515
import { getProviderModelId } from "./getProviderModelId";
@@ -22,6 +22,22 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
2222
*/
2323
let tasks: Record<string, { models: { id: string }[] }> | null = null;
2424

25+
/**
26+
* Config to define how to serialize requests for each provider
27+
*/
28+
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
29+
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
30+
"fal-ai": FAL_AI_CONFIG,
31+
"fireworks-ai": FIREWORKS_AI_CONFIG,
32+
"hf-inference": HF_INFERENCE_CONFIG,
33+
hyperbolic: HYPERBOLIC_CONFIG,
34+
nebius: NEBIUS_CONFIG,
35+
novita: NOVITA_CONFIG,
36+
replicate: REPLICATE_CONFIG,
37+
sambanova: SAMBANOVA_CONFIG,
38+
together: TOGETHER_CONFIG,
39+
};
40+
2541
/**
2642
* Helper that prepares request arguments
2743
*/
@@ -37,10 +53,10 @@ export async function makeRequestOptions(
3753
}
3854
): Promise<{ url: string; info: RequestInit }> {
3955
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
40-
let otherArgs = remainingArgs;
4156
const provider = maybeProvider ?? "hf-inference";
57+
const providerConfig = providerConfigs[provider];
4258

43-
const { includeCredentials, task, chatCompletion } = options ?? {};
59+
const { includeCredentials, task, chatCompletion, signal } = options ?? {};
4460

4561
if (endpointUrl && provider !== "hf-inference") {
4662
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
@@ -51,6 +67,9 @@ export async function makeRequestOptions(
5167
if (!maybeModel && !task) {
5268
throw new Error("No model provided, and no task has been specified.");
5369
}
70+
if (!providerConfig) {
71+
throw new Error(`No provider config found for provider ${provider}`);
72+
}
5473
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
5574
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
5675
const model = await getProviderModelId({ model: hfModel, provider }, args, {
@@ -68,44 +87,52 @@ export async function makeRequestOptions(
6887
? "credentials-include"
6988
: "none";
7089

90+
// Make URL
7191
const url = endpointUrl
7292
? chatCompletion
7393
? endpointUrl + `/v1/chat/completions`
7494
: endpointUrl
75-
: makeUrl({
76-
authMethod,
77-
chatCompletion: chatCompletion ?? false,
95+
: providerConfig.makeUrl({
96+
baseUrl:
97+
authMethod !== "provider-key"
98+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
99+
: providerConfig.baseUrl,
78100
model,
79-
provider: provider ?? "hf-inference",
101+
chatCompletion,
80102
task,
81103
});
82104

83-
const headers: Record<string, string> = {};
84-
if (accessToken) {
85-
if (provider === "fal-ai" && authMethod === "provider-key") {
86-
headers["Authorization"] = `Key ${accessToken}`;
87-
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
88-
headers["X-Key"] = accessToken;
89-
} else {
90-
headers["Authorization"] = `Bearer ${accessToken}`;
91-
}
92-
}
93-
94-
// e.g. @huggingface/inference/3.1.3
95-
const ownUserAgent = `${packageName}/${packageVersion}`;
96-
headers["User-Agent"] = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
97-
.filter((x) => x !== undefined)
98-
.join(" ");
99-
105+
// Make headers
100106
const binary = "data" in args && !!args.data;
107+
const headers = providerConfig.makeHeaders({
108+
accessToken,
109+
authMethod,
110+
});
101111

112+
// Add content-type to headers
102113
if (!binary) {
103114
headers["Content-Type"] = "application/json";
104115
}
105116

106-
if (provider === "replicate") {
107-
headers["Prefer"] = "wait";
108-
}
117+
// Add user-agent to headers
118+
// e.g. @huggingface/inference/3.1.3
119+
const ownUserAgent = `${packageName}/${packageVersion}`;
120+
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
121+
.filter((x) => x !== undefined)
122+
.join(" ");
123+
headers["User-Agent"] = userAgent;
124+
125+
// Make body
126+
const body = binary
127+
? args.data
128+
: JSON.stringify(
129+
providerConfig.makeBody({
130+
args: remainingArgs as Record<string, unknown>,
131+
model,
132+
task,
133+
chatCompletion,
134+
})
135+
);
109136

110137
/**
111138
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
@@ -117,158 +144,17 @@ export async function makeRequestOptions(
117144
credentials = "include";
118145
}
119146

120-
/**
121-
* Replicate models wrap all inputs inside { input: ... }
122-
* Versioned Replicate models in the format `owner/model:version` expect the version in the body
123-
*/
124-
if (provider === "replicate") {
125-
const version = model.includes(":") ? model.split(":")[1] : undefined;
126-
(otherArgs as unknown) = { input: otherArgs, version };
127-
}
128-
129147
const info: RequestInit = {
130148
headers,
131149
method: "POST",
132-
body: binary
133-
? args.data
134-
: JSON.stringify({
135-
...otherArgs,
136-
...(task === "text-to-image" && provider === "hyperbolic"
137-
? { model_name: model }
138-
: chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
139-
? { model }
140-
: undefined),
141-
}),
150+
body,
142151
...(credentials ? { credentials } : undefined),
143-
signal: options?.signal,
152+
signal,
144153
};
145154

146155
return { url, info };
147156
}
148157

149-
function makeUrl(params: {
150-
authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
151-
chatCompletion: boolean;
152-
model: string;
153-
provider: InferenceProvider;
154-
task: InferenceTask | undefined;
155-
}): string {
156-
if (params.authMethod === "none" && params.provider !== "hf-inference") {
157-
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
158-
}
159-
160-
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
161-
switch (params.provider) {
162-
case "black-forest-labs": {
163-
const baseUrl = shouldProxy
164-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
165-
: BLACKFORESTLABS_AI_API_BASE_URL;
166-
return `${baseUrl}/${params.model}`;
167-
}
168-
case "fal-ai": {
169-
const baseUrl = shouldProxy
170-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
171-
: FAL_AI_API_BASE_URL;
172-
return `${baseUrl}/${params.model}`;
173-
}
174-
case "nebius": {
175-
const baseUrl = shouldProxy
176-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
177-
: NEBIUS_API_BASE_URL;
178-
179-
if (params.task === "text-to-image") {
180-
return `${baseUrl}/v1/images/generations`;
181-
}
182-
if (params.task === "text-generation") {
183-
if (params.chatCompletion) {
184-
return `${baseUrl}/v1/chat/completions`;
185-
}
186-
return `${baseUrl}/v1/completions`;
187-
}
188-
return baseUrl;
189-
}
190-
case "replicate": {
191-
const baseUrl = shouldProxy
192-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
193-
: REPLICATE_API_BASE_URL;
194-
if (params.model.includes(":")) {
195-
/// Versioned model
196-
return `${baseUrl}/v1/predictions`;
197-
}
198-
/// Evergreen / Canonical model
199-
return `${baseUrl}/v1/models/${params.model}/predictions`;
200-
}
201-
case "sambanova": {
202-
const baseUrl = shouldProxy
203-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
204-
: SAMBANOVA_API_BASE_URL;
205-
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
206-
if (params.task === "text-generation" && params.chatCompletion) {
207-
return `${baseUrl}/v1/chat/completions`;
208-
}
209-
return baseUrl;
210-
}
211-
case "together": {
212-
const baseUrl = shouldProxy
213-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
214-
: TOGETHER_API_BASE_URL;
215-
/// Together API matches OpenAI-like APIs: model is defined in the request body
216-
if (params.task === "text-to-image") {
217-
return `${baseUrl}/v1/images/generations`;
218-
}
219-
if (params.task === "text-generation") {
220-
if (params.chatCompletion) {
221-
return `${baseUrl}/v1/chat/completions`;
222-
}
223-
return `${baseUrl}/v1/completions`;
224-
}
225-
return baseUrl;
226-
}
227-
228-
case "fireworks-ai": {
229-
const baseUrl = shouldProxy
230-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
231-
: FIREWORKS_AI_API_BASE_URL;
232-
if (params.task === "text-generation" && params.chatCompletion) {
233-
return `${baseUrl}/v1/chat/completions`;
234-
}
235-
return baseUrl;
236-
}
237-
case "hyperbolic": {
238-
const baseUrl = shouldProxy
239-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
240-
: HYPERBOLIC_API_BASE_URL;
241-
242-
if (params.task === "text-to-image") {
243-
return `${baseUrl}/v1/images/generations`;
244-
}
245-
return `${baseUrl}/v1/chat/completions`;
246-
}
247-
case "novita": {
248-
const baseUrl = shouldProxy
249-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
250-
: NOVITA_API_BASE_URL;
251-
if (params.task === "text-generation") {
252-
if (params.chatCompletion) {
253-
return `${baseUrl}/chat/completions`;
254-
}
255-
return `${baseUrl}/completions`;
256-
}
257-
return baseUrl;
258-
}
259-
default: {
260-
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
261-
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
262-
/// when deployed on hf-inference, those two tasks are automatically compatible with one another.
263-
return `${baseUrl}/pipeline/${params.task}/${params.model}`;
264-
}
265-
if (params.task === "text-generation" && params.chatCompletion) {
266-
return `${baseUrl}/models/${params.model}/v1/chat/completions`;
267-
}
268-
return `${baseUrl}/models/${params.model}`;
269-
}
270-
}
271-
}
272158
async function loadDefaultModel(task: InferenceTask): Promise<string> {
273159
if (!tasks) {
274160
tasks = await loadTaskInfo();

packages/inference/src/providers/black-forest-labs.ts

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
2-
31
/**
42
* See the registered mapping of HF model ID => Black Forest Labs model ID here:
53
*
@@ -16,3 +14,29 @@ export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
1614
*
1715
* Thanks!
1816
*/
17+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18+
19+
const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
20+
21+
const makeBody = (params: BodyParams): Record<string, unknown> => {
22+
return params.args;
23+
};
24+
25+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
26+
if (params.authMethod === "provider-key") {
27+
return { "X-Key": `${params.accessToken}` };
28+
} else {
29+
return { Authorization: `Bearer ${params.accessToken}` };
30+
}
31+
};
32+
33+
const makeUrl = (params: UrlParams): string => {
34+
return `${params.baseUrl}/${params.model}`;
35+
};
36+
37+
export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
38+
baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
39+
makeBody,
40+
makeHeaders,
41+
makeUrl,
42+
};

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
2222
"hf-inference": {},
2323
hyperbolic: {},
2424
nebius: {},
25+
novita: {},
2526
replicate: {},
2627
sambanova: {},
2728
together: {},
28-
novita: {},
2929
};

packages/inference/src/providers/fal-ai.ts

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
export const FAL_AI_API_BASE_URL = "https://fal.run";
2-
31
/**
42
* See the registered mapping of HF model ID => Fal model ID here:
53
*
@@ -16,3 +14,27 @@ export const FAL_AI_API_BASE_URL = "https://fal.run";
1614
*
1715
* Thanks!
1816
*/
17+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18+
19+
const FAL_AI_API_BASE_URL = "https://fal.run";
20+
21+
const makeBody = (params: BodyParams): Record<string, unknown> => {
22+
return params.args;
23+
};
24+
25+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
26+
return {
27+
Authorization: params.authMethod === "provider-key" ? `Key ${params.accessToken}` : `Bearer ${params.accessToken}`,
28+
};
29+
};
30+
31+
const makeUrl = (params: UrlParams): string => {
32+
return `${params.baseUrl}/${params.model}`;
33+
};
34+
35+
export const FAL_AI_CONFIG: ProviderConfig = {
36+
baseUrl: FAL_AI_API_BASE_URL,
37+
makeBody,
38+
makeHeaders,
39+
makeUrl,
40+
};

0 commit comments

Comments
 (0)