Skip to content

Commit f931db2

Browse files
WauplinSBrandeis
andauthored
[draft] Support providerModelId in inference snippets (#1210)
Goal of this PR is to correctly handle model id mapping when generating inference snippets for a given provider. For now it's a simple PR to showcase what I had in mind in huggingface-internal/moon-landing#12626 (comment) (private repo). I only implemented it for chat-completion curl snippet but the rest would follow the same pattern. Note that we should use the mapping "only" for curl, pure Python, pure JS and openai client snippets. For `huggingface.js` and `huggingface_hub` ones, the model id from the Hub should be used. --- **Note:** orthogonal to this PR but I realized that some URLs are also incorrect depending on the provider. --------- Co-authored-by: SBrandeis <[email protected]>
1 parent b9c2bb8 commit f931db2

25 files changed

+526
-56
lines changed

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const TEST_CASES: {
5252
inference: "",
5353
},
5454
languages: ["sh", "js", "py"],
55-
providers: ["hf-inference"],
55+
providers: ["hf-inference", "together"],
5656
opts: { streaming: true },
5757
},
5858
{
@@ -64,7 +64,7 @@ const TEST_CASES: {
6464
inference: "",
6565
},
6666
languages: ["sh", "js", "py"],
67-
providers: ["hf-inference"],
67+
providers: ["hf-inference", "fireworks-ai"],
6868
opts: { streaming: false },
6969
},
7070
{
@@ -76,7 +76,7 @@ const TEST_CASES: {
7676
inference: "",
7777
},
7878
languages: ["sh", "js", "py"],
79-
providers: ["hf-inference"],
79+
providers: ["hf-inference", "fireworks-ai"],
8080
opts: { streaming: true },
8181
},
8282
{
@@ -87,7 +87,7 @@ const TEST_CASES: {
8787
tags: [],
8888
inference: "",
8989
},
90-
providers: ["hf-inference"],
90+
providers: ["hf-inference", "fal-ai"],
9191
languages: ["sh", "js", "py"],
9292
},
9393
{
@@ -133,7 +133,8 @@ function generateInferenceSnippet(
133133
provider: SnippetInferenceProvider,
134134
opts?: Record<string, unknown>
135135
): InferenceSnippet[] {
136-
return GET_SNIPPET_FN[language](model, "api_token", provider, opts);
136+
const providerModelId = provider === "hf-inference" ? model.id : `<${provider} alias for ${model.id}>`;
137+
return GET_SNIPPET_FN[language](model, "api_token", provider, providerModelId, opts);
137138
}
138139

139140
async function getExpectedInferenceSnippet(

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.curl.together.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ curl 'https://router.huggingface.co/together/v1/chat/completions' \
22
-H 'Authorization: Bearer api_token' \
33
-H 'Content-Type: application/json' \
44
--data '{
5-
"model": "meta-llama/Llama-3.1-8B-Instruct",
5+
"model": "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
66
"messages": [
77
{
88
"role": "user",

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.together.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ const client = new OpenAI({
66
});
77

88
const chatCompletion = await client.chat.completions.create({
9-
model: "meta-llama/Llama-3.1-8B-Instruct",
9+
model: "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
1010
messages: [
1111
{
1212
role: "user",

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.together.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
]
1414

1515
completion = client.chat.completions.create(
16-
model="meta-llama/Llama-3.1-8B-Instruct",
16+
model="<together alias for meta-llama/Llama-3.1-8B-Instruct>",
1717
messages=messages,
1818
max_tokens=500,
1919
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
curl 'https://router.huggingface.co/together/v1/chat/completions' \
2+
-H 'Authorization: Bearer api_token' \
3+
-H 'Content-Type: application/json' \
4+
--data '{
5+
"model": "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
6+
"messages": [
7+
{
8+
"role": "user",
9+
"content": "What is the capital of France?"
10+
}
11+
],
12+
"max_tokens": 500,
13+
"stream": true
14+
}'
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { HfInference } from "@huggingface/inference";
2+
3+
const client = new HfInference("api_token");
4+
5+
let out = "";
6+
7+
const stream = client.chatCompletionStream({
8+
model: "meta-llama/Llama-3.1-8B-Instruct",
9+
messages: [
10+
{
11+
role: "user",
12+
content: "What is the capital of France?"
13+
}
14+
],
15+
provider: "together",
16+
max_tokens: 500,
17+
});
18+
19+
for await (const chunk of stream) {
20+
if (chunk.choices && chunk.choices.length > 0) {
21+
const newContent = chunk.choices[0].delta.content;
22+
out += newContent;
23+
console.log(newContent);
24+
}
25+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from huggingface_hub import InferenceClient
2+
3+
client = InferenceClient(
4+
provider="together",
5+
api_key="api_token"
6+
)
7+
8+
messages = [
9+
{
10+
"role": "user",
11+
"content": "What is the capital of France?"
12+
}
13+
]
14+
15+
stream = client.chat.completions.create(
16+
model="meta-llama/Llama-3.1-8B-Instruct",
17+
messages=messages,
18+
max_tokens=500,
19+
stream=True
20+
)
21+
22+
for chunk in stream:
23+
print(chunk.choices[0].delta.content, end="")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import { OpenAI } from "openai";
2+
3+
const client = new OpenAI({
4+
baseURL: "https://router.huggingface.co/together",
5+
apiKey: "api_token"
6+
});
7+
8+
let out = "";
9+
10+
const stream = await client.chat.completions.create({
11+
model: "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
12+
messages: [
13+
{
14+
role: "user",
15+
content: "What is the capital of France?"
16+
}
17+
],
18+
max_tokens: 500,
19+
stream: true,
20+
});
21+
22+
for await (const chunk of stream) {
23+
if (chunk.choices && chunk.choices.length > 0) {
24+
const newContent = chunk.choices[0].delta.content;
25+
out += newContent;
26+
console.log(newContent);
27+
}
28+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from openai import OpenAI
2+
3+
client = OpenAI(
4+
base_url="https://router.huggingface.co/together",
5+
api_key="api_token"
6+
)
7+
8+
messages = [
9+
{
10+
"role": "user",
11+
"content": "What is the capital of France?"
12+
}
13+
]
14+
15+
stream = client.chat.completions.create(
16+
model="<together alias for meta-llama/Llama-3.1-8B-Instruct>",
17+
messages=messages,
18+
max_tokens=500,
19+
stream=True
20+
)
21+
22+
for chunk in stream:
23+
print(chunk.choices[0].delta.content, end="")
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
curl 'https://router.huggingface.co/fireworks-ai/v1/chat/completions' \
2+
-H 'Authorization: Bearer api_token' \
3+
-H 'Content-Type: application/json' \
4+
--data '{
5+
"model": "<fireworks-ai alias for meta-llama/Llama-3.2-11B-Vision-Instruct>",
6+
"messages": [
7+
{
8+
"role": "user",
9+
"content": [
10+
{
11+
"type": "text",
12+
"text": "Describe this image in one sentence."
13+
},
14+
{
15+
"type": "image_url",
16+
"image_url": {
17+
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
18+
}
19+
}
20+
]
21+
}
22+
],
23+
"max_tokens": 500,
24+
"stream": false
25+
}'

0 commit comments

Comments
 (0)