diff --git a/packages/inference/README.md b/packages/inference/README.md index 88d0328883..83bfdc7f01 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -59,6 +59,7 @@ Currently, we support the following providers: - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) +- [Groq](https://groq.com) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. ```ts @@ -86,6 +87,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) +- [Groq supported models](https://console.groq.com/docs/models) - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 060cddffb7..0c8de60326 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -3,6 +3,7 @@ import * as Cerebras from "../providers/cerebras"; import * as Cohere from "../providers/cohere"; import * as FalAI from "../providers/fal-ai"; import * as Fireworks from "../providers/fireworks-ai"; +import * as Groq from "../providers/groq"; import * as HFInference from "../providers/hf-inference"; import * as Hyperbolic from "../providers/hyperbolic"; @@ -96,6 +97,10 @@ export const PROVIDERS: Record Groq model ID here: + * + * https://huggingface.co/api/partners/groq/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Groq and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Groq, please open an issue on the present repo + * and we will tag Groq team members. + * + * Thanks! + */ + +const GROQ_API_BASE_URL = "https://api.groq.com"; + +export class GroqTextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("groq", GROQ_API_BASE_URL); + } + + override makeRoute(): string { + return "/openai/v1/chat/completions"; + } +} + +export class GroqConversationalTask extends BaseConversationalTask { + constructor() { + super("groq", GROQ_API_BASE_URL); + } + + override makeRoute(): string { + return "/openai/v1/chat/completions"; + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 032bd41901..b1db14c567 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -43,6 +43,7 @@ export const INFERENCE_PROVIDERS = [ "cohere", "fal-ai", "fireworks-ai", + "groq", "hf-inference", "hyperbolic", "nebius", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 2105902bc0..60be17b186 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1751,4 +1751,55 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "Groq", + () => { + const client = new InferenceClient(env.HF_GROQ_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["groq"] = { + "meta-llama/Llama-3.3-70B-Instruct": { + hfModelId: "meta-llama/Llama-3.3-70B-Instruct", + providerId: "llama-3.3-70b-versatile", + status: "live", + task: "conversational", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "groq", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "groq", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + }, + TIMEOUT + ); });