Skip to content

Commit cea35c9

Browse files
julien-cWauplin
andauthored
improve python snippets to add huggingface_hub in more tasks (#1185)
(the code for inference snippets is still quite ugly, but we'll improve later) --------- Co-authored-by: Wauplin <[email protected]>
1 parent 0429cdc commit cea35c9

32 files changed

+192
-73
lines changed

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import { existsSync as pathExists } from "node:fs";
1919
import * as fs from "node:fs/promises";
2020
import * as path from "node:path/posix";
2121

22-
import type { InferenceProvider, InferenceSnippet } from "@huggingface/tasks";
22+
import type { SnippetInferenceProvider, InferenceSnippet } from "@huggingface/tasks";
2323
import { snippets } from "@huggingface/tasks";
2424

2525
type LANGUAGE = "sh" | "js" | "py";
@@ -28,7 +28,7 @@ const TEST_CASES: {
2828
testName: string;
2929
model: snippets.ModelDataMinimal;
3030
languages: LANGUAGE[];
31-
providers: InferenceProvider[];
31+
providers: SnippetInferenceProvider[];
3232
opts?: Record<string, unknown>;
3333
}[] = [
3434
{
@@ -90,6 +90,17 @@ const TEST_CASES: {
9090
providers: ["hf-inference"],
9191
languages: ["sh", "js", "py"],
9292
},
93+
{
94+
testName: "text-classification",
95+
model: {
96+
id: "distilbert/distilbert-base-uncased-finetuned-sst-2-english",
97+
pipeline_tag: "text-classification",
98+
tags: [],
99+
inference: "",
100+
},
101+
providers: ["hf-inference"],
102+
languages: ["sh", "js", "py"],
103+
},
93104
] as const;
94105

95106
const GET_SNIPPET_FN = {
@@ -119,17 +130,16 @@ function getFixtureFolder(testName: string): string {
119130
function generateInferenceSnippet(
120131
model: snippets.ModelDataMinimal,
121132
language: LANGUAGE,
122-
provider: InferenceProvider,
133+
provider: SnippetInferenceProvider,
123134
opts?: Record<string, unknown>
124135
): InferenceSnippet[] {
125-
const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", provider, opts);
126-
return Array.isArray(generatedSnippets) ? generatedSnippets : [generatedSnippets];
136+
return GET_SNIPPET_FN[language](model, "api_token", provider, opts);
127137
}
128138

129139
async function getExpectedInferenceSnippet(
130140
testName: string,
131141
language: LANGUAGE,
132-
provider: InferenceProvider
142+
provider: SnippetInferenceProvider
133143
): Promise<InferenceSnippet[]> {
134144
const fixtureFolder = getFixtureFolder(testName);
135145
const files = await fs.readdir(fixtureFolder);
@@ -146,7 +156,7 @@ async function getExpectedInferenceSnippet(
146156
async function saveExpectedInferenceSnippet(
147157
testName: string,
148158
language: LANGUAGE,
149-
provider: InferenceProvider,
159+
provider: SnippetInferenceProvider,
150160
snippets: InferenceSnippet[]
151161
) {
152162
const fixtureFolder = getFixtureFolder(testName);

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.curl.hf-inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \
1+
curl 'https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \
22
-H 'Authorization: Bearer api_token' \
33
-H 'Content-Type: application/json' \
44
--data '{

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { OpenAI } from "openai";
22

33
const client = new OpenAI({
4-
baseURL: "https://router.huggingface.co/hf-inference",
4+
baseURL: "https://router.huggingface.co/hf-inference/v1",
55
apiKey: "api_token"
66
});
77

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.hf-inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from openai import OpenAI
22

33
client = OpenAI(
4-
base_url="https://router.huggingface.co/hf-inference",
4+
base_url="https://router.huggingface.co/hf-inference/v1",
55
api_key="api_token"
66
)
77

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.curl.hf-inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \
1+
curl 'https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \
22
-H 'Authorization: Bearer api_token' \
33
-H 'Content-Type: application/json' \
44
--data '{

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { OpenAI } from "openai";
22

33
const client = new OpenAI({
4-
baseURL: "https://router.huggingface.co/hf-inference",
4+
baseURL: "https://router.huggingface.co/hf-inference/v1",
55
apiKey: "api_token"
66
});
77

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.hf-inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from openai import OpenAI
22

33
client = OpenAI(
4-
base_url="https://router.huggingface.co/hf-inference",
4+
base_url="https://router.huggingface.co/hf-inference/v1",
55
api_key="api_token"
66
)
77

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.curl.hf-inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-11B-Vision-Instruct/v1/chat/completions' \
1+
curl 'https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.2-11B-Vision-Instruct/v1/chat/completions' \
22
-H 'Authorization: Bearer api_token' \
33
-H 'Content-Type: application/json' \
44
--data '{

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { OpenAI } from "openai";
22

33
const client = new OpenAI({
4-
baseURL: "https://router.huggingface.co/hf-inference",
4+
baseURL: "https://router.huggingface.co/hf-inference/v1",
55
apiKey: "api_token"
66
});
77

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.hf-inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from openai import OpenAI
22

33
client = OpenAI(
4-
base_url="https://router.huggingface.co/hf-inference",
4+
base_url="https://router.huggingface.co/hf-inference/v1",
55
api_key="api_token"
66
)
77

0 commit comments

Comments
 (0)