Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ Currently, we support the following providers:
- [Groq](https://groq.com)
- [Wavespeed.ai](https://wavespeed.ai/)
- [Z.ai](https://z.ai/)
- [TextCLF](https://textclf.com)


To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers.

Expand Down Expand Up @@ -109,6 +111,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [DeepInfra supported models](https://huggingface.co/api/partners/deepinfra/models)
- [Groq supported models](https://console.groq.com/docs/models)
- [TextCLF supported models](https://huggingface.co/api/partners/textclf/models)
- [Novita AI supported models](https://huggingface.co/api/partners/novita/models)
- [Wavespeed.ai supported models](https://huggingface.co/api/partners/wavespeed/models)
- [Z.ai supported models](https://huggingface.co/api/partners/zai-org/models)
Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import * as Nvidia from "../providers/nvidia.js";
import * as OpenAI from "../providers/openai.js";
import * as OvhCloud from "../providers/ovhcloud.js";
import * as PublicAI from "../providers/publicai.js";
import * as TextCLF from "../providers/textclf.js";
import type {
AudioClassificationTaskHelper,
AudioToAudioTaskHelper,
Expand Down Expand Up @@ -182,6 +183,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
"text-generation": new Scaleway.ScalewayTextGenerationTask(),
"feature-extraction": new Scaleway.ScalewayFeatureExtractionTask(),
},
textclf: {
conversational: new TextCLF.TextCLFConversationalTask(),
"text-generation": new TextCLF.TextCLFTextGenerationTask(),
},
together: {
"text-to-image": new Together.TogetherTextToImageTask(),
conversational: new Together.TogetherConversationalTask(),
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
replicate: {},
sambanova: {},
scaleway: {},
textclf: {},
together: {},
wavespeed: {},
"zai-org": {},
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export const INFERENCE_PROVIDERS = [
"replicate",
"sambanova",
"scaleway",
"textclf",
"together",
"wavespeed",
"zai-org",
Expand Down Expand Up @@ -106,6 +107,7 @@ export const PROVIDERS_HUB_ORGS: Record<InferenceProvider, string> = {
replicate: "replicate",
sambanova: "sambanovasystems",
scaleway: "scaleway",
textclf: "textclf",
together: "togethercomputer",
wavespeed: "wavespeed",
"zai-org": "zai-org",
Expand Down
52 changes: 52 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2208,6 +2208,58 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT,
);
describe.concurrent(
"TextCLF",
() => {
const client = new InferenceClient(env.HF_TEXTCLF_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["textclf"] = {
"meta-llama/Llama-3.1-8B-Instruct": {
provider: "groq",
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
providerId: "meta-llama/Llama-3.1-8B-Instruct",
status: "live",
task: "conversational",
},
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "textclf",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "textclf",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT,
);
describe.concurrent(
"ZAI",
() => {
Expand Down