@@ -2,25 +2,40 @@ import type { PipelineType } from "../pipelines.js";
22import { getModelInputSnippet } from "./inputs.js" ;
33import type { ModelDataMinimal } from "./types.js" ;
44
5- export const snippetConversational = ( model : ModelDataMinimal , accessToken : string ) : string =>
5+ // Import snippets
6+
7+ const snippetImportInferenceClient = ( model : ModelDataMinimal , accessToken : string ) : string =>
8+ `from huggingface_hub import InferenceClient
9+
10+ client = InferenceClient(${ model . id } , token="${ accessToken || "{API_TOKEN}" } ")
11+ ` ;
12+
13+ const snippetImportConversationalInferenceClient = ( model : ModelDataMinimal , accessToken : string ) : string =>
14+ // Same but uses OpenAI convention
615 `from huggingface_hub import InferenceClient
716
817client = InferenceClient(api_key="${ accessToken || "{API_TOKEN}" } ")
18+ ` ;
919
10- for message in client.chat_completion(
20+ const snippetImportRequests = ( model : ModelDataMinimal , accessToken : string ) : string =>
21+ `import requests
22+
23+ API_URL = "https://api-inference.huggingface.co/models/${ model . id } "
24+ headers = {"Authorization": ${ accessToken ? `"Bearer ${ accessToken } "` : `f"Bearer {API_TOKEN}"` } }` ;
25+
26+ export const snippetConversational = ( model : ModelDataMinimal ) : string =>
27+ `for message in client.chat_completion(
1128 model="${ model . id } ",
1229 messages=[{"role": "user", "content": "What is the capital of France?"}],
1330 max_tokens=500,
1431 stream=True,
1532):
1633 print(message.choices[0].delta.content, end="")` ;
1734
18- export const snippetConversationalWithImage = ( model : ModelDataMinimal , accessToken : string ) : string =>
19- `from huggingface_hub import InferenceClient
20-
21- client = InferenceClient(api_key="${ accessToken || "{API_TOKEN}" } ")
35+ // InferenceClient-based snippets
2236
23- image_url = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
37+ export const snippetConversationalWithImage = ( model : ModelDataMinimal ) : string =>
38+ `image_url = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
2439
2540for message in client.chat_completion(
2641 model="${ model . id } ",
@@ -38,31 +53,29 @@ for message in client.chat_completion(
3853):
3954 print(message.choices[0].delta.content, end="")` ;
4055
41- export const snippetZeroShotClassification = ( model : ModelDataMinimal ) : string =>
42- `def query(payload):
43- response = requests.post(API_URL, headers=headers, json=payload)
44- return response.json()
56+ export const snippetDocumentQuestionAnswering = ( ) : string =>
57+ `output = client.document_question_answering("cat.png", "What is in this image?")` ;
4558
46- output = query({
47- "inputs": ${ getModelInputSnippet ( model ) } ,
48- "parameters": {"candidate_labels": ["refund", "legal", "faq"]},
49- })` ;
59+ export const snippetTabularClassification = ( model : ModelDataMinimal ) : string =>
60+ `output = client.tabular_classification(${ getModelInputSnippet ( model ) } )` ;
61+
62+ export const snippetTabularRegression = ( model : ModelDataMinimal ) : string =>
63+ `output = client.tabular_regression(${ getModelInputSnippet ( model ) } )` ;
64+ export const snippetTextToImage = ( model : ModelDataMinimal ) : string =>
65+ `# output is a PIL.Image object
66+ image = client.text_to_image(${ getModelInputSnippet ( model ) } )` ;
67+
68+ export const snippetZeroShotClassification = ( model : ModelDataMinimal ) : string =>
69+ `text = ${ getModelInputSnippet ( model ) }
70+ labels = ["refund", "legal", "faq"]
71+ output = client.zero_shot_classification(text, labels)` ;
5072
5173export const snippetZeroShotImageClassification = ( model : ModelDataMinimal ) : string =>
52- `def query(data):
53- with open(data["image_path"], "rb") as f:
54- img = f.read()
55- payload={
56- "parameters": data["parameters"],
57- "inputs": base64.b64encode(img).decode("utf-8")
58- }
59- response = requests.post(API_URL, headers=headers, json=payload)
60- return response.json()
74+ `image = ${ getModelInputSnippet ( model ) }
75+ labels = ["cat", "dog", "llama"]
76+ output = client.zero_shot_image_classification(image, labels)` ;
6177
62- output = query({
63- "image_path": ${ getModelInputSnippet ( model ) } ,
64- "parameters": {"candidate_labels": ["cat", "dog", "llama"]},
65- })` ;
78+ // requests-based snippets
6679
6780export const snippetBasic = ( model : ModelDataMinimal ) : string =>
6881 `def query(payload):
@@ -82,26 +95,6 @@ export const snippetFile = (model: ModelDataMinimal): string =>
8295
8396output = query(${ getModelInputSnippet ( model ) } )` ;
8497
85- export const snippetTextToImage = ( model : ModelDataMinimal ) : string =>
86- `def query(payload):
87- response = requests.post(API_URL, headers=headers, json=payload)
88- return response.content
89- image_bytes = query({
90- "inputs": ${ getModelInputSnippet ( model ) } ,
91- })
92- # You can access the image with PIL.Image for example
93- import io
94- from PIL import Image
95- image = Image.open(io.BytesIO(image_bytes))` ;
96-
97- export const snippetTabular = ( model : ModelDataMinimal ) : string =>
98- `def query(payload):
99- response = requests.post(API_URL, headers=headers, json=payload)
100- return response.content
101- response = query({
102- "inputs": {"data": ${ getModelInputSnippet ( model ) } },
103- })` ;
104-
10598export const snippetTextToAudio = ( model : ModelDataMinimal ) : string => {
10699 // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
107100 // with the latest update to inference-api (IA).
@@ -131,19 +124,16 @@ Audio(audio, rate=sampling_rate)`;
131124 }
132125} ;
133126
134- export const snippetDocumentQuestionAnswering = ( model : ModelDataMinimal ) : string =>
135- `def query(payload):
136- with open(payload["image"], "rb") as f:
137- img = f.read()
138- payload["image"] = base64.b64encode(img).decode("utf-8")
139- response = requests.post(API_URL, headers=headers, json=payload)
140- return response.json()
141-
142- output = query({
143- "inputs": ${ getModelInputSnippet ( model ) } ,
144- })` ;
127+ const PIPELINES_USING_INFERENCE_CLIENT : PipelineType [ ] = [
128+ "document-question-answering" ,
129+ "tabular-classification" ,
130+ "tabular-regression" ,
131+ "text-to-image" ,
132+ "zero-shot-classification" ,
133+ "zero-shot-image-classification" ,
134+ ] ;
145135
146- export const pythonSnippets : Partial < Record < PipelineType , ( model : ModelDataMinimal , accessToken : string ) => string > > = {
136+ export const pythonSnippets : Partial < Record < PipelineType , ( model : ModelDataMinimal ) => string > > = {
147137 // Same order as in tasks/src/pipelines.ts
148138 "text-classification" : snippetBasic ,
149139 "token-classification" : snippetBasic ,
@@ -165,8 +155,8 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
165155 "audio-to-audio" : snippetFile ,
166156 "audio-classification" : snippetFile ,
167157 "image-classification" : snippetFile ,
168- "tabular-regression" : snippetTabular ,
169- "tabular-classification" : snippetTabular ,
158+ "tabular-regression" : snippetTabularRegression ,
159+ "tabular-classification" : snippetTabularClassification ,
170160 "object-detection" : snippetFile ,
171161 "image-segmentation" : snippetFile ,
172162 "document-question-answering" : snippetDocumentQuestionAnswering ,
@@ -175,25 +165,38 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
175165} ;
176166
177167export function getPythonInferenceSnippet ( model : ModelDataMinimal , accessToken : string ) : string {
178- if ( model . pipeline_tag === "text-generation" && model . tags . includes ( "conversational" ) ) {
179- // Conversational model detected, so we display a code snippet that features the Messages API
180- return snippetConversational ( model , accessToken ) ;
181- } else if ( model . pipeline_tag === "image-text-to-text" && model . tags . includes ( "conversational" ) ) {
182- // Example sending an image to the Message API
183- return snippetConversationalWithImage ( model , accessToken ) ;
184- } else {
185- const body =
186- model . pipeline_tag && model . pipeline_tag in pythonSnippets
187- ? pythonSnippets [ model . pipeline_tag ] ?.( model , accessToken ) ?? ""
188- : "" ;
189-
190- return `import requests
191-
192- API_URL = "https://api-inference.huggingface.co/models/${ model . id } "
193- headers = {"Authorization": ${ accessToken ? `"Bearer ${ accessToken } "` : `f"Bearer {API_TOKEN}"` } }
168+ // Specific case for chat completion snippets
169+ const isConversational =
170+ "conversational" in model . tags &&
171+ model . pipeline_tag &&
172+ model . pipeline_tag in [ "text-generation" , "image-text-to-text" ] ;
173+
174+ // Determine the import snippet based on model tags and pipeline tag
175+ const getImportSnippet = ( ) => {
176+ if ( isConversational ) {
177+ return snippetImportConversationalInferenceClient ( model , accessToken ) ;
178+ } else if ( model . pipeline_tag && model . pipeline_tag in PIPELINES_USING_INFERENCE_CLIENT ) {
179+ return snippetImportInferenceClient ( model , accessToken ) ;
180+ } else {
181+ return snippetImportRequests ( model , accessToken ) ;
182+ }
183+ } ;
184+
185+ // Determine the body snippet based on model tags and pipeline tag
186+ const getBodySnippet = ( ) => {
187+ if ( isConversational ) {
188+ return model . pipeline_tag === "text-generation"
189+ ? snippetConversational ( model )
190+ : snippetConversationalWithImage ( model ) ;
191+ } else if ( model . pipeline_tag && model . pipeline_tag in pythonSnippets ) {
192+ return pythonSnippets [ model . pipeline_tag ] ?.( model ) ?? "" ;
193+ } else {
194+ return "" ;
195+ }
196+ } ;
194197
195- ${ body } ` ;
196- }
198+ // Combine import and body snippets with newline separation
199+ return ` ${ getImportSnippet ( ) } \n\n ${ getBodySnippet ( ) } ` ;
197200}
198201
199202export function hasPythonInferenceSnippet ( model : ModelDataMinimal ) : boolean {
0 commit comments