@@ -305,16 +305,44 @@ const prepareConversationalInput = (
305305 temperature ?: GenerationParameters [ "temperature" ] ;
306306 max_tokens ?: GenerationParameters [ "max_new_tokens" ] ;
307307 top_p ?: GenerationParameters [ "top_p" ] ;
308+ response_format ?: Record < string , unknown > ;
308309 }
309310) : object => {
310311 return {
311312 messages : opts ?. messages ?? getModelInputSnippet ( model ) ,
312313 ...( opts ?. temperature ? { temperature : opts ?. temperature } : undefined ) ,
313314 ...( opts ?. max_tokens ? { max_tokens : opts ?. max_tokens } : undefined ) ,
314315 ...( opts ?. top_p ? { top_p : opts ?. top_p } : undefined ) ,
316+ ...( opts ?. response_format ? { response_format : opts ?. response_format } : undefined ) ,
315317 } ;
316318} ;
317319
320+ const prepareTextGenerationInput = (
321+ model : ModelDataMinimal ,
322+ opts ?: {
323+ streaming ?: boolean ;
324+ temperature ?: GenerationParameters [ "temperature" ] ;
325+ max_tokens ?: GenerationParameters [ "max_new_tokens" ] ;
326+ top_p ?: GenerationParameters [ "top_p" ] ;
327+ response_format ?: Record < string , unknown > ;
328+ }
329+ ) : object => {
330+ const base = { inputs : getModelInputSnippet ( model ) } ;
331+ const parameters : Record < string , unknown > = { } ;
332+
333+ if ( opts ?. temperature !== undefined ) parameters . temperature = opts . temperature ;
334+ if ( opts ?. max_tokens !== undefined ) parameters . max_new_tokens = opts . max_tokens ;
335+ if ( opts ?. top_p !== undefined ) parameters . top_p = opts . top_p ;
336+ if ( opts ?. response_format !== undefined ) parameters . response_format = opts . response_format ;
337+
338+ // Only add parameters if there are any
339+ if ( Object . keys ( parameters ) . length > 0 ) {
340+ return { ...base , parameters } ;
341+ }
342+
343+ return base ;
344+ } ;
345+
318346const prepareQuestionAnsweringInput = ( model : ModelDataMinimal ) : object => {
319347 const data = JSON . parse ( getModelInputSnippet ( model ) as string ) ;
320348 return { question : data . question , context : data . context } ;
@@ -355,7 +383,7 @@ const snippets: Partial<
355383 "tabular-regression" : snippetGenerator ( "tabular" ) ,
356384 "table-question-answering" : snippetGenerator ( "tableQuestionAnswering" , prepareTableQuestionAnsweringInput ) ,
357385 "text-classification" : snippetGenerator ( "basic" ) ,
358- "text-generation" : snippetGenerator ( "basic" ) ,
386+ "text-generation" : snippetGenerator ( "basic" , prepareTextGenerationInput ) ,
359387 "text-to-audio" : snippetGenerator ( "textToAudio" ) ,
360388 "text-to-image" : snippetGenerator ( "textToImage" ) ,
361389 "text-to-speech" : snippetGenerator ( "textToSpeech" ) ,
@@ -393,7 +421,7 @@ function formatBody(obj: object, format: "curl" | "json" | "python" | "ts"): str
393421 return indentString (
394422 Object . entries ( obj )
395423 . map ( ( [ key , value ] ) => {
396- const formattedValue = JSON . stringify ( value , null , 4 ) . replace ( / " / g , '"' ) ;
424+ const formattedValue = formatPythonValue ( value , 1 ) ;
397425 return `${ key } =${ formattedValue } ,` ;
398426 } )
399427 . join ( "\n" )
@@ -408,6 +436,46 @@ function formatBody(obj: object, format: "curl" | "json" | "python" | "ts"): str
408436 }
409437}
410438
439+ function formatPythonValue ( obj : unknown , depth ?: number ) : string {
440+ depth = depth ?? 0 ;
441+
442+ /// Case boolean - convert to Python format
443+ if ( typeof obj === "boolean" ) {
444+ return obj ? "True" : "False" ;
445+ }
446+
447+ /// Case null - convert to Python format
448+ if ( obj === null ) {
449+ return "None" ;
450+ }
451+
452+ /// Case number or string
453+ if ( typeof obj !== "object" ) {
454+ return JSON . stringify ( obj ) ;
455+ }
456+
457+ /// Case array
458+ if ( Array . isArray ( obj ) ) {
459+ const items = obj
460+ . map ( ( item ) => {
461+ const formatted = formatPythonValue ( item , depth + 1 ) ;
462+ return `${ " " . repeat ( 4 * ( depth + 1 ) ) } ${ formatted } ,` ;
463+ } )
464+ . join ( "\n" ) ;
465+ return `[\n${ items } \n${ " " . repeat ( 4 * depth ) } ]` ;
466+ }
467+
468+ /// Case mapping (object)
469+ const entries = Object . entries ( obj ) ;
470+ const lines = entries
471+ . map ( ( [ key , value ] ) => {
472+ const formattedValue = formatPythonValue ( value , depth + 1 ) ;
473+ return `${ " " . repeat ( 4 * ( depth + 1 ) ) } "${ key } ": ${ formattedValue } ,` ;
474+ } )
475+ . join ( "\n" ) ;
476+ return `{\n${ lines } \n${ " " . repeat ( 4 * depth ) } }` ;
477+ }
478+
411479function formatTsObject ( obj : unknown , depth ?: number ) : string {
412480 depth = depth ?? 0 ;
413481
0 commit comments