|
| 1 | +import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config"; |
1 | 2 | import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai"; |
2 | 3 | import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate"; |
3 | 4 | import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova"; |
4 | 5 | import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together"; |
5 | 6 | import type { InferenceProvider } from "../types"; |
6 | 7 | import type { InferenceTask, Options, RequestArgs } from "../types"; |
7 | | -import { HF_HUB_URL } from "./getDefaultTask"; |
8 | 8 | import { isUrl } from "./isUrl"; |
9 | 9 |
|
10 | | -const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co"; |
| 10 | +const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`; |
11 | 11 |
|
12 | 12 | /** |
13 | 13 | * Lazy-loaded from huggingface.co/api/tasks when needed |
@@ -59,21 +59,32 @@ export async function makeRequestOptions( |
59 | 59 | model = mapModel({ model: maybeModel, provider }); |
60 | 60 | } |
61 | 61 |
|
| 62 | + /// If accessToken is passed, it should take precedence over includeCredentials |
| 63 | + const authMethod = accessToken |
| 64 | + ? accessToken.startsWith("hf_") |
| 65 | + ? "hf-token" |
| 66 | + : "provider-key" |
| 67 | + : includeCredentials === "include" |
| 68 | + ? "credentials-include" |
| 69 | + : "none"; |
| 70 | + |
62 | 71 | const url = endpointUrl |
63 | 72 | ? chatCompletion |
64 | 73 | ? endpointUrl + `/v1/chat/completions` |
65 | 74 | : endpointUrl |
66 | 75 | : makeUrl({ |
| 76 | + authMethod, |
| 77 | + chatCompletion: chatCompletion ?? false, |
| 78 | + forceTask, |
67 | 79 | model, |
68 | 80 | provider: provider ?? "hf-inference", |
69 | 81 | taskHint, |
70 | | - chatCompletion: chatCompletion ?? false, |
71 | | - forceTask, |
72 | 82 | }); |
73 | 83 |
|
74 | 84 | const headers: Record<string, string> = {}; |
75 | 85 | if (accessToken) { |
76 | | - headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`; |
| 86 | + headers["Authorization"] = |
| 87 | + provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`; |
77 | 88 | } |
78 | 89 |
|
79 | 90 | const binary = "data" in args && !!args.data; |
@@ -155,46 +166,66 @@ function mapModel(params: { model: string; provider: InferenceProvider }): strin |
155 | 166 | } |
156 | 167 |
|
157 | 168 | function makeUrl(params: { |
| 169 | + authMethod: "none" | "hf-token" | "credentials-include" | "provider-key"; |
| 170 | + chatCompletion: boolean; |
158 | 171 | model: string; |
159 | 172 | provider: InferenceProvider; |
160 | 173 | taskHint: InferenceTask | undefined; |
161 | | - chatCompletion: boolean; |
162 | 174 | forceTask?: string | InferenceTask; |
163 | 175 | }): string { |
| 176 | + if (params.authMethod === "none" && params.provider !== "hf-inference") { |
| 177 | + throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken"); |
| 178 | + } |
| 179 | + |
| 180 | + const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key"; |
164 | 181 | switch (params.provider) { |
165 | | - case "fal-ai": |
166 | | - return `${FAL_AI_API_BASE_URL}/${params.model}`; |
| 182 | + case "fal-ai": { |
| 183 | + const baseUrl = shouldProxy |
| 184 | + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) |
| 185 | + : FAL_AI_API_BASE_URL; |
| 186 | + return `${baseUrl}/${params.model}`; |
| 187 | + } |
167 | 188 | case "replicate": { |
| 189 | + const baseUrl = shouldProxy |
| 190 | + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) |
| 191 | + : REPLICATE_API_BASE_URL; |
168 | 192 | if (params.model.includes(":")) { |
169 | 193 | /// Versioned model |
170 | | - return `${REPLICATE_API_BASE_URL}/v1/predictions`; |
| 194 | + return `${baseUrl}/v1/predictions`; |
171 | 195 | } |
172 | 196 | /// Evergreen / Canonical model |
173 | | - return `${REPLICATE_API_BASE_URL}/v1/models/${params.model}/predictions`; |
| 197 | + return `${baseUrl}/v1/models/${params.model}/predictions`; |
174 | 198 | } |
175 | | - case "sambanova": |
| 199 | + case "sambanova": { |
| 200 | + const baseUrl = shouldProxy |
| 201 | + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) |
| 202 | + : SAMBANOVA_API_BASE_URL; |
176 | 203 | /// Sambanova API matches OpenAI-like APIs: model is defined in the request body |
177 | 204 | if (params.taskHint === "text-generation" && params.chatCompletion) { |
178 | | - return `${SAMBANOVA_API_BASE_URL}/v1/chat/completions`; |
| 205 | + return `${baseUrl}/v1/chat/completions`; |
179 | 206 | } |
180 | | - return SAMBANOVA_API_BASE_URL; |
| 207 | + return baseUrl; |
| 208 | + } |
181 | 209 | case "together": { |
| 210 | + const baseUrl = shouldProxy |
| 211 | + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) |
| 212 | + : TOGETHER_API_BASE_URL; |
182 | 213 | /// Together API matches OpenAI-like APIs: model is defined in the request body |
183 | 214 | if (params.taskHint === "text-to-image") { |
184 | | - return `${TOGETHER_API_BASE_URL}/v1/images/generations`; |
| 215 | + return `${baseUrl}/v1/images/generations`; |
185 | 216 | } |
186 | 217 | if (params.taskHint === "text-generation") { |
187 | 218 | if (params.chatCompletion) { |
188 | | - return `${TOGETHER_API_BASE_URL}/v1/chat/completions`; |
| 219 | + return `${baseUrl}/v1/chat/completions`; |
189 | 220 | } |
190 | | - return `${TOGETHER_API_BASE_URL}/v1/completions`; |
| 221 | + return `${baseUrl}/v1/completions`; |
191 | 222 | } |
192 | | - return TOGETHER_API_BASE_URL; |
| 223 | + return baseUrl; |
193 | 224 | } |
194 | 225 | default: { |
195 | 226 | const url = params.forceTask |
196 | | - ? `${HF_INFERENCE_API_BASE_URL}/pipeline/${params.forceTask}/${params.model}` |
197 | | - : `${HF_INFERENCE_API_BASE_URL}/models/${params.model}`; |
| 227 | + ? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` |
| 228 | + : `${HF_INFERENCE_API_URL}/models/${params.model}`; |
198 | 229 | if (params.taskHint === "text-generation" && params.chatCompletion) { |
199 | 230 | return url + `/v1/chat/completions`; |
200 | 231 | } |
|
0 commit comments