diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 2f8653f419..7f3844a74a 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -54,6 +54,7 @@ import type { import * as Replicate from "../providers/replicate.js"; import * as Sambanova from "../providers/sambanova.js"; import * as Scaleway from "../providers/scaleway.js"; +import * as Systalyze from "../providers/systalyze.js"; import * as Together from "../providers/together.js"; import * as Wavespeed from "../providers/wavespeed.js"; import * as Zai from "../providers/zai-org.js"; @@ -177,6 +178,10 @@ export const PROVIDERS: Record Systalyze model ID here: + * + * https://huggingface.co/api/partners/systalyze/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Systalyze and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Systalyze, please open an issue on the present repo + * and we will tag Systalyze team members. + * + * Thanks! + */ + +import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks"; +import { InferenceClientProviderOutputError } from "../errors.js"; +import type { BaseArgs, BodyParams } from "../types.js"; +import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper.js"; + +const SYSTALYZE_API_BASE_URL = "https://api.systalyze.com"; + +export class SystalyzeConversationalTask extends BaseConversationalTask { + constructor() { + super("systalyze", SYSTALYZE_API_BASE_URL); + } +} + +export class SystalyzeTextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("systalyze", SYSTALYZE_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + const { inputs, parameters, ...rest } = params.args; + return { + ...rest, + model: params.model, + prompt: inputs, + ...(parameters?.max_new_tokens !== undefined && { max_tokens: parameters.max_new_tokens }), + ...(parameters?.temperature !== undefined && { temperature: parameters.temperature }), + ...(parameters?.top_p !== undefined && { top_p: parameters.top_p }), + ...(parameters?.repetition_penalty !== undefined && { repetition_penalty: parameters.repetition_penalty }), + ...(parameters?.stop !== undefined && { stop: parameters.stop }), + }; + } + + override async getResponse(response: unknown): Promise { + const r = response as { choices?: Array<{ text?: string }> }; + if (typeof r?.choices?.[0]?.text === "string") { + return { generated_text: r.choices[0].text }; + } + throw new InferenceClientProviderOutputError("Malformed response from Systalyze completions API"); + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index cba3a364e1..85efc677a3 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -66,6 +66,7 @@ export const INFERENCE_PROVIDERS = [ "replicate", "sambanova", "scaleway", + "systalyze", "together", "wavespeed", "zai-org", @@ -104,6 +105,7 @@ export const PROVIDERS_HUB_ORGS: Record = { replicate: "replicate", sambanova: "sambanovasystems", scaleway: "scaleway", + systalyze: "systalyze", together: "togethercomputer", wavespeed: "wavespeed", "zai-org": "zai-org", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 50e3ba49fa..8609aeea3c 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1626,6 +1626,80 @@ describe.skip("InferenceClient", () => { TIMEOUT, ); + describe.concurrent( + "Systalyze", + () => { + const client = new InferenceClient(env.HF_SYSTALYZE_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["systalyze"] = { + "meta-llama/Llama-3.1-8B-Instruct": { + provider: "systalyze", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + providerId: "meta-llama/Llama-3.1-8B-Instruct", + status: "live", + task: "conversational", + }, + "meta-llama/Llama-3.1-8B-Instruct": { + provider: "systalyze", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + providerId: "meta-llama/Llama-3.1-8B-Instruct", + status: "live", + task: "text-generation", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "systalyze", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "systalyze", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + + it("textGeneration", async () => { + const res = await client.textGeneration({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "systalyze", + inputs: "The capital of France is", + parameters: { + max_new_tokens: 10, + }, + }); + expect(res).toBeDefined(); + expect(typeof res.generated_text).toBe("string"); + expect(res.generated_text.length).toBeGreaterThan(0); + }); + }, + TIMEOUT, + ); + describe.concurrent("3rd party providers", () => { it("chatCompletion - fails with unsupported model", async () => { expect(