Skip to content

Commit 9452dc2

Browse files
committed
Merge branch 'main' into feat/novita
2 parents d82fc7b + c1a8dfc commit 9452dc2

File tree

11 files changed

+180
-15
lines changed

11 files changed

+180
-15
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
HF_SAMBANOVA_KEY: dummy
4747
HF_TOGETHER_KEY: dummy
4848
HF_NOVITA_KEY: dummy
49+
HF_FIREWORKS_KEY: dummy
4950

5051
browser:
5152
runs-on: ubuntu-latest
@@ -87,6 +88,7 @@ jobs:
8788
HF_SAMBANOVA_KEY: dummy
8889
HF_TOGETHER_KEY: dummy
8990
HF_NOVITA_KEY: dummy
91+
HF_FIREWORKS_KEY: dummy
9092

9193
e2e:
9294
runs-on: ubuntu-latest
@@ -154,4 +156,5 @@ jobs:
154156
HF_REPLICATE_KEY: dummy
155157
HF_SAMBANOVA_KEY: dummy
156158
HF_TOGETHER_KEY: dummy
157-
HF_NOVITA_KEY: dummy
159+
HF_NOVITA_KEY: dummy
160+
HF_FIREWORKS_KEY: dummy

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ You can run our packages with vanilla JS, without any bundler, by using a CDN or
9696

9797
```html
9898
<script type="module">
99-
import { HfInference } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@3.2.0/+esm';
99+
import { HfInference } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@3.3.0/+esm';
100100
import { createRepo, commit, deleteRepo, listFiles } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]/+esm";
101101
</script>
102102
```

packages/inference/README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ You can send inference requests to third-party providers with the inference clie
4848

4949
Currently, we support the following providers:
5050
- [Fal.ai](https://fal.ai)
51+
- [Fireworks AI](https://fireworks.ai)
5152
- [Replicate](https://replicate.com)
5253
- [Sambanova](https://sambanova.ai)
5354
- [Together](https://together.xyz)
@@ -69,10 +70,11 @@ When authenticated with a Hugging Face access token, the request is routed throu
6970
When authenticated with a third-party provider key, the request is made directly against that provider's inference API.
7071

7172
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
72-
- [Fal.ai supported models](./src/providers/fal-ai.ts)
73-
- [Replicate supported models](./src/providers/replicate.ts)
74-
- [Sambanova supported models](./src/providers/sambanova.ts)
75-
- [Together supported models](./src/providers/together.ts)
73+
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
74+
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
75+
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
76+
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
77+
- [Together supported models](https://huggingface.co/api/partners/together/models)
7678
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
7779

7880
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@huggingface/inference",
3-
"version": "3.2.0",
3+
"version": "3.3.0",
44
"packageManager": "[email protected]",
55
"license": "MIT",
66
"author": "Tim Mikeladze <[email protected]>",

packages/inference/src/lib/getProviderModelId.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ export async function getProviderModelId(
3030
options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
3131

3232
// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
33-
if (HARDCODED_MODEL_ID_MAPPING[params.model]) {
34-
return HARDCODED_MODEL_ID_MAPPING[params.model];
33+
if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
34+
return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
3535
}
3636

3737
let inferenceProviderMapping: InferenceProviderMapping | null;

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { REPLICATE_API_BASE_URL } from "../providers/replicate";
44
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
55
import { TOGETHER_API_BASE_URL } from "../providers/together";
66
import { NOVITA_API_BASE_URL } from "../providers/novita";
7+
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
78
import type { InferenceProvider } from "../types";
89
import type { InferenceTask, Options, RequestArgs } from "../types";
910
import { isUrl } from "./isUrl";
@@ -209,11 +210,20 @@ function makeUrl(params: {
209210
}
210211
return baseUrl;
211212
}
213+
214+
case "fireworks-ai": {
215+
const baseUrl = shouldProxy
216+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
217+
: FIREWORKS_AI_API_BASE_URL;
218+
if (params.taskHint === "text-generation" && params.chatCompletion) {
219+
return `${baseUrl}/v1/chat/completions`;
220+
}
221+
return baseUrl;
222+
}
212223
case "novita": {
213224
const baseUrl = shouldProxy
214225
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
215226
: NOVITA_API_BASE_URL;
216-
/// Novita API matches OpenAI-like APIs: model is defined in the request body
217227
if (params.taskHint === "text-generation") {
218228
if (params.chatCompletion) {
219229
return `${baseUrl}/chat/completions`;
Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
1-
import type { ModelId } from "../types";
1+
import type { InferenceProvider } from "../types";
2+
import { type ModelId } from "../types";
23

34
type ProviderId = string;
4-
55
/**
66
* If you want to try to run inference for a new model locally before it's registered on huggingface.co
77
* for a given Inference Provider,
88
* you can add it to the following dictionary, for dev purposes.
9+
*
10+
* We also inject into this dictionary from tests.
911
*/
10-
export const HARDCODED_MODEL_ID_MAPPING: Record<ModelId, ProviderId> = {
12+
export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelId, ProviderId>> = {
1113
/**
1214
* "HF model ID" => "Model ID on Inference Provider's side"
15+
*
16+
* Example:
17+
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1318
*/
14-
// "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
19+
"fal-ai": {},
20+
"fireworks-ai": {},
21+
"hf-inference": {},
22+
replicate: {},
23+
sambanova: {},
24+
together: {},
1525
};
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Fireworks model ID here:
5+
*
6+
* https://huggingface.co/api/partners/fireworks/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Fireworks and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Fireworks, please open an issue on the present repo
15+
* and we will tag Fireworks team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/types.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,16 @@ export interface Options {
4444

4545
export type InferenceTask = Exclude<PipelineType, "other">;
4646

47-
export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference", "novita"] as const;
47+
export const INFERENCE_PROVIDERS = [
48+
"fal-ai",
49+
"fireworks-ai",
50+
"hf-inference",
51+
"replicate",
52+
"sambanova",
53+
"together",
54+
"novita",
55+
] as const;
56+
4857
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
4958

5059
export interface BaseArgs {

packages/inference/test/HfInference.spec.ts

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { chatCompletion, HfInference } from "../src";
66
import { textToVideo } from "../src/tasks/cv/textToVideo";
77
import { readTestFile } from "./test-files";
88
import "./vcr";
9+
import { HARDCODED_MODEL_ID_MAPPING } from "../src/providers/consts";
910

1011
const TIMEOUT = 60000 * 3;
1112
const env = import.meta.env;
@@ -1078,6 +1079,53 @@ describe.concurrent("HfInference", () => {
10781079
});
10791080
});
10801081

1082+
describe.concurrent(
1083+
"Fireworks",
1084+
() => {
1085+
const client = new HfInference(env.HF_FIREWORKS_KEY);
1086+
1087+
HARDCODED_MODEL_ID_MAPPING["fireworks-ai"] = {
1088+
"deepseek-ai/DeepSeek-R1": "accounts/fireworks/models/deepseek-r1",
1089+
};
1090+
1091+
it("chatCompletion", async () => {
1092+
const res = await client.chatCompletion({
1093+
model: "deepseek-ai/DeepSeek-R1",
1094+
provider: "fireworks-ai",
1095+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1096+
});
1097+
if (res.choices && res.choices.length > 0) {
1098+
const completion = res.choices[0].message?.content;
1099+
expect(completion).toContain("two");
1100+
}
1101+
});
1102+
1103+
it("chatCompletion stream", async () => {
1104+
const stream = client.chatCompletionStream({
1105+
model: "deepseek-ai/DeepSeek-R1",
1106+
provider: "fireworks-ai",
1107+
messages: [{ role: "user", content: "Say this is a test" }],
1108+
stream: true,
1109+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1110+
1111+
let fullResponse = "";
1112+
for await (const chunk of stream) {
1113+
if (chunk.choices && chunk.choices.length > 0) {
1114+
const content = chunk.choices[0].delta?.content;
1115+
if (content) {
1116+
fullResponse += content;
1117+
}
1118+
}
1119+
}
1120+
1121+
// Verify we got a meaningful response
1122+
expect(fullResponse).toBeTruthy();
1123+
expect(fullResponse.length).toBeGreaterThan(0);
1124+
});
1125+
},
1126+
TIMEOUT
1127+
);
1128+
10811129
describe.concurrent(
10821130
"Novita",
10831131
() => {

0 commit comments

Comments
 (0)