Skip to content

Commit befb723

Browse files
authored
[inference] Proposal: rename taskHint param to task (#1204)
so we merge `task` and `taskHint`
1 parent 23ffa83 commit befb723

34 files changed

+60
-64
lines changed

packages/inference/src/lib/getProviderModelId.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,19 @@ export async function getProviderModelId(
1515
},
1616
args: RequestArgs,
1717
options: {
18-
taskHint?: InferenceTask;
18+
task?: InferenceTask;
1919
chatCompletion?: boolean;
2020
fetch?: Options["fetch"];
2121
} = {}
2222
): Promise<string> {
2323
if (params.provider === "hf-inference") {
2424
return params.model;
2525
}
26-
if (!options.taskHint) {
27-
throw new Error("taskHint must be specified when using a third-party provider");
26+
if (!options.task) {
27+
throw new Error("task must be specified when using a third-party provider");
2828
}
2929
const task: WidgetType =
30-
options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
30+
options.task === "text-generation" && options.chatCompletion ? "conversational" : options.task;
3131

3232
// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
3333
if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,30 @@ export async function makeRequestOptions(
3131
stream?: boolean;
3232
},
3333
options?: Options & {
34-
/** To load default model if needed */
35-
taskHint?: InferenceTask;
34+
/** In most cases (unless we pass a endpointUrl) we know the task */
35+
task?: InferenceTask;
3636
chatCompletion?: boolean;
3737
}
3838
): Promise<{ url: string; info: RequestInit }> {
3939
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
4040
let otherArgs = remainingArgs;
4141
const provider = maybeProvider ?? "hf-inference";
4242

43-
const { includeCredentials, taskHint, chatCompletion } = options ?? {};
43+
const { includeCredentials, task, chatCompletion } = options ?? {};
4444

4545
if (endpointUrl && provider !== "hf-inference") {
4646
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
4747
}
4848
if (maybeModel && isUrl(maybeModel)) {
4949
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
5050
}
51-
if (!maybeModel && !taskHint) {
51+
if (!maybeModel && !task) {
5252
throw new Error("No model provided, and no task has been specified.");
5353
}
5454
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
55-
const hfModel = maybeModel ?? (await loadDefaultModel(taskHint!));
55+
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
5656
const model = await getProviderModelId({ model: hfModel, provider }, args, {
57-
taskHint,
57+
task,
5858
chatCompletion,
5959
fetch: options?.fetch,
6060
});
@@ -77,7 +77,7 @@ export async function makeRequestOptions(
7777
chatCompletion: chatCompletion ?? false,
7878
model,
7979
provider: provider ?? "hf-inference",
80-
taskHint,
80+
task,
8181
});
8282

8383
const headers: Record<string, string> = {};
@@ -133,7 +133,7 @@ export async function makeRequestOptions(
133133
? args.data
134134
: JSON.stringify({
135135
...otherArgs,
136-
...(taskHint === "text-to-image" && provider === "hyperbolic"
136+
...(task === "text-to-image" && provider === "hyperbolic"
137137
? { model_name: model }
138138
: chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
139139
? { model }
@@ -151,7 +151,7 @@ function makeUrl(params: {
151151
chatCompletion: boolean;
152152
model: string;
153153
provider: InferenceProvider;
154-
taskHint: InferenceTask | undefined;
154+
task: InferenceTask | undefined;
155155
}): string {
156156
if (params.authMethod === "none" && params.provider !== "hf-inference") {
157157
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
@@ -176,10 +176,10 @@ function makeUrl(params: {
176176
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
177177
: NEBIUS_API_BASE_URL;
178178

179-
if (params.taskHint === "text-to-image") {
179+
if (params.task === "text-to-image") {
180180
return `${baseUrl}/v1/images/generations`;
181181
}
182-
if (params.taskHint === "text-generation") {
182+
if (params.task === "text-generation") {
183183
if (params.chatCompletion) {
184184
return `${baseUrl}/v1/chat/completions`;
185185
}
@@ -203,7 +203,7 @@ function makeUrl(params: {
203203
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
204204
: SAMBANOVA_API_BASE_URL;
205205
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
206-
if (params.taskHint === "text-generation" && params.chatCompletion) {
206+
if (params.task === "text-generation" && params.chatCompletion) {
207207
return `${baseUrl}/v1/chat/completions`;
208208
}
209209
return baseUrl;
@@ -213,10 +213,10 @@ function makeUrl(params: {
213213
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
214214
: TOGETHER_API_BASE_URL;
215215
/// Together API matches OpenAI-like APIs: model is defined in the request body
216-
if (params.taskHint === "text-to-image") {
216+
if (params.task === "text-to-image") {
217217
return `${baseUrl}/v1/images/generations`;
218218
}
219-
if (params.taskHint === "text-generation") {
219+
if (params.task === "text-generation") {
220220
if (params.chatCompletion) {
221221
return `${baseUrl}/v1/chat/completions`;
222222
}
@@ -229,7 +229,7 @@ function makeUrl(params: {
229229
const baseUrl = shouldProxy
230230
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
231231
: FIREWORKS_AI_API_BASE_URL;
232-
if (params.taskHint === "text-generation" && params.chatCompletion) {
232+
if (params.task === "text-generation" && params.chatCompletion) {
233233
return `${baseUrl}/v1/chat/completions`;
234234
}
235235
return baseUrl;
@@ -239,7 +239,7 @@ function makeUrl(params: {
239239
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
240240
: HYPERBOLIC_API_BASE_URL;
241241

242-
if (params.taskHint === "text-to-image") {
242+
if (params.task === "text-to-image") {
243243
return `${baseUrl}/v1/images/generations`;
244244
}
245245
return `${baseUrl}/v1/chat/completions`;
@@ -248,7 +248,7 @@ function makeUrl(params: {
248248
const baseUrl = shouldProxy
249249
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
250250
: NOVITA_API_BASE_URL;
251-
if (params.taskHint === "text-generation") {
251+
if (params.task === "text-generation") {
252252
if (params.chatCompletion) {
253253
return `${baseUrl}/chat/completions`;
254254
}
@@ -258,11 +258,11 @@ function makeUrl(params: {
258258
}
259259
default: {
260260
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
261-
if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
261+
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
262262
/// when deployed on hf-inference, those two tasks are automatically compatible with one another.
263-
return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
263+
return `${baseUrl}/pipeline/${params.task}/${params.model}`;
264264
}
265-
if (params.taskHint === "text-generation" && params.chatCompletion) {
265+
if (params.task === "text-generation" && params.chatCompletion) {
266266
return `${baseUrl}/models/${params.model}/v1/chat/completions`;
267267
}
268268
return `${baseUrl}/models/${params.model}`;

packages/inference/src/tasks/audio/audioClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export async function audioClassification(
1818
const payload = preparePayload(args);
1919
const res = await request<AudioClassificationOutput>(payload, {
2020
...options,
21-
taskHint: "audio-classification",
21+
task: "audio-classification",
2222
});
2323
const isValidOutput =
2424
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

packages/inference/src/tasks/audio/audioToAudio.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): P
3939
const payload = preparePayload(args);
4040
const res = await request<AudioToAudioOutput>(payload, {
4141
...options,
42-
taskHint: "audio-to-audio",
42+
task: "audio-to-audio",
4343
});
4444

4545
return validateOutput(res);

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export async function automaticSpeechRecognition(
1919
const payload = await buildPayload(args);
2020
const res = await request<AutomaticSpeechRecognitionOutput>(payload, {
2121
...options,
22-
taskHint: "automatic-speech-recognition",
22+
task: "automatic-speech-recognition",
2323
});
2424
const isValidOutput = typeof res?.text === "string";
2525
if (!isValidOutput) {

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P
2424
: args;
2525
const res = await request<Blob | OutputUrlTextToSpeechGeneration>(payload, {
2626
...options,
27-
taskHint: "text-to-speech",
27+
task: "text-to-speech",
2828
});
2929
if (res instanceof Blob) {
3030
return res;

packages/inference/src/tasks/custom/request.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ import { makeRequestOptions } from "../../lib/makeRequestOptions";
77
export async function request<T>(
88
args: RequestArgs,
99
options?: Options & {
10-
/** When a model can be used for multiple tasks, and we want to run a non-default task */
11-
task?: string | InferenceTask;
12-
/** To load default model if needed */
13-
taskHint?: InferenceTask;
10+
/** In most cases (unless we pass a endpointUrl) we know the task */
11+
task?: InferenceTask;
1412
/** Is chat completion compatible */
1513
chatCompletion?: boolean;
1614
}

packages/inference/src/tasks/custom/streamingRequest.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@ import { getLines, getMessages } from "../../vendor/fetch-event-source/parse";
99
export async function* streamingRequest<T>(
1010
args: RequestArgs,
1111
options?: Options & {
12-
/** When a model can be used for multiple tasks, and we want to run a non-default task */
13-
task?: string | InferenceTask;
14-
/** To load default model if needed */
15-
taskHint?: InferenceTask;
12+
/** In most cases (unless we pass a endpointUrl) we know the task */
13+
task?: InferenceTask;
1614
/** Is chat completion compatible */
1715
chatCompletion?: boolean;
1816
}

packages/inference/src/tasks/cv/imageClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export async function imageClassification(
1717
const payload = preparePayload(args);
1818
const res = await request<ImageClassificationOutput>(payload, {
1919
...options,
20-
taskHint: "image-classification",
20+
task: "image-classification",
2121
});
2222
const isValidOutput =
2323
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

packages/inference/src/tasks/cv/imageSegmentation.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export async function imageSegmentation(
1717
const payload = preparePayload(args);
1818
const res = await request<ImageSegmentationOutput>(payload, {
1919
...options,
20-
taskHint: "image-segmentation",
20+
task: "image-segmentation",
2121
});
2222
const isValidOutput =
2323
Array.isArray(res) &&

0 commit comments

Comments
 (0)