Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import type {
import * as Replicate from "../providers/replicate.js";
import * as Sambanova from "../providers/sambanova.js";
import * as Scaleway from "../providers/scaleway.js";
import * as Systalyze from "../providers/systalyze.js";
import * as Together from "../providers/together.js";
import * as Wavespeed from "../providers/wavespeed.js";
import * as Zai from "../providers/zai-org.js";
Expand Down Expand Up @@ -177,6 +178,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
"text-generation": new Scaleway.ScalewayTextGenerationTask(),
"feature-extraction": new Scaleway.ScalewayFeatureExtractionTask(),
},
systalyze: {
conversational: new Systalyze.SystalyzeConversationalTask(),
"text-generation": new Systalyze.SystalyzeTextGenerationTask(),
},
together: {
"text-to-image": new Together.TogetherTextToImageTask(),
conversational: new Together.TogetherConversationalTask(),
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
replicate: {},
sambanova: {},
scaleway: {},
systalyze: {},
together: {},
wavespeed: {},
"zai-org": {},
Expand Down
57 changes: 57 additions & 0 deletions packages/inference/src/providers/systalyze.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/**
* See the registered mapping of HF model ID => Systalyze model ID here:
*
* https://huggingface.co/api/partners/systalyze/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Systalyze and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Systalyze, please open an issue on the present repo
* and we will tag Systalyze team members.
*
* Thanks!
*/

import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
import { InferenceClientProviderOutputError } from "../errors.js";
import type { BaseArgs, BodyParams } from "../types.js";
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper.js";

const SYSTALYZE_API_BASE_URL = "https://api.systalyze.com";

export class SystalyzeConversationalTask extends BaseConversationalTask {
constructor() {
super("systalyze", SYSTALYZE_API_BASE_URL);
}
}

export class SystalyzeTextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("systalyze", SYSTALYZE_API_BASE_URL);
}

override preparePayload(params: BodyParams<TextGenerationInput & BaseArgs>): Record<string, unknown> {
const { inputs, parameters, ...rest } = params.args;
return {
...rest,
model: params.model,
prompt: inputs,
...(parameters?.max_new_tokens !== undefined && { max_tokens: parameters.max_new_tokens }),
...(parameters?.temperature !== undefined && { temperature: parameters.temperature }),
...(parameters?.top_p !== undefined && { top_p: parameters.top_p }),
...(parameters?.repetition_penalty !== undefined && { repetition_penalty: parameters.repetition_penalty }),
...(parameters?.stop !== undefined && { stop: parameters.stop }),
};
}

override async getResponse(response: unknown): Promise<TextGenerationOutput> {
const r = response as { choices?: Array<{ text?: string }> };
if (typeof r?.choices?.[0]?.text === "string") {
return { generated_text: r.choices[0].text };
}
throw new InferenceClientProviderOutputError("Malformed response from Systalyze completions API");
}
}
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export const INFERENCE_PROVIDERS = [
"replicate",
"sambanova",
"scaleway",
"systalyze",
"together",
"wavespeed",
"zai-org",
Expand Down Expand Up @@ -104,6 +105,7 @@ export const PROVIDERS_HUB_ORGS: Record<InferenceProvider, string> = {
replicate: "replicate",
sambanova: "sambanovasystems",
scaleway: "scaleway",
systalyze: "systalyze",
together: "togethercomputer",
wavespeed: "wavespeed",
"zai-org": "zai-org",
Expand Down
74 changes: 74 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,80 @@ describe.skip("InferenceClient", () => {
TIMEOUT,
);

describe.concurrent(
"Systalyze",
() => {
const client = new InferenceClient(env.HF_SYSTALYZE_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["systalyze"] = {
"meta-llama/Llama-3.1-8B-Instruct": {
provider: "systalyze",
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
providerId: "meta-llama/Llama-3.1-8B-Instruct",
status: "live",
task: "conversational",
},
"meta-llama/Llama-3.1-8B-Instruct": {
provider: "systalyze",
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
providerId: "meta-llama/Llama-3.1-8B-Instruct",
status: "live",
task: "text-generation",
},
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "systalyze",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "systalyze",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});

it("textGeneration", async () => {
const res = await client.textGeneration({
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "systalyze",
inputs: "The capital of France is",
parameters: {
max_new_tokens: 10,
},
});
expect(res).toBeDefined();
expect(typeof res.generated_text).toBe("string");
expect(res.generated_text.length).toBeGreaterThan(0);
});
},
TIMEOUT,
);

describe.concurrent("3rd party providers", () => {
it("chatCompletion - fails with unsupported model", async () => {
expect(
Expand Down