Skip to content

Commit 475e5a5

Browse files
committed
Use InferenceClient in (some) python inference snippets
1 parent f9ae194 commit 475e5a5

File tree

1 file changed

+83
-80
lines changed

1 file changed

+83
-80
lines changed

packages/tasks/src/snippets/python.ts

Lines changed: 83 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,40 @@ import type { PipelineType } from "../pipelines.js";
22
import { getModelInputSnippet } from "./inputs.js";
33
import type { ModelDataMinimal } from "./types.js";
44

5-
export const snippetConversational = (model: ModelDataMinimal, accessToken: string): string =>
5+
// Import snippets
6+
7+
const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
8+
`from huggingface_hub import InferenceClient
9+
10+
client = InferenceClient(${model.id}, token="${accessToken || "{API_TOKEN}"}")
11+
`;
12+
13+
const snippetImportConversationalInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
14+
// Same but uses OpenAI convention
615
`from huggingface_hub import InferenceClient
716
817
client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
18+
`;
919

10-
for message in client.chat_completion(
20+
const snippetImportRequests = (model: ModelDataMinimal, accessToken: string): string =>
21+
`import requests
22+
23+
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
24+
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}`;
25+
26+
export const snippetConversational = (model: ModelDataMinimal): string =>
27+
`for message in client.chat_completion(
1128
model="${model.id}",
1229
messages=[{"role": "user", "content": "What is the capital of France?"}],
1330
max_tokens=500,
1431
stream=True,
1532
):
1633
print(message.choices[0].delta.content, end="")`;
1734

18-
export const snippetConversationalWithImage = (model: ModelDataMinimal, accessToken: string): string =>
19-
`from huggingface_hub import InferenceClient
20-
21-
client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
35+
// InferenceClient-based snippets
2236

23-
image_url = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
37+
export const snippetConversationalWithImage = (model: ModelDataMinimal): string =>
38+
`image_url = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
2439
2540
for message in client.chat_completion(
2641
model="${model.id}",
@@ -38,31 +53,29 @@ for message in client.chat_completion(
3853
):
3954
print(message.choices[0].delta.content, end="")`;
4055

41-
export const snippetZeroShotClassification = (model: ModelDataMinimal): string =>
42-
`def query(payload):
43-
response = requests.post(API_URL, headers=headers, json=payload)
44-
return response.json()
56+
export const snippetDocumentQuestionAnswering = (): string =>
57+
`output = client.document_question_answering("cat.png", "What is in this image?")`;
4558

46-
output = query({
47-
"inputs": ${getModelInputSnippet(model)},
48-
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
49-
})`;
59+
export const snippetTabularClassification = (model: ModelDataMinimal): string =>
60+
`output = client.tabular_classification(${getModelInputSnippet(model)})`;
61+
62+
export const snippetTabularRegression = (model: ModelDataMinimal): string =>
63+
`output = client.tabular_regression(${getModelInputSnippet(model)})`;
64+
export const snippetTextToImage = (model: ModelDataMinimal): string =>
65+
`# output is a PIL.Image object
66+
image = client.text_to_image(${getModelInputSnippet(model)})`;
67+
68+
export const snippetZeroShotClassification = (model: ModelDataMinimal): string =>
69+
`text = ${getModelInputSnippet(model)}
70+
labels = ["refund", "legal", "faq"]
71+
output = client.zero_shot_classification(text, labels)`;
5072

5173
export const snippetZeroShotImageClassification = (model: ModelDataMinimal): string =>
52-
`def query(data):
53-
with open(data["image_path"], "rb") as f:
54-
img = f.read()
55-
payload={
56-
"parameters": data["parameters"],
57-
"inputs": base64.b64encode(img).decode("utf-8")
58-
}
59-
response = requests.post(API_URL, headers=headers, json=payload)
60-
return response.json()
74+
`image = ${getModelInputSnippet(model)}
75+
labels = ["cat", "dog", "llama"]
76+
output = client.zero_shot_image_classification(image, labels)`;
6177

62-
output = query({
63-
"image_path": ${getModelInputSnippet(model)},
64-
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
65-
})`;
78+
// requests-based snippets
6679

6780
export const snippetBasic = (model: ModelDataMinimal): string =>
6881
`def query(payload):
@@ -82,26 +95,6 @@ export const snippetFile = (model: ModelDataMinimal): string =>
8295
8396
output = query(${getModelInputSnippet(model)})`;
8497

85-
export const snippetTextToImage = (model: ModelDataMinimal): string =>
86-
`def query(payload):
87-
response = requests.post(API_URL, headers=headers, json=payload)
88-
return response.content
89-
image_bytes = query({
90-
"inputs": ${getModelInputSnippet(model)},
91-
})
92-
# You can access the image with PIL.Image for example
93-
import io
94-
from PIL import Image
95-
image = Image.open(io.BytesIO(image_bytes))`;
96-
97-
export const snippetTabular = (model: ModelDataMinimal): string =>
98-
`def query(payload):
99-
response = requests.post(API_URL, headers=headers, json=payload)
100-
return response.content
101-
response = query({
102-
"inputs": {"data": ${getModelInputSnippet(model)}},
103-
})`;
104-
10598
export const snippetTextToAudio = (model: ModelDataMinimal): string => {
10699
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
107100
// with the latest update to inference-api (IA).
@@ -131,19 +124,16 @@ Audio(audio, rate=sampling_rate)`;
131124
}
132125
};
133126

134-
export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): string =>
135-
`def query(payload):
136-
with open(payload["image"], "rb") as f:
137-
img = f.read()
138-
payload["image"] = base64.b64encode(img).decode("utf-8")
139-
response = requests.post(API_URL, headers=headers, json=payload)
140-
return response.json()
141-
142-
output = query({
143-
"inputs": ${getModelInputSnippet(model)},
144-
})`;
127+
const PIPELINES_USING_INFERENCE_CLIENT: PipelineType[] = [
128+
"document-question-answering",
129+
"tabular-classification",
130+
"tabular-regression",
131+
"text-to-image",
132+
"zero-shot-classification",
133+
"zero-shot-image-classification",
134+
];
145135

146-
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal, accessToken: string) => string>> = {
136+
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal) => string>> = {
147137
// Same order as in tasks/src/pipelines.ts
148138
"text-classification": snippetBasic,
149139
"token-classification": snippetBasic,
@@ -165,8 +155,8 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
165155
"audio-to-audio": snippetFile,
166156
"audio-classification": snippetFile,
167157
"image-classification": snippetFile,
168-
"tabular-regression": snippetTabular,
169-
"tabular-classification": snippetTabular,
158+
"tabular-regression": snippetTabularRegression,
159+
"tabular-classification": snippetTabularClassification,
170160
"object-detection": snippetFile,
171161
"image-segmentation": snippetFile,
172162
"document-question-answering": snippetDocumentQuestionAnswering,
@@ -175,25 +165,38 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
175165
};
176166

177167
export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): string {
178-
if (model.pipeline_tag === "text-generation" && model.tags.includes("conversational")) {
179-
// Conversational model detected, so we display a code snippet that features the Messages API
180-
return snippetConversational(model, accessToken);
181-
} else if (model.pipeline_tag === "image-text-to-text" && model.tags.includes("conversational")) {
182-
// Example sending an image to the Message API
183-
return snippetConversationalWithImage(model, accessToken);
184-
} else {
185-
const body =
186-
model.pipeline_tag && model.pipeline_tag in pythonSnippets
187-
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
188-
: "";
189-
190-
return `import requests
191-
192-
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
193-
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
168+
// Specific case for chat completion snippets
169+
const isConversational =
170+
"conversational" in model.tags &&
171+
model.pipeline_tag &&
172+
model.pipeline_tag in ["text-generation", "image-text-to-text"];
173+
174+
// Determine the import snippet based on model tags and pipeline tag
175+
const getImportSnippet = () => {
176+
if (isConversational) {
177+
return snippetImportConversationalInferenceClient(model, accessToken);
178+
} else if (model.pipeline_tag && model.pipeline_tag in PIPELINES_USING_INFERENCE_CLIENT) {
179+
return snippetImportInferenceClient(model, accessToken);
180+
} else {
181+
return snippetImportRequests(model, accessToken);
182+
}
183+
};
184+
185+
// Determine the body snippet based on model tags and pipeline tag
186+
const getBodySnippet = () => {
187+
if (isConversational) {
188+
return model.pipeline_tag === "text-generation"
189+
? snippetConversational(model)
190+
: snippetConversationalWithImage(model);
191+
} else if (model.pipeline_tag && model.pipeline_tag in pythonSnippets) {
192+
return pythonSnippets[model.pipeline_tag]?.(model) ?? "";
193+
} else {
194+
return "";
195+
}
196+
};
194197

195-
${body}`;
196-
}
198+
// Combine import and body snippets with newline separation
199+
return `${getImportSnippet()}\n\n${getBodySnippet()}`;
197200
}
198201

199202
export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean {

0 commit comments

Comments
 (0)