Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Currently, we support the following providers:
- [NVIDIA](https://build.nvidia.com/)
- [OVHcloud](https://endpoints.ai.cloud.ovh.net/)
- [Public AI](https://publicai.co)
- [RegoloAI](https://regolo.ai)
- [Replicate](https://replicate.com)
- [Sambanova](https://sambanova.ai)
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
Expand Down Expand Up @@ -98,6 +99,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Nscale supported models](https://huggingface.co/api/partners/nscale/models)
- [NVIDIA supported models](https://huggingface.co/api/partners/nvidia/models)
- [OVHcloud supported models](https://huggingface.co/api/partners/ovhcloud/models)
- [RegoloAI supported models](https://huggingface.co/api/partners/regoloai/models)
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models)
Expand Down
7 changes: 7 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import type {
ZeroShotClassificationTaskHelper,
ZeroShotImageClassificationTaskHelper,
} from "../providers/providerHelper.js";
import * as RegoloAI from "../providers/regoloai.js"
import * as Replicate from "../providers/replicate.js";
import * as Sambanova from "../providers/sambanova.js";
import * as Scaleway from "../providers/scaleway.js";
Expand Down Expand Up @@ -161,6 +162,12 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
publicai: {
conversational: new PublicAI.PublicAIConversationalTask(),
},
regoloai: {
conversational: new RegoloAI.RegoloaiConversationalTask(),
"text-generation": new RegoloAI.RegoloaiTextGenerationTask(),
"text-to-image": new RegoloAI.RegoloaiTextToImageTask(),
"feature-extraction": new RegoloAI.RegoloaiFeatureExtractionTask(),
},
replicate: {
"text-to-image": new Replicate.ReplicateTextToImageTask(),
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
openai: {},
publicai: {},
ovhcloud: {},
regoloai: {},
replicate: {},
sambanova: {},
scaleway: {},
Expand Down
163 changes: 163 additions & 0 deletions packages/inference/src/providers/regoloai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import type {
ChatCompletionOutput,
TextGenerationOutput,
TextGenerationOutputFinishReason,
AutomaticSpeechRecognitionOutput,
FeatureExtractionOutput,
ChatCompletionInput,
} from "@huggingface/tasks";
import type { BodyParams, OutputType, RequestArgs } from "../types.js";
import { omit } from "../utils/omit.js";
import {
BaseConversationalTask,
BaseTextGenerationTask,
TaskProviderHelper,
type TextToImageTaskHelper,
type AutomaticSpeechRecognitionTaskHelper,
type FeatureExtractionTaskHelper,
} from "./providerHelper.js";
import { InferenceClientProviderOutputError } from "../errors.js";

const REGOLO_API_BASE_URL = "https://api.regolo.ai";

interface RegoloTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
choices: Array<{
text: string;
finish_reason: TextGenerationOutputFinishReason;
index: number;
logprobs?: unknown;
}>;
}

interface RegoloImageGeneration {
data: Array<{
b64_json?: string;
url?: string;
}>;
}

export class RegoloaiConversationalTask extends BaseConversationalTask {
constructor() {
super("regoloai", REGOLO_API_BASE_URL);
}

override preparePayload(params: BodyParams<ChatCompletionInput>): Record<string, unknown> {
const payload = super.preparePayload(params);

if (params.model === "deepseek-ocr") {
payload.skip_special_tokens = false;
}

const response_format = payload.response_format as
| { type: "json_schema"; json_schema: { schema: unknown } }
| undefined;

if (response_format?.type === "json_schema" && response_format?.json_schema?.schema) {
payload.response_format = {
type: "json_schema",
schema: response_format.json_schema.schema,
};
}

return payload;
}
}

export class RegoloaiTextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("regoloai", REGOLO_API_BASE_URL);
}

override makeRoute(): string {
return "v1/completions";
}

override preparePayload(params: BodyParams): Record<string, unknown> {
return {
model: params.model,
...params.args,
prompt: params.args.inputs,
};
}

override async getResponse(response: RegoloTextCompletionOutput): Promise<TextGenerationOutput> {
if (typeof response?.choices?.[0]?.text === "string") {
return { generated_text: response.choices[0].text };
}
throw new InferenceClientProviderOutputError("Malformed response from RegoloAI completions API");
}
}

export class RegoloaiTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
constructor() {
super("regoloai", REGOLO_API_BASE_URL);
}

makeRoute(): string {
return "v1/images/generations";
}

preparePayload(params: BodyParams): Record<string, unknown> {
const parameters = (params.args.parameters as Record<string, unknown>) ?? {};

return {
model: params.model,
prompt: params.args.inputs,
n: parameters.n ?? 1,
size: parameters.size ?? "1024x1024",
response_format: params.outputType === "url" ? "url" : "b64_json",
...omit(params.args, ["inputs", "parameters"]),
};
}

async getResponse(
response: RegoloImageGeneration,
_url?: string,
_headers?: HeadersInit,
outputType?: OutputType,
): Promise<string | Blob | Record<string, unknown>> {

if (!response?.data?.[0]) {
throw new InferenceClientProviderOutputError("No images received from RegoloAI");
}

if (outputType === "json") return { ...response };

const item = response.data[0];

if (item.url) return item.url;

if (item.b64_json) {
const base64Data = item.b64_json;
if (outputType === "dataUrl") return `data:image/png;base64,${base64Data}`;

return fetch(`data:image/png;base64,${base64Data}`).then((res) => res.blob());
}

throw new InferenceClientProviderOutputError("Image format not recognized in response");
}
}

export class RegoloaiFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper {
constructor() {
super("regoloai", REGOLO_API_BASE_URL);
}

makeRoute(): string {
return "v1/embeddings";
}

preparePayload(params: BodyParams): Record<string, unknown> {
return {
model: params.model,
input: params.args.inputs,
};
}

async getResponse(response: any): Promise<FeatureExtractionOutput> {
if (Array.isArray(response?.data) && response.data[0]?.embedding) {
return response.data[0].embedding;
}
throw new InferenceClientProviderOutputError("Invalid embeddings response from RegoloAI");
}
}
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export const INFERENCE_PROVIDERS = [
"openai",
"ovhcloud",
"publicai",
"regoloai",
"replicate",
"sambanova",
"scaleway",
Expand Down Expand Up @@ -101,6 +102,7 @@ export const PROVIDERS_HUB_ORGS: Record<InferenceProvider, string> = {
openai: "openai",
ovhcloud: "ovhcloud",
publicai: "publicai",
regoloai: "regoloai",
replicate: "replicate",
sambanova: "sambanovasystems",
scaleway: "scaleway",
Expand Down
Loading