Skip to content

Commit 455f12c

Browse files
authored
[Inference] Proxy calls to 3rd party providers (#1108)
Companion to huggingface-internal/moon-landing#12072 (internal)
1 parent 4b2fbb6 commit 455f12c

File tree

4 files changed

+54
-21
lines changed

4 files changed

+54
-21
lines changed

packages/inference/src/config.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export const HF_HUB_URL = "https://huggingface.co";
2+
export const HF_INFERENCE_API_URL = "https://api-inference.huggingface.co";

packages/inference/src/lib/getDefaultTask.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { HF_HUB_URL } from "../config";
12
import { isUrl } from "./isUrl";
23

34
/**
@@ -8,7 +9,6 @@ import { isUrl } from "./isUrl";
89
const taskCache = new Map<string, { task: string; date: Date }>();
910
const CACHE_DURATION = 10 * 60 * 1000;
1011
const MAX_CACHE_ITEMS = 1000;
11-
export const HF_HUB_URL = "https://huggingface.co";
1212

1313
export interface DefaultTaskOptions {
1414
fetch?: typeof fetch;

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config";
12
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
23
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
34
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
45
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
56
import type { InferenceProvider } from "../types";
67
import type { InferenceTask, Options, RequestArgs } from "../types";
7-
import { HF_HUB_URL } from "./getDefaultTask";
88
import { isUrl } from "./isUrl";
99

10-
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
10+
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
1111

1212
/**
1313
* Lazy-loaded from huggingface.co/api/tasks when needed
@@ -59,21 +59,32 @@ export async function makeRequestOptions(
5959
model = mapModel({ model: maybeModel, provider });
6060
}
6161

62+
/// If accessToken is passed, it should take precedence over includeCredentials
63+
const authMethod = accessToken
64+
? accessToken.startsWith("hf_")
65+
? "hf-token"
66+
: "provider-key"
67+
: includeCredentials === "include"
68+
? "credentials-include"
69+
: "none";
70+
6271
const url = endpointUrl
6372
? chatCompletion
6473
? endpointUrl + `/v1/chat/completions`
6574
: endpointUrl
6675
: makeUrl({
76+
authMethod,
77+
chatCompletion: chatCompletion ?? false,
78+
forceTask,
6779
model,
6880
provider: provider ?? "hf-inference",
6981
taskHint,
70-
chatCompletion: chatCompletion ?? false,
71-
forceTask,
7282
});
7383

7484
const headers: Record<string, string> = {};
7585
if (accessToken) {
76-
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
86+
headers["Authorization"] =
87+
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
7788
}
7889

7990
const binary = "data" in args && !!args.data;
@@ -155,46 +166,66 @@ function mapModel(params: { model: string; provider: InferenceProvider }): strin
155166
}
156167

157168
function makeUrl(params: {
169+
authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
170+
chatCompletion: boolean;
158171
model: string;
159172
provider: InferenceProvider;
160173
taskHint: InferenceTask | undefined;
161-
chatCompletion: boolean;
162174
forceTask?: string | InferenceTask;
163175
}): string {
176+
if (params.authMethod === "none" && params.provider !== "hf-inference") {
177+
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
178+
}
179+
180+
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
164181
switch (params.provider) {
165-
case "fal-ai":
166-
return `${FAL_AI_API_BASE_URL}/${params.model}`;
182+
case "fal-ai": {
183+
const baseUrl = shouldProxy
184+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
185+
: FAL_AI_API_BASE_URL;
186+
return `${baseUrl}/${params.model}`;
187+
}
167188
case "replicate": {
189+
const baseUrl = shouldProxy
190+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
191+
: REPLICATE_API_BASE_URL;
168192
if (params.model.includes(":")) {
169193
/// Versioned model
170-
return `${REPLICATE_API_BASE_URL}/v1/predictions`;
194+
return `${baseUrl}/v1/predictions`;
171195
}
172196
/// Evergreen / Canonical model
173-
return `${REPLICATE_API_BASE_URL}/v1/models/${params.model}/predictions`;
197+
return `${baseUrl}/v1/models/${params.model}/predictions`;
174198
}
175-
case "sambanova":
199+
case "sambanova": {
200+
const baseUrl = shouldProxy
201+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
202+
: SAMBANOVA_API_BASE_URL;
176203
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
177204
if (params.taskHint === "text-generation" && params.chatCompletion) {
178-
return `${SAMBANOVA_API_BASE_URL}/v1/chat/completions`;
205+
return `${baseUrl}/v1/chat/completions`;
179206
}
180-
return SAMBANOVA_API_BASE_URL;
207+
return baseUrl;
208+
}
181209
case "together": {
210+
const baseUrl = shouldProxy
211+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
212+
: TOGETHER_API_BASE_URL;
182213
/// Together API matches OpenAI-like APIs: model is defined in the request body
183214
if (params.taskHint === "text-to-image") {
184-
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
215+
return `${baseUrl}/v1/images/generations`;
185216
}
186217
if (params.taskHint === "text-generation") {
187218
if (params.chatCompletion) {
188-
return `${TOGETHER_API_BASE_URL}/v1/chat/completions`;
219+
return `${baseUrl}/v1/chat/completions`;
189220
}
190-
return `${TOGETHER_API_BASE_URL}/v1/completions`;
221+
return `${baseUrl}/v1/completions`;
191222
}
192-
return TOGETHER_API_BASE_URL;
223+
return baseUrl;
193224
}
194225
default: {
195226
const url = params.forceTask
196-
? `${HF_INFERENCE_API_BASE_URL}/pipeline/${params.forceTask}/${params.model}`
197-
: `${HF_INFERENCE_API_BASE_URL}/models/${params.model}`;
227+
? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}`
228+
: `${HF_INFERENCE_API_URL}/models/${params.model}`;
198229
if (params.taskHint === "text-generation" && params.chatCompletion) {
199230
return url + `/v1/chat/completions`;
200231
}

packages/inference/test/vcr.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { omit } from "../src/utils/omit";
2-
import { HF_HUB_URL } from "../src/lib/getDefaultTask";
2+
import { HF_HUB_URL } from "../src/config";
33
import { isBackend } from "../src/utils/isBackend";
44
import { isFrontend } from "../src/utils/isFrontend";
55

0 commit comments

Comments
 (0)