Skip to content

Commit 351580f

Browse files
✨ Support text generation task for Nebius AI Studio provider (#1561)
In this PR, I'm adding text-generation task support for the `nebius` provider (Nebius AI Studio). This is necessary to make completions for model [mistralai/Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505) work. Please let me know if my changes are inconsistent with how you think to support such models. I will make the necessary changes. Thanks! Co-authored-by: célina <[email protected]>
1 parent cf8e494 commit 351580f

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

packages/inference/src/providers/nebius.ts

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { FeatureExtractionOutput } from "@huggingface/tasks";
17+
import type { FeatureExtractionOutput, TextGenerationOutput } from "@huggingface/tasks";
1818
import type { BodyParams } from "../types.js";
1919
import { omit } from "../utils/omit.js";
2020
import {
@@ -40,6 +40,12 @@ interface NebiusEmbeddingsResponse {
4040
}>;
4141
}
4242

43+
interface NebiusTextGenerationOutput extends Omit<TextGenerationOutput, "choices"> {
44+
choices: Array<{
45+
text: string;
46+
}>;
47+
}
48+
4349
export class NebiusConversationalTask extends BaseConversationalTask {
4450
constructor() {
4551
super("nebius", NEBIUS_API_BASE_URL);
@@ -50,6 +56,29 @@ export class NebiusTextGenerationTask extends BaseTextGenerationTask {
5056
constructor() {
5157
super("nebius", NEBIUS_API_BASE_URL);
5258
}
59+
60+
override preparePayload(params: BodyParams): Record<string, unknown> {
61+
return {
62+
...params.args,
63+
model: params.model,
64+
prompt: params.args.inputs,
65+
};
66+
}
67+
68+
override async getResponse(response: NebiusTextGenerationOutput): Promise<TextGenerationOutput> {
69+
if (
70+
typeof response === "object" &&
71+
"choices" in response &&
72+
Array.isArray(response?.choices) &&
73+
response.choices.length > 0 &&
74+
typeof response.choices[0]?.text === "string"
75+
) {
76+
return {
77+
generated_text: response.choices[0].text,
78+
};
79+
}
80+
throw new InferenceClientProviderOutputError("Received malformed response from Nebius text generation API");
81+
}
5382
}
5483

5584
export class NebiusTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {

packages/inference/test/InferenceClient.spec.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,13 @@ describe.skip("InferenceClient", () => {
14241424
status: "live",
14251425
task: "feature-extraction",
14261426
},
1427+
"mistralai/Devstral-Small-2505": {
1428+
provider: "nebius",
1429+
providerId: "mistralai/Devstral-Small-2505",
1430+
hfModelId: "mistralai/Devstral-Small-2505",
1431+
status: "live",
1432+
task: "text2text-generation",
1433+
},
14271434
};
14281435

14291436
it("chatCompletion", async () => {
@@ -1471,6 +1478,19 @@ describe.skip("InferenceClient", () => {
14711478
expect(res).toBeInstanceOf(Array);
14721479
expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)]));
14731480
});
1481+
1482+
it("text2textGeneration", async () => {
1483+
const res = await client.textGeneration({
1484+
model: "mistralai/Devstral-Small-2505",
1485+
provider: "nebius",
1486+
inputs: "Once upon a time,",
1487+
temperature: 0,
1488+
max_tokens: 19,
1489+
});
1490+
expect(res).toMatchObject({
1491+
generated_text: " in a land far, far away, there lived a king who was very fond of flowers.",
1492+
});
1493+
});
14741494
},
14751495
TIMEOUT
14761496
);

0 commit comments

Comments
 (0)