|
1 | | -import type { WidgetType } from "@huggingface/tasks"; |
2 | 1 | import { HF_HUB_URL } from "../config"; |
3 | | -import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai"; |
4 | | -import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate"; |
5 | | -import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova"; |
6 | | -import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together"; |
| 2 | +import { FAL_AI_API_BASE_URL } from "../providers/fal-ai"; |
| 3 | +import { REPLICATE_API_BASE_URL } from "../providers/replicate"; |
| 4 | +import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova"; |
| 5 | +import { TOGETHER_API_BASE_URL } from "../providers/together"; |
7 | 6 | import type { InferenceProvider } from "../types"; |
8 | 7 | import type { InferenceTask, Options, RequestArgs } from "../types"; |
9 | 8 | import { isUrl } from "./isUrl"; |
10 | 9 | import { version as packageVersion, name as packageName } from "../../package.json"; |
| 10 | +import { getProviderModelId } from "./getProviderModelId"; |
11 | 11 |
|
12 | 12 | const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`; |
13 | 13 |
|
@@ -49,18 +49,16 @@ export async function makeRequestOptions( |
49 | 49 | if (maybeModel && isUrl(maybeModel)) { |
50 | 50 | throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`); |
51 | 51 | } |
52 | | - |
53 | | - let model: string; |
54 | | - if (!maybeModel) { |
55 | | - if (taskHint) { |
56 | | - model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion }); |
57 | | - } else { |
58 | | - throw new Error("No model provided, and no default model found for this task"); |
59 | | - /// TODO : change error message ^ |
60 | | - } |
61 | | - } else { |
62 | | - model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion }); |
| 52 | + if (!maybeModel && !taskHint) { |
| 53 | + throw new Error("No model provided, and no task has been specified."); |
63 | 54 | } |
| 55 | + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion |
| 56 | + const hfModel = maybeModel ?? (await loadDefaultModel(taskHint!)); |
| 57 | + const model = await getProviderModelId({ model: hfModel, provider }, args, { |
| 58 | + taskHint, |
| 59 | + chatCompletion, |
| 60 | + fetch: options?.fetch, |
| 61 | + }); |
64 | 62 |
|
65 | 63 | /// If accessToken is passed, it should take precedence over includeCredentials |
66 | 64 | const authMethod = accessToken |
@@ -153,39 +151,6 @@ export async function makeRequestOptions( |
153 | 151 | return { url, info }; |
154 | 152 | } |
155 | 153 |
|
156 | | -function mapModel(params: { |
157 | | - model: string; |
158 | | - provider: InferenceProvider; |
159 | | - taskHint: InferenceTask | undefined; |
160 | | - chatCompletion: boolean | undefined; |
161 | | -}): string { |
162 | | - if (params.provider === "hf-inference") { |
163 | | - return params.model; |
164 | | - } |
165 | | - if (!params.taskHint) { |
166 | | - throw new Error("taskHint must be specified when using a third-party provider"); |
167 | | - } |
168 | | - const task: WidgetType = |
169 | | - params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint; |
170 | | - const model = (() => { |
171 | | - switch (params.provider) { |
172 | | - case "fal-ai": |
173 | | - return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model]; |
174 | | - case "replicate": |
175 | | - return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model]; |
176 | | - case "sambanova": |
177 | | - return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model]; |
178 | | - case "together": |
179 | | - return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model]; |
180 | | - } |
181 | | - })(); |
182 | | - |
183 | | - if (!model) { |
184 | | - throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`); |
185 | | - } |
186 | | - return model; |
187 | | -} |
188 | | - |
189 | 154 | function makeUrl(params: { |
190 | 155 | authMethod: "none" | "hf-token" | "credentials-include" | "provider-key"; |
191 | 156 | chatCompletion: boolean; |
|
0 commit comments