Skip to content

Commit d0674a3

Browse files
committed
Fireworks AI Conversational Models
1 parent 0a690a1 commit d0674a3

File tree

8 files changed

+96
-4
lines changed

8 files changed

+96
-4
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ jobs:
4545
HF_REPLICATE_KEY: dummy
4646
HF_SAMBANOVA_KEY: dummy
4747
HF_TOGETHER_KEY: dummy
48+
HF_FIREWORKS_AI_KEY: dummy
4849

4950
browser:
5051
runs-on: ubuntu-latest
@@ -85,6 +86,7 @@ jobs:
8586
HF_REPLICATE_KEY: dummy
8687
HF_SAMBANOVA_KEY: dummy
8788
HF_TOGETHER_KEY: dummy
89+
HF_FIREWORKS_AI_KEY: dummy
8890

8991
e2e:
9092
runs-on: ubuntu-latest
@@ -152,3 +154,4 @@ jobs:
152154
HF_REPLICATE_KEY: dummy
153155
HF_SAMBANOVA_KEY: dummy
154156
HF_TOGETHER_KEY: dummy
157+
HF_FIREWORKS_AI_KEY: dummy

packages/inference/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Your access token should be kept private. If you need to protect it in front-end
4646

4747
You can send inference requests to third-party providers with the inference client.
4848

49-
Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).
49+
Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz), [Sambanova](https://sambanova.ai), and [Fireworks AI](https://fireworks.ai).
5050

5151
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
5252
```ts
@@ -68,6 +68,7 @@ Only a subset of models are supported when requesting third-party providers. You
6868
- [Replicate supported models](./src/providers/replicate.ts)
6969
- [Sambanova supported models](./src/providers/sambanova.ts)
7070
- [Together supported models](./src/providers/together.ts)
71+
- [Fireworks AI supported models](./src/providers/fireworks-ai.ts)
7172
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
7273

7374
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
55
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate";
66
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova";
77
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together";
8+
export { FIREWORKS_AI_SUPPORTED_MODEL_IDS } from "./providers/fireworks-ai";
89
export * from "./types";
910
export * from "./tasks";

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fa
44
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
55
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova";
66
import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together";
7+
import { FIREWORKS_AI_API_BASE_URL, FIREWORKS_AI_SUPPORTED_MODEL_IDS } from "../providers/fireworks-ai";
78
import type { InferenceProvider } from "../types";
89
import type { InferenceTask, Options, RequestArgs } from "../types";
910
import { isUrl } from "./isUrl";
@@ -177,6 +178,8 @@ function mapModel(params: {
177178
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
178179
case "together":
179180
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
181+
case "fireworks-ai":
182+
return FIREWORKS_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
180183
}
181184
})();
182185

@@ -243,6 +246,15 @@ function makeUrl(params: {
243246
}
244247
return baseUrl;
245248
}
249+
case "fireworks-ai": {
250+
const baseUrl = shouldProxy
251+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
252+
: FIREWORKS_AI_API_BASE_URL;
253+
if (params.taskHint === "text-generation" && params.chatCompletion) {
254+
return `${baseUrl}/v1/chat/completions`;
255+
}
256+
return baseUrl;
257+
}
246258
default: {
247259
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
248260
const url = params.forceTask
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import type { ProviderMapping } from "./types";
2+
3+
export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";
4+
5+
type FireworksAiId = string; // you can make this more specific if needed
6+
7+
/**
8+
* Mapping of HuggingFace model IDs to Fireworks model IDs
9+
*/
10+
export const FIREWORKS_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FireworksAiId> = {
11+
// Chat/Conversational models
12+
"conversational": {
13+
"meta-llama/Llama-3.3-70B-Instruct": "accounts/fireworks/models/llama-v3p3-70b-instruct",
14+
"meta-llama/Llama-3.2-3B-Instruct": "accounts/fireworks/models/llama-v3p2-3b-instruct",
15+
"meta-llama/Llama-3.1-8B-Instruct": "accounts/fireworks/models/llama-v3p1-8b-instruct",
16+
"mistralai/Mixtral-8x7B-Instruct-v0.1": "accounts/fireworks/models/mixtral-8x7b-instruct",
17+
"deepseek-ai/DeepSeek-R1": "accounts/fireworks/models/deepseek-r1",
18+
"deepseek-ai/DeepSeek-V3": "accounts/fireworks/models/deepseek-v3",
19+
"meta-llama/Llama-3.2-90B-Vision-Instruct": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
20+
"meta-llama/Llama-3.2-11B-Vision-Instruct": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
21+
"meta-llama/Meta-Llama-3-70B-Instruct": "accounts/fireworks/models/llama-v3-70b-instruct",
22+
"meta-llama/Meta-Llama-3-8B-Instruct": "accounts/fireworks/models/llama-v3-8b-instruct",
23+
"mistralai/Mistral-Small-24B-Instruct-2501": "accounts/fireworks/models/mistral-small-24b-instruct-2501",
24+
"mistralai/Mixtral-8x22B-Instruct-v0.1": "accounts/fireworks/models/mixtral-8x22b-instruct",
25+
"Qwen/QWQ-32B-Preview": "accounts/fireworks/models/qwen-qwq-32b-preview",
26+
"Qwen/Qwen2.5-72B-Instruct": "accounts/fireworks/models/qwen2p5-72b-instruct",
27+
"Qwen/Qwen2.5-Coder-32B-Instruct": "accounts/fireworks/models/qwen2p5-coder-32b-instruct",
28+
"Qwen/Qwen2-VL-72B-Instruct": "accounts/fireworks/models/qwen2-vl-72b-instruct",
29+
"Gryphe/MythoMax-L2-13b": "accounts/fireworks/models/mythomax-l2-13b",
30+
"microsoft/Phi-3.5-vision-instruct": "accounts/fireworks/models/phi-3-vision-128k-instruct"
31+
},
32+
};

packages/inference/src/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ export interface Options {
4545

4646
export type InferenceTask = Exclude<PipelineType, "other">;
4747

48-
export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const;
48+
export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "fireworks-ai", "hf-inference"] as const;
4949
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
5050

5151
export interface BaseArgs {

packages/inference/test/HfInference.spec.ts

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,4 +1070,47 @@ describe.concurrent("HfInference", () => {
10701070
);
10711071
});
10721072
});
1073-
});
1073+
1074+
describe.concurrent(
1075+
"Fireworks",
1076+
() => {
1077+
const client = new HfInference(env.HF_FIREWORKS_AI_KEY);
1078+
1079+
it("chatCompletion", async () => {
1080+
const res = await client.chatCompletion({
1081+
model: "deepseek-ai/DeepSeek-R1",
1082+
provider: "fireworks-ai",
1083+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1084+
});
1085+
if (res.choices && res.choices.length > 0) {
1086+
const completion = res.choices[0].message?.content;
1087+
expect(completion).toContain("two");
1088+
}
1089+
});
1090+
1091+
it("chatCompletion stream", async () => {
1092+
const stream = client.chatCompletionStream({
1093+
model: "deepseek-ai/DeepSeek-R1",
1094+
provider: "fireworks-ai",
1095+
messages: [{ role: "user", content: "Say this is a test" }],
1096+
stream: true
1097+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1098+
1099+
let fullResponse = "";
1100+
for await (const chunk of stream) {
1101+
if (chunk.choices && chunk.choices.length > 0) {
1102+
const content = chunk.choices[0].delta?.content;
1103+
if (content) {
1104+
fullResponse += content;
1105+
}
1106+
}
1107+
}
1108+
1109+
// Verify we got a meaningful response
1110+
expect(fullResponse).toBeTruthy();
1111+
expect(fullResponse.length).toBeGreaterThan(0);
1112+
});
1113+
},
1114+
TIMEOUT
1115+
);
1116+
});

packages/tasks/src/inference-providers.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "sambanova", "together"] as const;
1+
export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "sambanova", "together", "fireworks-ai"] as const;
22

33
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
44

0 commit comments

Comments
 (0)