Skip to content

Commit 453d528

Browse files
committed
Refactor conversational input to getModelInputSnippet
1 parent 850a70e commit 453d528

File tree

4 files changed

+40
-61
lines changed

4 files changed

+40
-61
lines changed

packages/tasks/src/snippets/curl.ts

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,7 @@ export const snippetTextGeneration = (
2626
if (model.tags.includes("conversational")) {
2727
// Conversational model detected, so we display a code snippet that features the Messages API
2828
const streaming = opts?.streaming ?? true;
29-
const exampleMessages: ChatCompletionInputMessage[] =
30-
model.pipeline_tag === "text-generation"
31-
? [{ role: "user", content: "What is the capital of France?" }]
32-
: [
33-
{
34-
role: "user",
35-
content: [
36-
{
37-
type: "image_url",
38-
image_url: {
39-
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
40-
},
41-
},
42-
{ type: "text", text: "Describe this image in one sentence." },
43-
],
44-
},
45-
];
29+
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
4630
const messages = opts?.messages ?? exampleMessages;
4731

4832
const config = {

packages/tasks/src/snippets/inputs.ts

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { PipelineType } from "../pipelines";
2+
import type { ChatCompletionInputMessage } from "../tasks";
23
import type { ModelDataMinimal } from "./types";
34

45
const inputsZeroShotClassification = () =>
@@ -40,7 +41,27 @@ const inputsTextClassification = () => `"I like you. I love you"`;
4041

4142
const inputsTokenClassification = () => `"My name is Sarah Jessica Parker but you can call me Jessica"`;
4243

43-
const inputsTextGeneration = () => `"Can you please let us know more details about your "`;
44+
const inputsTextGeneration = (model: ModelDataMinimal): string | ChatCompletionInputMessage[] => {
45+
if (model.tags.includes("conversational")) {
46+
return model.pipeline_tag === "text-generation"
47+
? [{ role: "user", content: "What is the capital of France?" }]
48+
: [
49+
{
50+
role: "user",
51+
content: [
52+
{
53+
type: "image_url",
54+
image_url: {
55+
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
56+
},
57+
},
58+
{ type: "text", text: "Describe this image in one sentence." },
59+
],
60+
},
61+
];
62+
}
63+
return `"Can you please let us know more details about your "`;
64+
};
4465

4566
const inputsText2TextGeneration = () => `"The answer to the universe is"`;
4667

@@ -84,7 +105,7 @@ const inputsTabularPrediction = () =>
84105
const inputsZeroShotImageClassification = () => `"cats.jpg"`;
85106

86107
const modelInputSnippets: {
87-
[key in PipelineType]?: (model: ModelDataMinimal) => string;
108+
[key in PipelineType]?: (model: ModelDataMinimal) => string | ChatCompletionInputMessage[];
88109
} = {
89110
"audio-to-audio": inputsAudioToAudio,
90111
"audio-classification": inputsAudioClassification,
@@ -116,18 +137,24 @@ const modelInputSnippets: {
116137

117138
// Use noWrap to put the whole snippet on a single line (removing new lines and tabulations)
118139
// Use noQuotes to strip quotes from start & end (example: "abc" -> abc)
119-
export function getModelInputSnippet(model: ModelDataMinimal, noWrap = false, noQuotes = false): string {
140+
export function getModelInputSnippet(
141+
model: ModelDataMinimal,
142+
noWrap = false,
143+
noQuotes = false
144+
): string | ChatCompletionInputMessage[] {
120145
if (model.pipeline_tag) {
121146
const inputs = modelInputSnippets[model.pipeline_tag];
122147
if (inputs) {
123148
let result = inputs(model);
124-
if (noWrap) {
125-
result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ");
126-
}
127-
if (noQuotes) {
128-
const REGEX_QUOTES = /^"(.+)"$/s;
129-
const match = result.match(REGEX_QUOTES);
130-
result = match ? match[1] : result;
149+
if (typeof result === "string") {
150+
if (noWrap) {
151+
result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ");
152+
}
153+
if (noQuotes) {
154+
const REGEX_QUOTES = /^"(.+)"$/s;
155+
const match = result.match(REGEX_QUOTES);
156+
result = match ? match[1] : result;
157+
}
131158
}
132159
return result;
133160
}

packages/tasks/src/snippets/js.ts

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,7 @@ export const snippetTextGeneration = (
4040
if (model.tags.includes("conversational")) {
4141
// Conversational model detected, so we display a code snippet that features the Messages API
4242
const streaming = opts?.streaming ?? true;
43-
const exampleMessages: ChatCompletionInputMessage[] =
44-
model.pipeline_tag === "text-generation"
45-
? [{ role: "user", content: "What is the capital of France?" }]
46-
: [
47-
{
48-
role: "user",
49-
content: [
50-
{
51-
type: "image_url",
52-
image_url: {
53-
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
54-
},
55-
},
56-
{ type: "text", text: "Describe this image in one sentence." },
57-
],
58-
},
59-
];
43+
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
6044
const messages = opts?.messages ?? exampleMessages;
6145
const messagesStr = stringifyMessages(messages, { sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });
6246

packages/tasks/src/snippets/python.ts

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,7 @@ export const snippetConversational = (
1616
}
1717
): InferenceSnippet[] => {
1818
const streaming = opts?.streaming ?? true;
19-
const exampleMessages: ChatCompletionInputMessage[] =
20-
model.pipeline_tag === "text-generation"
21-
? [{ role: "user", content: "What is the capital of France?" }]
22-
: [
23-
{
24-
role: "user",
25-
content: [
26-
{
27-
type: "image_url",
28-
image_url: {
29-
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
30-
},
31-
},
32-
{ type: "text", text: "Describe this image in one sentence." },
33-
],
34-
},
35-
];
19+
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
3620
const messages = opts?.messages ?? exampleMessages;
3721
const messagesStr = stringifyMessages(messages, {
3822
sep: ",\n\t",

0 commit comments

Comments
 (0)