Skip to content

Commit 0f82415

Browse files
Merge branch 'main' into provider-featherless-ai
2 parents 591b6ae + d119c63 commit 0f82415

File tree

6 files changed

+100
-0
lines changed

6 files changed

+100
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Currently, we support the following providers:
6060
- [Blackforestlabs](https://blackforestlabs.ai)
6161
- [Cohere](https://cohere.com)
6262
- [Cerebras](https://cerebras.ai/)
63+
- [Groq](https://groq.com)
6364

6465
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
6566
```ts
@@ -88,6 +89,7 @@ Only a subset of models are supported when requesting third-party providers. You
8889
- [Together supported models](https://huggingface.co/api/partners/together/models)
8990
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
9091
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
92+
- [Groq supported models](https://console.groq.com/docs/models)
9193
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
9294

9395
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import * as Cohere from "../providers/cohere";
44
import * as FalAI from "../providers/fal-ai";
55
import * as FeatherlessAI from "../providers/featherless-ai";
66
import * as Fireworks from "../providers/fireworks-ai";
7+
import * as Groq from "../providers/groq";
78
import * as HFInference from "../providers/hf-inference";
89
import * as Hyperbolic from "../providers/hyperbolic";
910
import * as Nebius from "../providers/nebius";
@@ -100,6 +101,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
100101
"fireworks-ai": {
101102
conversational: new Fireworks.FireworksConversationalTask(),
102103
},
104+
groq: {
105+
conversational: new Groq.GroqConversationalTask(),
106+
"text-generation": new Groq.GroqTextGenerationTask(),
107+
},
103108
hyperbolic: {
104109
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(),
105110
conversational: new Hyperbolic.HyperbolicConversationalTask(),

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
2525
"fal-ai": {},
2626
"featherless-ai": {},
2727
"fireworks-ai": {},
28+
groq: {},
2829
"hf-inference": {},
2930
hyperbolic: {},
3031
nebius: {},
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Groq model ID here:
5+
*
6+
* https://huggingface.co/api/partners/groq/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Groq and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Groq, please open an issue on the present repo
15+
* and we will tag Groq team members.
16+
*
17+
* Thanks!
18+
*/
19+
20+
const GROQ_API_BASE_URL = "https://api.groq.com";
21+
22+
export class GroqTextGenerationTask extends BaseTextGenerationTask {
23+
constructor() {
24+
super("groq", GROQ_API_BASE_URL);
25+
}
26+
27+
override makeRoute(): string {
28+
return "/openai/v1/chat/completions";
29+
}
30+
}
31+
32+
export class GroqConversationalTask extends BaseConversationalTask {
33+
constructor() {
34+
super("groq", GROQ_API_BASE_URL);
35+
}
36+
37+
override makeRoute(): string {
38+
return "/openai/v1/chat/completions";
39+
}
40+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export const INFERENCE_PROVIDERS = [
4444
"fal-ai",
4545
"featherless-ai",
4646
"fireworks-ai",
47+
"groq",
4748
"hf-inference",
4849
"hyperbolic",
4950
"nebius",

packages/inference/test/InferenceClient.spec.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,4 +1824,55 @@ describe.skip("InferenceClient", () => {
18241824
},
18251825
TIMEOUT
18261826
);
1827+
describe.concurrent(
1828+
"Groq",
1829+
() => {
1830+
const client = new InferenceClient(env.HF_GROQ_KEY ?? "dummy");
1831+
1832+
HARDCODED_MODEL_INFERENCE_MAPPING["groq"] = {
1833+
"meta-llama/Llama-3.3-70B-Instruct": {
1834+
hfModelId: "meta-llama/Llama-3.3-70B-Instruct",
1835+
providerId: "llama-3.3-70b-versatile",
1836+
status: "live",
1837+
task: "conversational",
1838+
},
1839+
};
1840+
1841+
it("chatCompletion", async () => {
1842+
const res = await client.chatCompletion({
1843+
model: "meta-llama/Llama-3.3-70B-Instruct",
1844+
provider: "groq",
1845+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1846+
});
1847+
if (res.choices && res.choices.length > 0) {
1848+
const completion = res.choices[0].message?.content;
1849+
expect(completion).toContain("two");
1850+
}
1851+
});
1852+
1853+
it("chatCompletion stream", async () => {
1854+
const stream = client.chatCompletionStream({
1855+
model: "meta-llama/Llama-3.3-70B-Instruct",
1856+
provider: "groq",
1857+
messages: [{ role: "user", content: "Say 'this is a test'" }],
1858+
stream: true,
1859+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1860+
1861+
let fullResponse = "";
1862+
for await (const chunk of stream) {
1863+
if (chunk.choices && chunk.choices.length > 0) {
1864+
const content = chunk.choices[0].delta?.content;
1865+
if (content) {
1866+
fullResponse += content;
1867+
}
1868+
}
1869+
}
1870+
1871+
// Verify we got a meaningful response
1872+
expect(fullResponse).toBeTruthy();
1873+
expect(fullResponse.length).toBeGreaterThan(0);
1874+
});
1875+
},
1876+
TIMEOUT
1877+
);
18271878
});

0 commit comments

Comments
 (0)