1- import type { InferenceProvider } from "../inference-providers.js" ;
1+ import { HF_HUB_INFERENCE_PROXY_TEMPLATE , type InferenceProvider } from "../inference-providers.js" ;
22import type { PipelineType } from "../pipelines.js" ;
33import type { ChatCompletionInputMessage , GenerationParameters } from "../tasks/index.js" ;
44import { stringifyGenerationConfig , stringifyMessages } from "./common.js" ;
55import { getModelInputSnippet } from "./inputs.js" ;
66import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
77
8- export const snippetBasic = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
9- content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
10- -X POST \\
11- -d '{"inputs": ${ getModelInputSnippet ( model , true ) } }' \\
12- -H 'Content-Type: application/json' \\
13- -H 'Authorization: Bearer ${ accessToken || `{API_TOKEN}` } '` ,
14- } ) ;
8+ export const snippetBasic = (
9+ model : ModelDataMinimal ,
10+ accessToken : string ,
11+ provider : InferenceProvider
12+ ) : InferenceSnippet [ ] => {
13+ if ( provider !== "hf-inference" ) {
14+ return [ ] ;
15+ }
16+ return [
17+ {
18+ content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
19+ -X POST \\
20+ -d '{"inputs": ${ getModelInputSnippet ( model , true ) } }' \\
21+ -H 'Content-Type: application/json' \\
22+ -H 'Authorization: Bearer ${ accessToken || `{API_TOKEN}` } '` ,
23+ } ,
24+ ] ;
25+ } ;
1526
1627export const snippetTextGeneration = (
1728 model : ModelDataMinimal ,
1829 accessToken : string ,
30+ provider : InferenceProvider ,
1931 opts ?: {
2032 streaming ?: boolean ;
2133 messages ?: ChatCompletionInputMessage [ ] ;
2234 temperature ?: GenerationParameters [ "temperature" ] ;
2335 max_tokens ?: GenerationParameters [ "max_tokens" ] ;
2436 top_p ?: GenerationParameters [ "top_p" ] ;
2537 }
26- ) : InferenceSnippet => {
38+ ) : InferenceSnippet [ ] => {
2739 if ( model . tags . includes ( "conversational" ) ) {
40+ const baseUrl =
41+ provider === "hf-inference"
42+ ? `https://api-inference.huggingface.co/models/${ model . id } /v1/chat/completions`
43+ : HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , provider ) + "/v1/chat/completions" ;
44+
2845 // Conversational model detected, so we display a code snippet that features the Messages API
2946 const streaming = opts ?. streaming ?? true ;
3047 const exampleMessages = getModelInputSnippet ( model ) as ChatCompletionInputMessage [ ] ;
@@ -35,8 +52,9 @@ export const snippetTextGeneration = (
3552 max_tokens : opts ?. max_tokens ?? 500 ,
3653 ...( opts ?. top_p ? { top_p : opts . top_p } : undefined ) ,
3754 } ;
38- return {
39- content : `curl 'https://api-inference.huggingface.co/models/${ model . id } /v1/chat/completions' \\
55+ return [
56+ {
57+ content : `curl '${ baseUrl } ' \\
4058-H 'Authorization: Bearer ${ accessToken || `{API_TOKEN}` } ' \\
4159-H 'Content-Type: application/json' \\
4260--data '{
@@ -53,31 +71,59 @@ export const snippetTextGeneration = (
5371 } ) } ,
5472 "stream": ${ ! ! streaming }
5573}'` ,
56- } ;
74+ } ,
75+ ] ;
5776 } else {
58- return snippetBasic ( model , accessToken ) ;
77+ return snippetBasic ( model , accessToken , provider ) ;
5978 }
6079} ;
6180
62- export const snippetZeroShotClassification = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
63- content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
81+ export const snippetZeroShotClassification = (
82+ model : ModelDataMinimal ,
83+ accessToken : string ,
84+ provider : InferenceProvider
85+ ) : InferenceSnippet [ ] => {
86+ if ( provider !== "hf-inference" ) {
87+ return [ ] ;
88+ }
89+ return [
90+ {
91+ content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
6492 -X POST \\
6593 -d '{"inputs": ${ getModelInputSnippet ( model , true ) } , "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
6694 -H 'Content-Type: application/json' \\
6795 -H 'Authorization: Bearer ${ accessToken || `{API_TOKEN}` } '` ,
68- } ) ;
96+ } ,
97+ ] ;
98+ } ;
6999
70- export const snippetFile = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
71- content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
100+ export const snippetFile = (
101+ model : ModelDataMinimal ,
102+ accessToken : string ,
103+ provider : InferenceProvider
104+ ) : InferenceSnippet [ ] => {
105+ if ( provider !== "hf-inference" ) {
106+ return [ ] ;
107+ }
108+ return [
109+ {
110+ content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
72111 -X POST \\
73112 --data-binary '@${ getModelInputSnippet ( model , true , true ) } ' \\
74113 -H 'Authorization: Bearer ${ accessToken || `{API_TOKEN}` } '` ,
75- } ) ;
114+ } ,
115+ ] ;
116+ } ;
76117
77118export const curlSnippets : Partial <
78119 Record <
79120 PipelineType ,
80- ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , unknown > ) => InferenceSnippet
121+ (
122+ model : ModelDataMinimal ,
123+ accessToken : string ,
124+ provider : InferenceProvider ,
125+ opts ?: Record < string , unknown >
126+ ) => InferenceSnippet [ ]
81127 >
82128> = {
83129 // Same order as in tasks/src/pipelines.ts
@@ -112,11 +158,9 @@ export function getCurlInferenceSnippet(
112158 provider : InferenceProvider ,
113159 opts ?: Record < string , unknown >
114160) : InferenceSnippet [ ] {
115- const snippets =
116- model . pipeline_tag && model . pipeline_tag in curlSnippets
117- ? curlSnippets [ model . pipeline_tag ] ?.( model , accessToken , opts ) ?? [ { content : "" } ]
118- : [ { content : "" } ] ;
119- return Array . isArray ( snippets ) ? snippets : [ snippets ] ;
161+ return model . pipeline_tag && model . pipeline_tag in curlSnippets
162+ ? curlSnippets [ model . pipeline_tag ] ?.( model , accessToken , provider , opts ) ?? [ ]
163+ : [ ] ;
120164}
121165
122166export function hasCurlInferenceSnippet ( model : Pick < ModelDataMinimal , "pipeline_tag" > ) : boolean {
0 commit comments