11import type { PipelineType } from "../pipelines.js" ;
2+ import type { ChatCompletionInputMessage , GenerationParameters } from "../tasks/index.js" ;
3+ import { stringifyGenerationConfig , stringifyMessages } from "./common.js" ;
24import { getModelInputSnippet } from "./inputs.js" ;
3- import type { ModelDataMinimal } from "./types.js" ;
5+ import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
46
5- export const snippetBasic = ( model : ModelDataMinimal , accessToken : string ) : string =>
6- `curl https://api-inference.huggingface.co/models/${ model . id } \\
7+ export const snippetBasic = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
8+ content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
79 -X POST \\
810 -d '{"inputs": ${ getModelInputSnippet ( model , true ) } }' \\
911 -H 'Content-Type: application/json' \\
10- -H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } "` ;
12+ -H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } "` ,
13+ } ) ;
1114
12- export const snippetTextGeneration = ( model : ModelDataMinimal , accessToken : string ) : string => {
15+ export const snippetTextGeneration = (
16+ model : ModelDataMinimal ,
17+ accessToken : string ,
18+ opts ?: {
19+ streaming ?: boolean ;
20+ messages ?: ChatCompletionInputMessage [ ] ;
21+ temperature ?: GenerationParameters [ "temperature" ] ;
22+ max_tokens ?: GenerationParameters [ "max_tokens" ] ;
23+ top_p ?: GenerationParameters [ "top_p" ] ;
24+ }
25+ ) : InferenceSnippet => {
1326 if ( model . tags . includes ( "conversational" ) ) {
1427 // Conversational model detected, so we display a code snippet that features the Messages API
15- return `curl 'https://api-inference.huggingface.co/models/${ model . id } /v1/chat/completions' \\
28+ const streaming = opts ?. streaming ?? true ;
29+ const messages : ChatCompletionInputMessage [ ] = opts ?. messages ?? [
30+ { role : "user" , content : "What is the capital of France?" } ,
31+ ] ;
32+
33+ const config = {
34+ ...( opts ?. temperature ? { temperature : opts . temperature } : undefined ) ,
35+ max_tokens : opts ?. max_tokens ?? 500 ,
36+ ...( opts ?. top_p ? { top_p : opts . top_p } : undefined ) ,
37+ } ;
38+ return {
39+ content : `curl 'https://api-inference.huggingface.co/models/${ model . id } /v1/chat/completions' \\
1640-H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } " \\
1741-H 'Content-Type: application/json' \\
18- -d '{
19- "model": "${ model . id } ",
20- "messages": [{"role": "user", "content": "What is the capital of France?"}],
21- "max_tokens": 500,
22- "stream": false
23- }'
24- ` ;
42+ --data '{
43+ "model": "${ model . id } ",
44+ "messages": ${ stringifyMessages ( messages , {
45+ sep : ",\n\t\t" ,
46+ start : `[\n\t\t` ,
47+ end : `\n\t]` ,
48+ attributeKeyQuotes : true ,
49+ customContentEscaper : ( str ) => str . replace ( / ' / g, "'\\''" ) ,
50+ } ) } ,
51+ ${ stringifyGenerationConfig ( config , {
52+ sep : ",\n " ,
53+ start : "" ,
54+ end : "" ,
55+ attributeKeyQuotes : true ,
56+ attributeValueConnector : ": " ,
57+ } ) } ,
58+ "stream": ${ ! ! streaming }
59+ }'` ,
60+ } ;
2561 } else {
2662 return snippetBasic ( model , accessToken ) ;
2763 }
2864} ;
2965
30- export const snippetImageTextToTextGeneration = ( model : ModelDataMinimal , accessToken : string ) : string => {
66+ export const snippetImageTextToTextGeneration = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => {
3167 if ( model . tags . includes ( "conversational" ) ) {
3268 // Conversational model detected, so we display a code snippet that features the Messages API
33- return `curl 'https://api-inference.huggingface.co/models/${ model . id } /v1/chat/completions' \\
69+ return {
70+ content : `curl 'https://api-inference.huggingface.co/models/${ model . id } /v1/chat/completions' \\
3471-H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } " \\
3572-H 'Content-Type: application/json' \\
3673-d '{
@@ -47,26 +84,34 @@ export const snippetImageTextToTextGeneration = (model: ModelDataMinimal, access
4784 "max_tokens": 500,
4885 "stream": false
4986}'
50- ` ;
87+ ` ,
88+ } ;
5189 } else {
5290 return snippetBasic ( model , accessToken ) ;
5391 }
5492} ;
5593
56- export const snippetZeroShotClassification = ( model : ModelDataMinimal , accessToken : string ) : string =>
57- `curl https://api-inference.huggingface.co/models/${ model . id } \\
94+ export const snippetZeroShotClassification = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
95+ content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
5896 -X POST \\
5997 -d '{"inputs": ${ getModelInputSnippet ( model , true ) } , "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
6098 -H 'Content-Type: application/json' \\
61- -H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } "` ;
99+ -H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } "` ,
100+ } ) ;
62101
63- export const snippetFile = ( model : ModelDataMinimal , accessToken : string ) : string =>
64- `curl https://api-inference.huggingface.co/models/${ model . id } \\
102+ export const snippetFile = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
103+ content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
65104 -X POST \\
66105 --data-binary '@${ getModelInputSnippet ( model , true , true ) } ' \\
67- -H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } "` ;
106+ -H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } "` ,
107+ } ) ;
68108
69- export const curlSnippets : Partial < Record < PipelineType , ( model : ModelDataMinimal , accessToken : string ) => string > > = {
109+ export const curlSnippets : Partial <
110+ Record <
111+ PipelineType ,
112+ ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , unknown > ) => InferenceSnippet
113+ >
114+ > = {
70115 // Same order as in js/src/lib/interfaces/Types.ts
71116 "text-classification" : snippetBasic ,
72117 "token-classification" : snippetBasic ,
@@ -93,10 +138,10 @@ export const curlSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal
93138 "image-segmentation" : snippetFile ,
94139} ;
95140
96- export function getCurlInferenceSnippet ( model : ModelDataMinimal , accessToken : string ) : string {
141+ export function getCurlInferenceSnippet ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet {
97142 return model . pipeline_tag && model . pipeline_tag in curlSnippets
98- ? curlSnippets [ model . pipeline_tag ] ?.( model , accessToken ) ?? ""
99- : "" ;
143+ ? curlSnippets [ model . pipeline_tag ] ?.( model , accessToken ) ?? { content : "" }
144+ : { content : "" } ;
100145}
101146
102147export function hasCurlInferenceSnippet ( model : Pick < ModelDataMinimal , "pipeline_tag" > ) : boolean {
0 commit comments