diff --git a/packages/inference/README.md b/packages/inference/README.md index 664c224583..a13f7ab37a 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -47,6 +47,7 @@ Your access token should be kept private. If you need to protect it in front-end You can send inference requests to third-party providers with the inference client. Currently, we support the following providers: +- [AlphaNeural](https://alphaneural.ai) - [Fal.ai](https://fal.ai) - [Featherless AI](https://featherless.ai) - [Fireworks AI](https://fireworks.ai) @@ -88,6 +89,7 @@ When authenticated with a Hugging Face access token, the request is routed throu When authenticated with a third-party provider key, the request is made directly against that provider's inference API. Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here: +- [AlphaNeural supported models](https://huggingface.co/api/partners/alphaneural/models) - [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models) - [Featherless AI supported models](https://huggingface.co/api/partners/featherless-ai/models) - [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index a733a37b65..4304933c08 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,3 +1,4 @@ +import * as Alphaneural from "../providers/alphaneural.js"; import * as Baseten from "../providers/baseten.js"; import * as Clarifai from "../providers/clarifai.js"; import * as BlackForestLabs from "../providers/black-forest-labs.js"; @@ -60,6 +61,10 @@ import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from import { InferenceClientInputError } from "../errors.js"; export const PROVIDERS: Record>> = { + alphaneural: { + conversational: new Alphaneural.AlphaneuralConversationalTask(), + "text-generation": new Alphaneural.AlphaneuralTextGenerationTask(), + }, baseten: { conversational: new Baseten.BasetenConversationalTask(), }, diff --git a/packages/inference/src/providers/alphaneural.ts b/packages/inference/src/providers/alphaneural.ts new file mode 100644 index 0000000000..cedf384c97 --- /dev/null +++ b/packages/inference/src/providers/alphaneural.ts @@ -0,0 +1,73 @@ +/** + * See the registered mapping of HF model ID => AlphaNeural model ID here: + * + * https://huggingface.co/api/partners/alphaneural/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at AlphaNeural and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to AlphaNeural, please open an issue on the present repo + * and we will tag AlphaNeural team members. + * + * Thanks! + */ +import type { TextGenerationInput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks"; +import type { BodyParams } from "../types.js"; +import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper.js"; +import { omit } from "../utils/omit.js"; +import { InferenceClientProviderOutputError } from "../errors.js"; + +const ALPHANEURAL_API_BASE_URL = "https://proxy.alfnrl.io"; + +interface AlphaneuralTextCompletionOutput { + choices: Array<{ + text: string; + finish_reason: TextGenerationOutputFinishReason; + index: number; + }>; + model: string; +} + +export class AlphaneuralConversationalTask extends BaseConversationalTask { + constructor() { + super("alphaneural", ALPHANEURAL_API_BASE_URL); + } +} + +export class AlphaneuralTextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("alphaneural", ALPHANEURAL_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + return { + model: params.model, + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters + ? { + max_tokens: params.args.parameters.max_new_tokens, + ...omit(params.args.parameters, "max_new_tokens"), + } + : undefined), + prompt: params.args.inputs, + }; + } + + override async getResponse(response: AlphaneuralTextCompletionOutput): Promise { + if ( + typeof response === "object" && + "choices" in response && + Array.isArray(response?.choices) && + response.choices.length > 0 && + typeof response.choices[0]?.text === "string" + ) { + return { + generated_text: response.choices[0].text, + }; + } + throw new InferenceClientProviderOutputError("Received malformed response from AlphaNeural text generation API"); + } +} diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 6fcf5e0715..2c2f1f10d7 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -18,6 +18,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record< * Example: * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", */ + alphaneural: {}, baseten: {}, "black-forest-labs": {}, cerebras: {}, diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index e4165914d6..6daaa99082 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -45,6 +45,7 @@ export interface Options { export type InferenceTask = Exclude | "conversational"; export const INFERENCE_PROVIDERS = [ + "alphaneural", "baseten", "black-forest-labs", "cerebras", @@ -82,6 +83,7 @@ export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number]; * Whenever possible, InferenceProvider should == org namespace */ export const PROVIDERS_HUB_ORGS: Record = { + alphaneural: "AlphaNeural", baseten: "baseten", "black-forest-labs": "black-forest-labs", cerebras: "cerebras", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 00f9d9953f..c2b7994715 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -22,7 +22,7 @@ if (!env.HF_TOKEN) { console.warn("Set HF_TOKEN in the env to run the tests for better rate limits"); } -describe.skip("InferenceClient", () => { + describe.skip("InferenceClient", () => { // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error. describe("backward compatibility", () => { @@ -1641,6 +1641,82 @@ describe.skip("InferenceClient", () => { }); }); + describe.concurrent( + "AlphaNeural", + () => { + const client = new InferenceClient(env.HF_ALPHANEURAL_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["alphaneural"] = { + "qwen/qwen3": { + provider: "alphaneural", + hfModelId: "qwen/qwen3", + providerId: "qwen3", + status: "live", + task: "conversational", + }, + "Qwen/Qwen3-8B": { + provider: "alphaneural", + hfModelId: "Qwen/Qwen3-8B", + providerId: "qwen3", + status: "live", + task: "text-generation", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "qwen/qwen3", + provider: "alphaneural", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "qwen/qwen3", + provider: "alphaneural", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + + it("textGeneration", async () => { + const res = await textGeneration({ + accessToken: env.HF_ALPHANEURAL_KEY ?? "dummy", + model: "Qwen/Qwen3-8B", + provider: "alphaneural", + inputs: "The capital of France is", + parameters: { + temperature: 0, + max_tokens: 10, + }, + }); + expect(res).toBeDefined(); + expect(res.generated_text).toBeDefined(); + expect(typeof res.generated_text).toBe("string"); + }); + }, + TIMEOUT + ); + describe.concurrent( "Fireworks", () => {