|
1 | 1 | import type { PipelineType } from "../pipelines"; |
| 2 | +import type { ChatCompletionInputMessage } from "../tasks"; |
2 | 3 | import type { ModelDataMinimal } from "./types"; |
3 | 4 |
|
4 | 5 | const inputsZeroShotClassification = () => |
@@ -40,7 +41,30 @@ const inputsTextClassification = () => `"I like you. I love you"`; |
40 | 41 |
|
41 | 42 | const inputsTokenClassification = () => `"My name is Sarah Jessica Parker but you can call me Jessica"`; |
42 | 43 |
|
43 | | -const inputsTextGeneration = () => `"Can you please let us know more details about your "`; |
| 44 | +const inputsTextGeneration = (model: ModelDataMinimal): string | ChatCompletionInputMessage[] => { |
| 45 | + if (model.tags.includes("conversational")) { |
| 46 | + return model.pipeline_tag === "text-generation" |
| 47 | + ? [{ role: "user", content: "What is the capital of France?" }] |
| 48 | + : [ |
| 49 | + { |
| 50 | + role: "user", |
| 51 | + content: [ |
| 52 | + { |
| 53 | + type: "text", |
| 54 | + text: "Describe this image in one sentence.", |
| 55 | + }, |
| 56 | + { |
| 57 | + type: "image_url", |
| 58 | + image_url: { |
| 59 | + url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", |
| 60 | + }, |
| 61 | + }, |
| 62 | + ], |
| 63 | + }, |
| 64 | + ]; |
| 65 | + } |
| 66 | + return `"Can you please let us know more details about your "`; |
| 67 | +}; |
44 | 68 |
|
45 | 69 | const inputsText2TextGeneration = () => `"The answer to the universe is"`; |
46 | 70 |
|
@@ -84,7 +108,7 @@ const inputsTabularPrediction = () => |
84 | 108 | const inputsZeroShotImageClassification = () => `"cats.jpg"`; |
85 | 109 |
|
86 | 110 | const modelInputSnippets: { |
87 | | - [key in PipelineType]?: (model: ModelDataMinimal) => string; |
| 111 | + [key in PipelineType]?: (model: ModelDataMinimal) => string | ChatCompletionInputMessage[]; |
88 | 112 | } = { |
89 | 113 | "audio-to-audio": inputsAudioToAudio, |
90 | 114 | "audio-classification": inputsAudioClassification, |
@@ -116,18 +140,24 @@ const modelInputSnippets: { |
116 | 140 |
|
117 | 141 | // Use noWrap to put the whole snippet on a single line (removing new lines and tabulations) |
118 | 142 | // Use noQuotes to strip quotes from start & end (example: "abc" -> abc) |
119 | | -export function getModelInputSnippet(model: ModelDataMinimal, noWrap = false, noQuotes = false): string { |
| 143 | +export function getModelInputSnippet( |
| 144 | + model: ModelDataMinimal, |
| 145 | + noWrap = false, |
| 146 | + noQuotes = false |
| 147 | +): string | ChatCompletionInputMessage[] { |
120 | 148 | if (model.pipeline_tag) { |
121 | 149 | const inputs = modelInputSnippets[model.pipeline_tag]; |
122 | 150 | if (inputs) { |
123 | 151 | let result = inputs(model); |
124 | | - if (noWrap) { |
125 | | - result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " "); |
126 | | - } |
127 | | - if (noQuotes) { |
128 | | - const REGEX_QUOTES = /^"(.+)"$/s; |
129 | | - const match = result.match(REGEX_QUOTES); |
130 | | - result = match ? match[1] : result; |
| 152 | + if (typeof result === "string") { |
| 153 | + if (noWrap) { |
| 154 | + result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " "); |
| 155 | + } |
| 156 | + if (noQuotes) { |
| 157 | + const REGEX_QUOTES = /^"(.+)"$/s; |
| 158 | + const match = result.match(REGEX_QUOTES); |
| 159 | + result = match ? match[1] : result; |
| 160 | + } |
131 | 161 | } |
132 | 162 | return result; |
133 | 163 | } |
|
0 commit comments