@@ -24,23 +24,10 @@ import type {
2424import { InferenceOutputError } from "../lib/InferenceOutputError" ;
2525import type { BodyParams } from "../types" ;
2626import { omit } from "../utils/omit" ;
27+ import type { TextGenerationInput } from "@huggingface/tasks" ;
2728
2829const OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net" ;
2930
30- function prepareBaseOvhCloudPayload ( params : BodyParams ) : Record < string , unknown > {
31- return {
32- model : params . model ,
33- ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
34- ...( params . args . parameters
35- ? {
36- max_tokens : ( params . args . parameters as Record < string , unknown > ) . max_new_tokens ,
37- ...omit ( params . args . parameters as Record < string , unknown > , "max_new_tokens" ) ,
38- }
39- : undefined ) ,
40- prompt : params . args . inputs ,
41- } ;
42- }
43-
4431interface OvhCloudTextCompletionOutput extends Omit < ChatCompletionOutput , "choices" > {
4532 choices : Array < {
4633 text : string ;
@@ -54,21 +41,25 @@ export class OvhCloudConversationalTask extends BaseConversationalTask {
5441 constructor ( ) {
5542 super ( "ovhcloud" , OVHCLOUD_API_BASE_URL ) ;
5643 }
57-
58- override preparePayload ( params : BodyParams ) : Record < string , unknown > {
59- return prepareBaseOvhCloudPayload ( params ) ;
60- }
6144}
6245
6346export class OvhCloudTextGenerationTask extends BaseTextGenerationTask {
6447 constructor ( ) {
6548 super ( "ovhcloud" , OVHCLOUD_API_BASE_URL ) ;
6649 }
6750
68- override preparePayload ( params : BodyParams ) : Record < string , unknown > {
69- const payload = prepareBaseOvhCloudPayload ( params ) ;
70- payload . prompt = params . args . inputs ;
71- return payload ;
51+ override preparePayload ( params : BodyParams < TextGenerationInput > ) : Record < string , unknown > {
52+ return {
53+ model : params . model ,
54+ ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
55+ ...( params . args . parameters
56+ ? {
57+ max_tokens : ( params . args . parameters as Record < string , unknown > ) . max_new_tokens ,
58+ ...omit ( params . args . parameters as Record < string , unknown > , "max_new_tokens" ) ,
59+ }
60+ : undefined ) ,
61+ prompt : params . args . inputs ,
62+ } ;
7263 }
7364
7465 override async getResponse ( response : OvhCloudTextCompletionOutput ) : Promise < TextGenerationOutput > {
0 commit comments