Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import * as Cohere from "../providers/cohere";
import * as FalAI from "../providers/fal-ai";
import * as Fireworks from "../providers/fireworks-ai";
import * as HFInference from "../providers/hf-inference";

import * as FeatherlessAI from "../providers/featherless-ai";
import * as Hyperbolic from "../providers/hyperbolic";
import * as Nebius from "../providers/nebius";
import * as Novita from "../providers/novita";
Expand Down Expand Up @@ -62,6 +62,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
"text-to-video": new FalAI.FalAITextToVideoTask(),
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
},
"featherless-ai": {
conversational: new FeatherlessAI.FeatherlessAIConversationalTask(),
"text-generation": new FeatherlessAI.FeatherlessAITextGenerationTask(),
},
"hf-inference": {
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
conversational: new HFInference.HFInferenceConversationalTask(),
Expand Down
51 changes: 51 additions & 0 deletions packages/inference/src/providers/featherless-ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
import type { ChatCompletionOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
import { InferenceOutputError } from "../lib/InferenceOutputError";

interface FeatherlessAITextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
choices: Array<{
text: string;
finish_reason: TextGenerationOutputFinishReason;
seed: number;
logprobs: unknown;
index: number;
}>;
}

const FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";

export class FeatherlessAIConversationalTask extends BaseConversationalTask {
constructor() {
super("featherless-ai", FEATHERLESS_API_BASE_URL);
}
}

export class FeatherlessAITextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("featherless-ai", FEATHERLESS_API_BASE_URL);
}

override preparePayload(params: BodyParams): Record<string, unknown> {
return {
model: params.model,
...params.args,
...params.args.parameters,
prompt: params.args.inputs,
};
}

override async getResponse(response: FeatherlessAITextCompletionOutput): Promise<TextGenerationOutput> {
if (
typeof response === "object" &&
"choices" in response &&
Array.isArray(response?.choices) &&
typeof response?.model === "string"
) {
const completion = response.choices[0];
return {
generated_text: completion.text,
};
}
throw new InferenceOutputError("Expected Together text generation response format");
}
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export const INFERENCE_PROVIDERS = [
"cerebras",
"cohere",
"fal-ai",
"featherless-ai",
"fireworks-ai",
"hf-inference",
"hyperbolic",
Expand Down
63 changes: 63 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,69 @@ describe.concurrent("InferenceClient", () => {
TIMEOUT
);

describe.concurrent(
"Featherless",
() => {
HARDCODED_MODEL_ID_MAPPING['featherless-ai'] = {
"meta-llama/Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B",
"meta-llama/Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
};

it("chatCompletion", async () => {
const res = await chatCompletion({
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "featherless-ai",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
temperature: 0.1,
});

expect(res).toBeDefined();
expect(res.choices).toBeDefined();
expect(res.choices?.length).toBeGreaterThan(0);

if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = chatCompletionStream({
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "featherless-ai",
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
expect(out).toContain("2");
});

it("textGeneration", async () => {
const res = await textGeneration({
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
model: "meta-llama/Llama-3.1-8B",
provider: "featherless-ai",
inputs: "Paris is a city of ",
parameters: {
temperature: 0,
top_p: 0.01,
max_tokens: 10,
},
});
expect(res).toMatchObject({ generated_text: "2.2 million people, and it is the" });
});
},
TIMEOUT
);

describe.concurrent(
"Replicate",
() => {
Expand Down