11import type { PipelineType } from "../pipelines.js" ;
2+ import type { ChatCompletionInputMessage , GenerationParameters } from "../tasks/index.js" ;
23import { getModelInputSnippet } from "./inputs.js" ;
3- import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
4+ import type {
5+ GenerationConfigFormatter ,
6+ GenerationMessagesFormatter ,
7+ InferenceSnippet ,
8+ ModelDataMinimal ,
9+ } from "./types.js" ;
410
511export const snippetBasic = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
612 content : `async function query(data) {
@@ -24,22 +30,128 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
2430});` ,
2531} ) ;
2632
27- export const snippetTextGeneration = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => {
33+ const formatGenerationMessages : GenerationMessagesFormatter = ( { messages, sep, start, end } ) =>
34+ start + messages . map ( ( { role, content } ) => `{ role: "${ role } ", content: "${ content } " }` ) . join ( sep ) + end ;
35+
36+ const formatGenerationConfig : GenerationConfigFormatter = ( { config, sep, start, end } ) =>
37+ start +
38+ Object . entries ( config )
39+ . map ( ( [ key , val ] ) => `${ key } : ${ val } ` )
40+ . join ( sep ) +
41+ end ;
42+
43+ export const snippetTextGeneration = (
44+ model : ModelDataMinimal ,
45+ accessToken : string ,
46+ opts ?: {
47+ streaming ?: boolean ;
48+ messages ?: ChatCompletionInputMessage [ ] ;
49+ temperature ?: GenerationParameters [ "temperature" ] ;
50+ max_tokens ?: GenerationParameters [ "max_tokens" ] ;
51+ top_p ?: GenerationParameters [ "top_p" ] ;
52+ }
53+ ) : InferenceSnippet | InferenceSnippet [ ] => {
2854 if ( model . tags . includes ( "conversational" ) ) {
2955 // Conversational model detected, so we display a code snippet that features the Messages API
30- return {
31- content : `import { HfInference } from "@huggingface/inference";
56+ const streaming = opts ?. streaming ?? true ;
57+ const messages : ChatCompletionInputMessage [ ] = opts ?. messages ?? [
58+ { role : "user" , content : "What is the capital of France?" } ,
59+ ] ;
60+ const messagesStr = formatGenerationMessages ( { messages, sep : ",\n\t\t" , start : "[\n\t\t" , end : "\n\t]" } ) ;
3261
33- const inference = new HfInference("${ accessToken || `{API_TOKEN}` } ");
62+ const config = {
63+ temperature : opts ?. temperature ,
64+ max_tokens : opts ?. max_tokens ?? 500 ,
65+ top_p : opts ?. top_p ,
66+ } ;
67+ const configStr = formatGenerationConfig ( { config, sep : ",\n\t" , start : "" , end : "" } ) ;
3468
35- for await (const chunk of inference.chatCompletionStream({
69+ if ( streaming ) {
70+ return [
71+ {
72+ client : "huggingface_hub" ,
73+ content : `import { HfInference } from "@huggingface/inference"
74+
75+ const client = new HfInference("${ accessToken || `{API_TOKEN}` } ")
76+
77+ let out = "";
78+
79+ const stream = client.chatCompletionStream({
3680 model: "${ model . id } ",
37- messages: [{ role: "user", content: "What is the capital of France?" }],
38- max_tokens: 500,
39- })) {
40- process.stdout.write(chunk.choices[0]?.delta?.content || "");
81+ messages: ${ messagesStr } ,
82+ ${ configStr }
83+ });
84+
85+ for await (const chunk of stream) {
86+ if (chunk.choices && chunk.choices.length > 0) {
87+ const newContent = chunk.choices[0].delta.content;
88+ out += newContent;
89+ console.log(newContent);
90+ }
4191}` ,
42- } ;
92+ } ,
93+ {
94+ client : "openai" ,
95+ content : `import { OpenAI } from "openai"
96+
97+ const client = new OpenAI({
98+ baseURL: "https://api-inference.huggingface.co/v1/",
99+ apiKey: "${ accessToken || `{API_TOKEN}` } "
100+ })
101+
102+ let out = "";
103+
104+ const stream = await client.chat.completions.create({
105+ model: "${ model . id } ",
106+ messages: ${ messagesStr } ,
107+ ${ configStr } ,
108+ stream: true,
109+ });
110+
111+ for await (const chunk of stream) {
112+ if (chunk.choices && chunk.choices.length > 0) {
113+ const newContent = chunk.choices[0].delta.content;
114+ out += newContent;
115+ console.log(newContent);
116+ }
117+ }` ,
118+ } ,
119+ ] ;
120+ } else {
121+ return [
122+ {
123+ client : "huggingface_hub" ,
124+ content : `import { HfInference } from '@huggingface/inference'
125+
126+ const client = new HfInference("${ accessToken || `{API_TOKEN}` } ")
127+
128+ const chatCompletion = await client.chatCompletion({
129+ model: "${ model . id } ",
130+ messages: ${ messagesStr } ,
131+ ${ configStr }
132+ });
133+
134+ console.log(chatCompletion.choices[0].message);` ,
135+ } ,
136+ {
137+ client : "openai" ,
138+ content : `import { OpenAI } from "openai"
139+
140+ const client = new OpenAI({
141+ baseURL: "https://api-inference.huggingface.co/v1/",
142+ apiKey: "${ accessToken || `{API_TOKEN}` } "
143+ })
144+
145+ const chatCompletion = await client.chat.completions.create({
146+ model: "${ model . id } ",
147+ messages: ${ messagesStr } ,
148+ ${ configStr }
149+ });
150+
151+ console.log(chatCompletion.choices[0].message);` ,
152+ } ,
153+ ] ;
154+ }
43155 } else {
44156 return snippetBasic ( model , accessToken ) ;
45157 }
@@ -187,7 +299,11 @@ query(${getModelInputSnippet(model)}).then((response) => {
187299export const jsSnippets : Partial <
188300 Record <
189301 PipelineType ,
190- ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , string | boolean | number > ) => InferenceSnippet
302+ (
303+ model : ModelDataMinimal ,
304+ accessToken : string ,
305+ opts ?: Record < string , unknown >
306+ ) => InferenceSnippet | InferenceSnippet [ ]
191307 >
192308> = {
193309 // Same order as in js/src/lib/interfaces/Types.ts
0 commit comments