Skip to content

Commit 0907d8f

Browse files
committed
Add base64 import when required
1 parent 6234469 commit 0907d8f

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

packages/tasks/src/snippets/python.spec.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ client = InferenceClient("impira/layoutlm-invoices", token="api_token")
155155
output = client.document_question_answering(cat.png, question=What is in this image?)`);
156156

157157
expect(snippets[1].client).toEqual("requests");
158-
expect(snippets[1].content).toEqual(`import requests
158+
expect(snippets[1].content).toEqual(`import base64
159+
import requests
159160
160161
API_URL = "https://api-inference.huggingface.co/models/impira/layoutlm-invoices"
161162
headers = {"Authorization": "Bearer api_token"}
@@ -194,7 +195,8 @@ client = InferenceClient("stabilityai/stable-diffusion-xl-refiner-1.0", token="a
194195
image = client.image_to_image("cat.png", prompt="Turn the cat into a tiger.")`);
195196

196197
expect(snippets[1].client).toEqual("requests");
197-
expect(snippets[1].content).toEqual(`import requests
198+
expect(snippets[1].content).toEqual(`import base64
199+
import requests
198200
199201
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-refiner-1.0"
200202
headers = {"Authorization": "Bearer api_token"}

packages/tasks/src/snippets/python.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@ const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: stri
1111
client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}")
1212
`;
1313

14+
const addImportsToSnippet = (snippet: string, model: ModelDataMinimal, accessToken: string): string => {
15+
if (snippet.includes("requests")) {
16+
snippet = `import requests
17+
18+
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
19+
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
20+
21+
${snippet}`;
22+
}
23+
if (snippet.includes("base64")) {
24+
snippet = `import base64
25+
${snippet}`;
26+
}
27+
return snippet;
28+
};
29+
1430
const snippetBasic = (model: ModelDataMinimal): InferenceSnippet => ({
1531
content: `def query(payload):
1632
response = requests.post(API_URL, headers=headers, json=payload)
@@ -368,14 +384,7 @@ export function getPythonInferenceSnippet(
368384
return snippets.map((snippet) => {
369385
return {
370386
...snippet,
371-
content: snippet.content.includes("requests")
372-
? `import requests
373-
374-
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
375-
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
376-
377-
${snippet.content}`
378-
: snippet.content,
387+
content: addImportsToSnippet(snippet.content, model, accessToken),
379388
};
380389
});
381390
}

0 commit comments

Comments
 (0)