Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_FIREWORKS_KEY: dummy

browser:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -85,6 +86,7 @@ jobs:
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_FIREWORKS_KEY: dummy

e2e:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -152,3 +154,4 @@ jobs:
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_FIREWORKS_KEY: dummy
3 changes: 2 additions & 1 deletion packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Your access token should be kept private. If you need to protect it in front-end

You can send inference requests to third-party providers with the inference client.

Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).
Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz), [Sambanova](https://sambanova.ai), and [Fireworks AI](https://fireworks.ai).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz), [Sambanova](https://sambanova.ai), and [Fireworks AI](https://fireworks.ai).
Currently, we support the following providers:
- [Fal.ai](https://fal.ai)
- [Fireworks AI](https://fireworks.ai)
- [Replicate](https://replicate.com)
- [Sambanova](https://sambanova.ai)
- [Together](https://together.xyz)

given this list is expected to expand quite a lot, I'd turn the sentence into a bullet list + order alphabetically

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with your suggestions, can you push them directly to this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 832a209


To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
Expand All @@ -68,6 +68,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Replicate supported models](./src/providers/replicate.ts)
- [Sambanova supported models](./src/providers/sambanova.ts)
- [Together supported models](./src/providers/together.ts)
- [Fireworks AI supported models](./src/providers/fireworks-ai.ts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The links in this list are not correct anymore. Must be a link to the API route now:

- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

(+ I would order alphabetically and put HF Inference in last position)

- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/src/lib/getProviderModelId.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ export async function getProviderModelId(
options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;

// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
if (HARDCODED_MODEL_ID_MAPPING[params.model]) {
return HARDCODED_MODEL_ID_MAPPING[params.model];
if (HARDCODED_MODEL_ID_MAPPING[params.provider]?.[params.model]) {
return HARDCODED_MODEL_ID_MAPPING[params.provider][params.model];
}

let inferenceProviderMapping: InferenceProviderMapping | null;
Expand Down
10 changes: 10 additions & 0 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL } from "../providers/together";
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
Expand Down Expand Up @@ -208,6 +209,15 @@ function makeUrl(params: {
}
return baseUrl;
}
case "fireworks-ai": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: FIREWORKS_AI_API_BASE_URL;
if (params.taskHint === "text-generation" && params.chatCompletion) {
return `${baseUrl}/v1/chat/completions`;
}
return baseUrl;
}
default: {
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
const url = params.forceTask
Expand Down
18 changes: 14 additions & 4 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import type { ModelId } from "../types";
import type { InferenceProvider } from "../types";
import { type ModelId } from "../types";

type ProviderId = string;

/**
* If you want to try to run inference for a new model locally before it's registered on huggingface.co
* for a given Inference Provider,
* you can add it to the following dictionary, for dev purposes.
*
* We also inject into this dictionary from tests.
*/
export const HARDCODED_MODEL_ID_MAPPING: Record<ModelId, ProviderId> = {
export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelId, ProviderId>> = {
/**
* "HF model ID" => "Model ID on Inference Provider's side"
*
* Example:
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
// "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
replicate: {},
sambanova: {},
together: {},
};
18 changes: 18 additions & 0 deletions packages/inference/src/providers/fireworks-ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai/inference";

/**
* See the registered mapping of HF model ID => Fireworks model ID here:
*
* https://huggingface.co/api/partners/fireworks/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Fireworks and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Fireworks, please open an issue on the present repo
* and we will tag Fireworks team members.
*
* Thanks!
*/
9 changes: 8 additions & 1 deletion packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ export interface Options {

export type InferenceTask = Exclude<PipelineType, "other">;

export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const;
export const INFERENCE_PROVIDERS = [
"fal-ai",
"fireworks-ai",
"hf-inference",
"replicate",
"sambanova",
"together",
] as const;
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

export interface BaseArgs {
Expand Down
48 changes: 48 additions & 0 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { chatCompletion, HfInference } from "../src";
import { textToVideo } from "../src/tasks/cv/textToVideo";
import { readTestFile } from "./test-files";
import "./vcr";
import { HARDCODED_MODEL_ID_MAPPING } from "../src/providers/consts";

const TIMEOUT = 60000 * 3;
const env = import.meta.env;
Expand Down Expand Up @@ -1077,4 +1078,51 @@ describe.concurrent("HfInference", () => {
);
});
});

describe.concurrent(
"Fireworks",
() => {
const client = new HfInference(env.HF_FIREWORKS_KEY);

HARDCODED_MODEL_ID_MAPPING["fireworks-ai"] = {
"deepseek-ai/DeepSeek-R1": "accounts/fireworks/models/deepseek-r1",
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "deepseek-ai/DeepSeek-R1",
provider: "fireworks-ai",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "deepseek-ai/DeepSeek-R1",
provider: "fireworks-ai",
messages: [{ role: "user", content: "Say this is a test" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT
);
});
65 changes: 65 additions & 0 deletions packages/inference/test/tapes.json

Large diffs are not rendered by default.