Skip to content

Commit dd504e3

Browse files
committed
add response_format support
1 parent 6fb1214 commit dd504e3

File tree

3 files changed

+270
-50
lines changed

3 files changed

+270
-50
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,16 +305,44 @@ const prepareConversationalInput = (
305305
temperature?: GenerationParameters["temperature"];
306306
max_tokens?: GenerationParameters["max_new_tokens"];
307307
top_p?: GenerationParameters["top_p"];
308+
response_format?: Record<string, unknown>;
308309
}
309310
): object => {
310311
return {
311312
messages: opts?.messages ?? getModelInputSnippet(model),
312313
...(opts?.temperature ? { temperature: opts?.temperature } : undefined),
313314
...(opts?.max_tokens ? { max_tokens: opts?.max_tokens } : undefined),
314315
...(opts?.top_p ? { top_p: opts?.top_p } : undefined),
316+
...(opts?.response_format ? { response_format: opts?.response_format } : undefined),
315317
};
316318
};
317319

320+
const prepareTextGenerationInput = (
321+
model: ModelDataMinimal,
322+
opts?: {
323+
streaming?: boolean;
324+
temperature?: GenerationParameters["temperature"];
325+
max_tokens?: GenerationParameters["max_new_tokens"];
326+
top_p?: GenerationParameters["top_p"];
327+
response_format?: Record<string, unknown>;
328+
}
329+
): object => {
330+
const base = { inputs: getModelInputSnippet(model) };
331+
const parameters: Record<string, unknown> = {};
332+
333+
if (opts?.temperature !== undefined) parameters.temperature = opts.temperature;
334+
if (opts?.max_tokens !== undefined) parameters.max_new_tokens = opts.max_tokens;
335+
if (opts?.top_p !== undefined) parameters.top_p = opts.top_p;
336+
if (opts?.response_format !== undefined) parameters.response_format = opts.response_format;
337+
338+
// Only add parameters if there are any
339+
if (Object.keys(parameters).length > 0) {
340+
return { ...base, parameters };
341+
}
342+
343+
return base;
344+
};
345+
318346
const prepareQuestionAnsweringInput = (model: ModelDataMinimal): object => {
319347
const data = JSON.parse(getModelInputSnippet(model) as string);
320348
return { question: data.question, context: data.context };
@@ -355,7 +383,7 @@ const snippets: Partial<
355383
"tabular-regression": snippetGenerator("tabular"),
356384
"table-question-answering": snippetGenerator("tableQuestionAnswering", prepareTableQuestionAnsweringInput),
357385
"text-classification": snippetGenerator("basic"),
358-
"text-generation": snippetGenerator("basic"),
386+
"text-generation": snippetGenerator("basic", prepareTextGenerationInput),
359387
"text-to-audio": snippetGenerator("textToAudio"),
360388
"text-to-image": snippetGenerator("textToImage"),
361389
"text-to-speech": snippetGenerator("textToSpeech"),
@@ -393,7 +421,7 @@ function formatBody(obj: object, format: "curl" | "json" | "python" | "ts"): str
393421
return indentString(
394422
Object.entries(obj)
395423
.map(([key, value]) => {
396-
const formattedValue = JSON.stringify(value, null, 4).replace(/"/g, '"');
424+
const formattedValue = formatPythonValue(value, 1);
397425
return `${key}=${formattedValue},`;
398426
})
399427
.join("\n")
@@ -408,6 +436,46 @@ function formatBody(obj: object, format: "curl" | "json" | "python" | "ts"): str
408436
}
409437
}
410438

439+
function formatPythonValue(obj: unknown, depth?: number): string {
440+
depth = depth ?? 0;
441+
442+
/// Case boolean - convert to Python format
443+
if (typeof obj === "boolean") {
444+
return obj ? "True" : "False";
445+
}
446+
447+
/// Case null - convert to Python format
448+
if (obj === null) {
449+
return "None";
450+
}
451+
452+
/// Case number or string
453+
if (typeof obj !== "object") {
454+
return JSON.stringify(obj);
455+
}
456+
457+
/// Case array
458+
if (Array.isArray(obj)) {
459+
const items = obj
460+
.map((item) => {
461+
const formatted = formatPythonValue(item, depth + 1);
462+
return `${" ".repeat(4 * (depth + 1))}${formatted},`;
463+
})
464+
.join("\n");
465+
return `[\n${items}\n${" ".repeat(4 * depth)}]`;
466+
}
467+
468+
/// Case mapping (object)
469+
const entries = Object.entries(obj);
470+
const lines = entries
471+
.map(([key, value]) => {
472+
const formattedValue = formatPythonValue(value, depth + 1);
473+
return `${" ".repeat(4 * (depth + 1))}"${key}": ${formattedValue},`;
474+
})
475+
.join("\n");
476+
return `{\n${lines}\n${" ".repeat(4 * depth)}}`;
477+
}
478+
411479
function formatTsObject(obj: unknown, depth?: number): string {
412480
depth = depth ?? 0;
413481

0 commit comments

Comments
 (0)