11import type { PipelineType } from "../pipelines.js" ;
2+ import type { ChatCompletionInputMessage , GenerationParameters } from "../tasks/index.js" ;
23import { getModelInputSnippet } from "./inputs.js" ;
3- import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
4+ import type {
5+ GenerationConfigFormatter ,
6+ GenerationMessagesFormatter ,
7+ InferenceSnippet ,
8+ ModelDataMinimal ,
9+ } from "./types.js" ;
410
511export const snippetBasic = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
612 content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
@@ -10,20 +16,58 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): Infe
1016 -H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } "` ,
1117} ) ;
1218
13- export const snippetTextGeneration = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => {
19+ const formatGenerationMessages : GenerationMessagesFormatter = ( { messages, sep, start, end } ) =>
20+ start +
21+ messages
22+ . map ( ( { role, content } ) => {
23+ // escape single quotes since single quotes is used to define http post body inside curl requests
24+ // TODO: handle the case below
25+ content = content ?. replace ( / ' / g, "'\\''" ) ;
26+ return `{ "role": "${ role } ", "content": "${ content } " }` ;
27+ } )
28+ . join ( sep ) +
29+ end ;
30+
31+ const formatGenerationConfig : GenerationConfigFormatter = ( { config, sep, start, end } ) =>
32+ start +
33+ Object . entries ( config )
34+ . map ( ( [ key , val ] ) => `"${ key } ": ${ val } ` )
35+ . join ( sep ) +
36+ end ;
37+
38+ export const snippetTextGeneration = (
39+ model : ModelDataMinimal ,
40+ accessToken : string ,
41+ opts ?: {
42+ streaming ?: boolean ;
43+ messages ?: ChatCompletionInputMessage [ ] ;
44+ temperature ?: GenerationParameters [ "temperature" ] ;
45+ max_tokens ?: GenerationParameters [ "max_tokens" ] ;
46+ top_p ?: GenerationParameters [ "top_p" ] ;
47+ }
48+ ) : InferenceSnippet => {
1449 if ( model . tags . includes ( "conversational" ) ) {
1550 // Conversational model detected, so we display a code snippet that features the Messages API
51+ const streaming = opts ?. streaming ?? true ;
52+ const messages : ChatCompletionInputMessage [ ] = opts ?. messages ?? [
53+ { role : "user" , content : "What is the capital of France?" } ,
54+ ] ;
55+
56+ const config = {
57+ temperature : opts ?. temperature ,
58+ max_tokens : opts ?. max_tokens ?? 500 ,
59+ top_p : opts ?. top_p ,
60+ } ;
1661 return {
1762 content : `curl 'https://api-inference.huggingface.co/models/${ model . id } /v1/chat/completions' \\
1863-H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } " \\
1964-H 'Content-Type: application/json' \\
20- -d '{
21- "model": "${ model . id } ",
22- "messages": [{"role": "user", "content": "What is the capital of France?"}],
23- "max_tokens": 500,
24- "stream": false
25- }'
26- ` ,
65+ --data '{
66+ "model": "${ model . id } ",
67+ "messages": ${ formatGenerationMessages ( { messages, sep : ",\n " , start : `[\n ` , end : `\n]` } ) } ,
68+ ${ formatGenerationConfig ( { config, sep : ",\n " , start : "" , end : "" } ) } ,
69+ "stream": ${ ! ! streaming }
70+ }'` ,
2771 } ;
2872 } else {
2973 return snippetBasic ( model , accessToken ) ;
@@ -76,7 +120,7 @@ export const snippetFile = (model: ModelDataMinimal, accessToken: string): Infer
76120export const curlSnippets : Partial <
77121 Record <
78122 PipelineType ,
79- ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , string | boolean | number > ) => InferenceSnippet
123+ ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , unknown > ) => InferenceSnippet
80124 >
81125> = {
82126 // Same order as in js/src/lib/interfaces/Types.ts
0 commit comments