Skip to content

Commit 31c666f

Browse files
author
Fabien Ric
committed
fix chatcompletion payload
1 parent 1212dc8 commit 31c666f

File tree

2 files changed

+21
-34
lines changed

2 files changed

+21
-34
lines changed

packages/inference/src/providers/ovhcloud.ts

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,10 @@ import type {
2424
import { InferenceOutputError } from "../lib/InferenceOutputError";
2525
import type { BodyParams } from "../types";
2626
import { omit } from "../utils/omit";
27+
import type { TextGenerationInput } from "@huggingface/tasks";
2728

2829
const OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net";
2930

30-
function prepareBaseOvhCloudPayload(params: BodyParams): Record<string, unknown> {
31-
return {
32-
model: params.model,
33-
...omit(params.args, ["inputs", "parameters"]),
34-
...(params.args.parameters
35-
? {
36-
max_tokens: (params.args.parameters as Record<string, unknown>).max_new_tokens,
37-
...omit(params.args.parameters as Record<string, unknown>, "max_new_tokens"),
38-
}
39-
: undefined),
40-
prompt: params.args.inputs,
41-
};
42-
}
43-
4431
interface OvhCloudTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
4532
choices: Array<{
4633
text: string;
@@ -54,21 +41,25 @@ export class OvhCloudConversationalTask extends BaseConversationalTask {
5441
constructor() {
5542
super("ovhcloud", OVHCLOUD_API_BASE_URL);
5643
}
57-
58-
override preparePayload(params: BodyParams): Record<string, unknown> {
59-
return prepareBaseOvhCloudPayload(params);
60-
}
6144
}
6245

6346
export class OvhCloudTextGenerationTask extends BaseTextGenerationTask {
6447
constructor() {
6548
super("ovhcloud", OVHCLOUD_API_BASE_URL);
6649
}
6750

68-
override preparePayload(params: BodyParams): Record<string, unknown> {
69-
const payload = prepareBaseOvhCloudPayload(params);
70-
payload.prompt = params.args.inputs;
71-
return payload;
51+
override preparePayload(params: BodyParams<TextGenerationInput>): Record<string, unknown> {
52+
return {
53+
model: params.model,
54+
...omit(params.args, ["inputs", "parameters"]),
55+
...(params.args.parameters
56+
? {
57+
max_tokens: (params.args.parameters as Record<string, unknown>).max_new_tokens,
58+
...omit(params.args.parameters as Record<string, unknown>, "max_new_tokens"),
59+
}
60+
: undefined),
61+
prompt: params.args.inputs,
62+
};
7263
}
7364

7465
override async getResponse(response: OvhCloudTextCompletionOutput): Promise<TextGenerationOutput> {

packages/inference/test/InferenceClient.spec.ts

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,12 +1710,10 @@ describe.concurrent("InferenceClient", () => {
17101710
model: "meta-llama/llama-3.1-8b-instruct",
17111711
provider: "ovhcloud",
17121712
messages: [{ role: "user", content: "A, B, C, " }],
1713-
parameters: {
1714-
seed: 42,
1715-
temperature: 0,
1716-
top_p: 0.01,
1717-
max_new_tokens: 1,
1718-
},
1713+
seed: 42,
1714+
temperature: 0,
1715+
top_p: 0.01,
1716+
max_tokens: 1,
17191717
});
17201718
expect(res.choices && res.choices.length > 0);
17211719
const completion = res.choices[0].message?.content;
@@ -1728,12 +1726,10 @@ describe.concurrent("InferenceClient", () => {
17281726
provider: "ovhcloud",
17291727
messages: [{ role: "user", content: "A, B, C, " }],
17301728
stream: true,
1731-
parameters: {
1732-
seed: 42,
1733-
temperature: 0,
1734-
top_p: 0.01,
1735-
max_new_tokens: 1,
1736-
},
1729+
seed: 42,
1730+
temperature: 0,
1731+
top_p: 0.01,
1732+
max_tokens: 1,
17371733
}) as AsyncGenerator<ChatCompletionStreamOutput>;
17381734

17391735
let fullResponse = "";

0 commit comments

Comments
 (0)