From d7879754e8892e846436d33510172c092fe1a21b Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 9 Sep 2025 20:37:47 -0400 Subject: [PATCH 1/6] Create the basis for integrating bytez as an inference provider. --- .gitignore | 3 +- packages/inference/README.md | 1 + .../inference/src/lib/getProviderHelper.ts | 27 + packages/inference/src/providers/bytez.ts | 904 ++++++++++++++++++ packages/inference/src/providers/consts.ts | 1 + .../inference/src/providers/providerHelper.ts | 10 + .../src/tasks/audio/audioClassification.ts | 4 +- .../src/tasks/cv/imageClassification.ts | 4 +- .../inference/src/tasks/cv/imageToText.ts | 8 +- .../inference/src/tasks/cv/objectDetection.ts | 5 +- packages/inference/src/types.ts | 1 + .../inference/test/InferenceClient.spec.ts | 507 +++++++++- 12 files changed, 1467 insertions(+), 8 deletions(-) create mode 100644 packages/inference/src/providers/bytez.ts diff --git a/.gitignore b/.gitignore index 68c13e55e8..eba13e678a 100644 --- a/.gitignore +++ b/.gitignore @@ -107,4 +107,5 @@ dist .DS_Store # Generated by doc-internal -docs \ No newline at end of file +docs +.vscode/launch.json diff --git a/packages/inference/README.md b/packages/inference/README.md index ed43e0644f..ba292c8f47 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -64,6 +64,7 @@ Currently, we support the following providers: - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) - [Groq](https://groq.com) +- [Bytez](https://bytez.com) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 3e95eceb8c..90aa4043ec 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,4 +1,5 @@ import * as BlackForestLabs from "../providers/black-forest-labs.js"; +import * as Bytez from "../providers/bytez.js"; import * as Cerebras from "../providers/cerebras.js"; import * as Cohere from "../providers/cohere.js"; import * as FalAI from "../providers/fal-ai.js"; @@ -56,6 +57,32 @@ export const PROVIDERS: Record { + void binary; + + const headers: Record = { Authorization: `Key ${params.accessToken}` }; + headers["Content-Type"] = "application/json"; + return headers; + } + + // we always want this behavior with out API, we only support JSON payloads + override makeBody(params: BodyParams): BodyInit { + return JSON.stringify(this.preparePayload(params)); + } + + async _preparePayloadAsync(args: Record) { + if ("inputs" in args) { + const input = args.inputs as Blob; + const arrayBuffer = await input.arrayBuffer(); + const uint8Array = new Uint8Array(arrayBuffer); + const base64 = base64FromBytes(uint8Array); + + return { + ...args, + inputs: base64, + }; + } else { + // handle LegacyImageInput case + const data = args.data as Blob; + const arrayBuffer = data instanceof Blob ? await data.arrayBuffer() : data; + const uint8Array = new Uint8Array(arrayBuffer); + const base64 = base64FromBytes(uint8Array); + + return { + ...args, + inputs: base64, + }; + } + } + + handleError(error: string) { + if (error) { + throw new Error(`There was a problem with the bytez API: ${error}`); + } + } +} + +export class BytezTextGenerationTask extends BytezTask implements TextGenerationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + stream: params.args.stream, + complianceFormat: params.args.stream ? "hf://text-generation" : undefined, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + generated_text: output, + }; + } +} + +export class BytezConversationalTask extends BytezTask implements ConversationalTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + messages: params.args.messages, + params: params.args.parameters, + stream: params.args.stream, + // HF uses the same schema as OAI compliant spec + complianceFormat: "openai://chat/completions", + }; + } + override async getResponse( + response: BytezChatLikeOutput, + url: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezSummarizationTask extends BytezTask implements SummarizationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + summary_text: output, + }; + } +} + +export class BytezTranslationTask extends BytezTask implements TranslationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + translation_text: output, + }; + } +} + +export class BytezTextToImageTask extends BytezTask implements TextToImageTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise> { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const cdnResponse = await fetch(output); + + const blob = await cdnResponse.blob(); + + return blob; + } +} + +export class BytezTextToVideoTask extends BytezTask implements TextToVideoTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse(response: BytezStringLikeOutput, url?: string, headers?: HeadersInit): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const cdnResponse = await fetch(output); + + const blob = await cdnResponse.blob(); + + return blob; + } +} + +export class BytezImageToTextTask extends BytezTask implements ImageToTextTaskHelper { + preparePayloadAsync(args: ImageToTextArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + generatedText: output, + generated_text: output, + }; + } +} + +export class BytezQuestionAnsweringTask extends BytezTask implements QuestionAnsweringTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + question: params.args.inputs.question, + context: params.args.inputs.context, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezQuestionAnsweringOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const { answer, score, start, end } = output; + + return { + answer, + score, + start, + end, + }; + } +} + +export class BytezVisualQuestionAnsweringTask extends BytezTask implements VisualQuestionAnsweringTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs.image, + question: params.args.inputs.question, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezVisualQuestionAnsweringOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output[0]; + } +} + +export class BytezDocumentQuestionAnsweringTask extends BytezTask implements DocumentQuestionAnsweringTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs.image, + question: params.args.inputs.question, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezDocumentQuestionAnsweringOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output[0]; + } +} + +export class BytezImageSegmentationTask extends BytezTask implements ImageSegmentationTaskHelper { + preparePayloadAsync(args: ImageSegmentationArgs): Promise { + return this._preparePayloadAsync(args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezImageSegmentationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + // some models return no score, in which case we default to -1 to indicate that a number is expected, but no number was produced + return output.map(({ label, score, mask_png }) => ({ label, score: score || -1, mask: mask_png })); + } +} + +export class BytezImageClassificationTask extends BytezTask implements ImageClassificationTaskHelper { + preparePayloadAsync(args: ImageClassificationArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezZeroShotImageClassificationTask extends BytezTask implements ZeroShotImageClassificationTaskHelper { + override preparePayload( + params: BodyParams + ): Record | BodyInit { + const candidate_labels = params.args.parameters.candidate_labels; + + return { + base64: params.args.inputs, + candidate_labels, + params: { ...params.args.parameters, candidate_labels: undefined }, + }; + } + + override async getResponse( + response: BytezZeroShotImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezObjectDetectionTask extends BytezTask implements ObjectDetectionTaskHelper { + async preparePayloadAsync(args: ObjectDetectionArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record | BodyInit { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezObjectDetectionOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +// TODO this flattens the vectors, this may not be the desired behavior, see how our test model performs when hitting HF directly +// and compare the HF impl model's output through our own api +export class BytezFeatureExtractionTask extends BytezTask implements FeatureExtractionTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezFeatureExtractionOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output.flat(); + } +} + +export class BytezSentenceSimilarityTask extends BytezTask implements SentenceSimilarityTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: [params.args.inputs.source_sentence, ...params.args.inputs.sentences], + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezSentenceSimilarityOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const [sourceSentenceVector, ...sentenceVectors] = output; + + const similarityScores = []; + + for (const sentenceVector of sentenceVectors) { + const similarity = this.cosineSimilarity(sourceSentenceVector, sentenceVector); + similarityScores.push(similarity); + } + + return similarityScores; + } + + cosineSimilarity(a: number[], b: number[]): number { + if (a.length !== b.length) throw new Error("Vectors must be same length"); + let dot = 0, + normA = 0, + normB = 0; + for (let i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] ** 2; + normB += b[i] ** 2; + } + if (normA === 0 || normB === 0) return 0; + return dot / (Math.sqrt(normA) * Math.sqrt(normB)); + } +} + +export class BytezFillMaskTask extends BytezTask implements FillMaskTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezSentenceFillMaskOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezTextClassificationTask extends BytezTask implements TextClassificationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezZeroShotImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezTokenClassificationTask extends BytezTask implements TokenClassificationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezTokenClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output.map(({ end, entity, score, start, word }) => ({ + entity_group: entity, + score, + word, + start, + end, + })); + } +} + +export class BytezZeroShotClassificationTask extends BytezTask implements ZeroShotClassificationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + candidate_labels: params.args.parameters.candidate_labels, + params: { + ...params.args.parameters, + candidate_labels: undefined, + }, + }; + } + + override async getResponse( + response: BytezZeroShotImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezAudioClassificationTask extends BytezTask implements AudioClassificationTaskHelper { + async preparePayloadAsync(args: AudioClassificationArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record | BodyInit { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezZeroShotImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezTextToSpeechTask extends BytezTask implements TextToSpeechTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse(response: BytezStringLikeOutput, url?: string, headers?: HeadersInit): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const byteArray = Buffer.from(output, "base64"); + return new Blob([byteArray], { type: "audio/wav" }); + } +} + +export class BytezAutomaticSpeechRecognitionTask extends BytezTask implements AutomaticSpeechRecognitionTaskHelper { + async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + text: output.trim(), + }; + } +} + +export class BytezTextToAudioTask extends BytezTask implements AutomaticSpeechRecognitionTaskHelper { + async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + text: output.trim(), + }; + } +} diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 806e34282c..840ecf1092 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -19,6 +19,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record< * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", */ "black-forest-labs": {}, + bytez: {}, cerebras: {}, cohere: {}, "fal-ai": {}, diff --git a/packages/inference/src/providers/providerHelper.ts b/packages/inference/src/providers/providerHelper.ts index fc1ebc25f8..e86a36814f 100644 --- a/packages/inference/src/providers/providerHelper.ts +++ b/packages/inference/src/providers/providerHelper.ts @@ -55,6 +55,12 @@ import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js"; import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js"; import type { ImageToVideoArgs } from "../tasks/cv/imageToVideo.js"; import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js"; +import type { + AudioClassificationArgs, + ImageClassificationArgs, + ImageToTextArgs, + ObjectDetectionArgs, +} from "../tasks/index.js"; /** * Base class for task-specific provider helpers @@ -168,16 +174,19 @@ export interface ImageSegmentationTaskHelper { export interface ImageClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: ImageClassificationArgs): Promise; } export interface ObjectDetectionTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: ObjectDetectionArgs): Promise; } export interface ImageToTextTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: ImageToTextArgs): Promise; } export interface ZeroShotImageClassificationTaskHelper { @@ -267,6 +276,7 @@ export interface AutomaticSpeechRecognitionTaskHelper { export interface AudioClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: AudioClassificationArgs): Promise; } // Multimodal Tasks diff --git a/packages/inference/src/tasks/audio/audioClassification.ts b/packages/inference/src/tasks/audio/audioClassification.ts index 89737e2597..13da90be0c 100644 --- a/packages/inference/src/tasks/audio/audioClassification.ts +++ b/packages/inference/src/tasks/audio/audioClassification.ts @@ -18,7 +18,9 @@ export async function audioClassification( ): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "audio-classification"); - const payload = preparePayload(args); + const payload = providerHelper.preparePayloadAsync + ? await providerHelper.preparePayloadAsync(args) + : preparePayload(args); const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "audio-classification", diff --git a/packages/inference/src/tasks/cv/imageClassification.ts b/packages/inference/src/tasks/cv/imageClassification.ts index 2a39f57e26..1a2d373d41 100644 --- a/packages/inference/src/tasks/cv/imageClassification.ts +++ b/packages/inference/src/tasks/cv/imageClassification.ts @@ -17,7 +17,9 @@ export async function imageClassification( ): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "image-classification"); - const payload = preparePayload(args); + const payload = providerHelper.preparePayloadAsync + ? await providerHelper.preparePayloadAsync(args) + : preparePayload(args); const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "image-classification", diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts index 7784b3630c..2c65389d96 100644 --- a/packages/inference/src/tasks/cv/imageToText.ts +++ b/packages/inference/src/tasks/cv/imageToText.ts @@ -13,11 +13,15 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput); export async function imageToText(args: ImageToTextArgs, options?: Options): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "image-to-text"); - const payload = preparePayload(args); + const payload = providerHelper.preparePayloadAsync + ? await providerHelper.preparePayloadAsync(args) + : preparePayload(args); const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, { ...options, task: "image-to-text", }); - return providerHelper.getResponse(res[0]); + // TODO the huggingface impl for this needs to be updated, used to be + // return providerHelper.getResponse(res[0]); + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/cv/objectDetection.ts b/packages/inference/src/tasks/cv/objectDetection.ts index 7f9741c5aa..a6e747cb11 100644 --- a/packages/inference/src/tasks/cv/objectDetection.ts +++ b/packages/inference/src/tasks/cv/objectDetection.ts @@ -14,7 +14,10 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "object-detection"); - const payload = preparePayload(args); + const payload = providerHelper.preparePayloadAsync + ? await providerHelper.preparePayloadAsync(args) + : preparePayload(args); + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "object-detection", diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index b31843b99b..ffb8a2bfc8 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -46,6 +46,7 @@ export type InferenceTask = Exclude | "conversational"; export const INFERENCE_PROVIDERS = [ "black-forest-labs", + "bytez", "cerebras", "cohere", "fal-ai", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 1cc60a43a9..061fa7636e 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1,6 +1,6 @@ import { assert, describe, expect, it } from "vitest"; -import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; +import type { ChatCompletionStreamOutput, WidgetType } from "@huggingface/tasks"; import type { TextToImageArgs } from "../src/index.js"; import { @@ -22,9 +22,512 @@ if (!env.HF_TOKEN) { console.warn("Set HF_TOKEN in the env to run the tests for better rate limits"); } -describe.skip("InferenceClient", () => { +describe("InferenceClient", () => { // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error. + describe.only( + "Bytez", + () => { + const client = new InferenceClient(env.HF_BYTEZ_KEY ?? "dummy"); + + const tests: { task: WidgetType; modelId: string; test: (modelId: string) => Promise; stream: boolean }[] = + [ + { + task: "text-generation", + modelId: "bigscience/mt0-small", + test: async (modelId: string) => { + const { generated_text } = await client.textGeneration({ + model: modelId, + provider: "bytez", + inputs: "Hello", + }); + expect(typeof generated_text).toBe("string"); + }, + stream: false, + }, + { + task: "text-generation", + modelId: "bigscience/mt0-small", + test: async (modelId: string) => { + const response = client.textGenerationStream({ + model: modelId, + provider: "bytez", + inputs: "Please answer the following question: complete one two and ____.", + parameters: { + max_new_tokens: 50, + num_beams: 1, + }, + }); + for await (const ret of response) { + expect(ret).toMatchObject({ + details: null, + index: expect.any(Number), + token: { + id: expect.any(Number), + logprob: expect.any(Number), + // text: expect.any(String) || null, + special: expect.any(Boolean), + }, + generated_text: ret.generated_text ? "afew" : null, + }); + expect(typeof ret.token.text === "string" || ret.token.text === null).toBe(true); + } + }, + stream: true, + }, + { + task: "conversational", + modelId: "Qwen/Qwen3-1.7B", + test: async (modelId: string) => { + const { choices } = await client.chatCompletion({ + model: modelId, + provider: "bytez", + messages: [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { role: "user", content: "How many helicopters can a human eat in one sitting?" }, + ], + }); + expect(typeof choices[0].message.role).toBe("string"); + expect(typeof choices[0].message.content).toBe("string"); + }, + stream: false, + }, + { + task: "conversational", + modelId: "Qwen/Qwen3-1.7B", + test: async (modelId: string) => { + const stream = client.chatCompletionStream({ + model: modelId, + provider: "bytez", + messages: [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { role: "user", content: "How many helicopters can a human eat in one sitting?" }, + ], + }); + const chunks = []; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + if (chunk.choices[0].finish_reason === "stop") { + break; + } + chunks.push(chunk); + } + } + const out = chunks.map((chunk) => chunk.choices[0].delta.content).join(""); + expect(out).toContain("helicopter"); + }, + stream: true, + }, + { + task: "summarization", + modelId: "ainize/bart-base-cnn", + test: async (modelId: string) => { + const input = + "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris..."; + const { summary_text } = await client.summarization({ + model: modelId, + provider: "bytez", + inputs: input, + parameters: { max_length: 40 }, + }); + expect(summary_text.length).toBeLessThan(input.length); + }, + stream: false, + }, + { + task: "translation", + modelId: "Areeb123/En-Fr_Translation_Model", + test: async (modelId: string) => { + const { translation_text } = await client.translation({ + model: modelId, + provider: "bytez", + inputs: "Hello", + }); + expect(typeof translation_text).toBe("string"); + expect(translation_text.length).toBeGreaterThan(0); + }, + stream: false, + }, + { + task: "text-to-image", + modelId: "IDKiro/sdxs-512-0.9", + test: async (modelId: string) => { + const res = await client.textToImage({ + model: modelId, + provider: "bytez", + inputs: "A cat in the hat", + }); + expect(res).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "text-to-video", + modelId: "ali-vilab/text-to-video-ms-1.7b", + test: async (modelId: string) => { + const res = await client.textToVideo({ + model: modelId, + provider: "bytez", + inputs: "A cat in the hat", + }); + expect(res).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "image-to-text", + modelId: "captioner/caption-gen", + test: async (modelId: string) => { + const { generated_text } = await client.imageToText({ + model: modelId, + provider: "bytez", + data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + }); + expect(typeof generated_text).toBe("string"); + }, + stream: false, + }, + { + task: "question-answering", + modelId: "airesearch/xlm-roberta-base-finetune-qa", + test: async (modelId: string) => { + const { answer, score, start, end } = await client.questionAnswering({ + model: modelId, + provider: "bytez", + inputs: { + question: "Where do I live?", + context: "My name is Merve and I live in İstanbul.", + }, + }); + expect(answer).toBeDefined(); + expect(score).toBeDefined(); + expect(start).toBeDefined(); + expect(end).toBeDefined(); + }, + stream: false, + }, + { + task: "visual-question-answering", + modelId: "aqachun/Vilt_fine_tune_2000", + test: async (modelId: string) => { + const output = await client.visualQuestionAnswering({ + model: modelId, + provider: "bytez", + inputs: { + image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + question: "What kind of animal is this?", + }, + }); + expect(output).toMatchObject({ + answer: expect.any(String), + score: expect.any(Number), + }); + }, + stream: false, + }, + { + task: "document-question-answering", + modelId: "cloudqi/CQI_Visual_Question_Awnser_PT_v0", + test: async (modelId: string) => { + const url = "https://templates.invoicehome.com/invoice-template-us-neat-750px.png"; + const response = await fetch(url); + const blob = await response.blob(); + const output = await client.documentQuestionAnswering({ + model: modelId, + provider: "bytez", + inputs: { + // + image: blob, + question: "What's the total cost?", + }, + }); + expect(output).toMatchObject({ + answer: expect.any(String), + score: expect.any(Number), + // not sure what start/end refers to in this case + start: expect.any(Number), + end: expect.any(Number), + }); + }, + stream: false, + }, + { + task: "image-segmentation", + modelId: "apple/deeplabv3-mobilevit-small", + test: async (modelId: string) => { + const output = await client.imageSegmentation({ + model: modelId, + provider: "bytez", + inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + label: expect.any(String), + score: expect.any(Number), + mask: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "image-classification", + modelId: "akahana/vit-base-cats-vs-dogs", + test: async (modelId: string) => { + const output = await client.imageClassification({ + // + model: modelId, + provider: "bytez", + data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "zero-shot-image-classification", + modelId: "BilelDJ/clip-hugging-face-finetuned", + test: async (modelId: string) => { + const output = await client.zeroShotImageClassification({ + model: modelId, + provider: "bytez", + inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }) }, + parameters: { + candidate_labels: ["animal", "toy", "car"], + }, + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "object-detection", + modelId: "aisak-ai/aisak-detect", + test: async (modelId: string) => { + const output = await client.objectDetection({ + model: modelId, + provider: "bytez", + inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + box: expect.objectContaining({ + xmin: expect.any(Number), + ymin: expect.any(Number), + xmax: expect.any(Number), + ymax: expect.any(Number), + }), + }), + ]) + ); + }, + stream: false, + }, + { + task: "feature-extraction", + modelId: "allenai/specter2_base", + test: async (modelId: string) => { + const output = await client.featureExtraction({ + model: modelId, + provider: "bytez", + inputs: "That is a happy person", + }); + expect(output).toEqual(expect.arrayContaining([expect.any(Number)])); + }, + stream: false, + }, + { + task: "sentence-similarity", + modelId: "embedding-data/distilroberta-base-sentence-transformer", + test: async (modelId: string) => { + const output = await client.sentenceSimilarity({ + model: modelId, + provider: "bytez", + inputs: { + source_sentence: "That is a happy person", + sentences: ["That is a happy dog", "That is a very happy person", "Today is a sunny day"], + }, + }); + expect(output).toEqual([expect.any(Number), expect.any(Number), expect.any(Number)]); + }, + stream: false, + }, + { + task: "fill-mask", + modelId: "almanach/camembert-base", + test: async (modelId: string) => { + const output = await client.fillMask({ model: modelId, provider: "bytez", inputs: "Hello " }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + token: expect.any(Number), + token_str: expect.any(String), + sequence: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "text-classification", + modelId: "AdamCodd/distilbert-base-uncased-finetuned-sentiment-amazon", + test: async (modelId: string) => { + const output = await client.textClassification({ + model: modelId, + provider: "bytez", + inputs: "I am a special unicorn", + }); + expect(output.every((entry) => entry.label && entry.score)).toBe(true); + }, + stream: false, + }, + { + task: "token-classification", + modelId: "2rtl3/mn-xlm-roberta-base-named-entity", + test: async (modelId: string) => { + const output = await client.tokenClassification({ + model: modelId, + provider: "bytez", + inputs: "John went to NYC", + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + entity_group: expect.any(String), + score: expect.any(Number), + word: expect.any(String), + start: expect.any(Number), + end: expect.any(Number), + }), + ]) + ); + }, + stream: false, + }, + { + task: "zero-shot-classification", + modelId: "AyoubChLin/DistilBERT_eco_ZeroShot", + test: async (modelId: string) => { + const testInput = "Ninja turtles are cool"; + const testCandidateLabels = ["positive", "negative"]; + const output = await client.zeroShotClassification({ + model: modelId, + provider: "bytez", + inputs: [ + testInput, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any, + parameters: { candidate_labels: testCandidateLabels }, + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + sequence: testInput, + labels: testCandidateLabels, + scores: [expect.closeTo(0.5206031203269958, 5), expect.closeTo(0.479396790266037, 5)], + }), + ]) + ); + }, + stream: false, + }, + { + task: "audio-classification", + modelId: "aaraki/wav2vec2-base-finetuned-ks", + test: async (modelId: string) => { + const output = await client.audioClassification({ + model: modelId, + provider: "bytez", + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "text-to-speech", + modelId: "facebook/mms-tts-eng", + test: async (modelId: string) => { + const output = await client.textToSpeech({ model: modelId, provider: "bytez", inputs: "Hello" }); + expect(output).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "automatic-speech-recognition", + modelId: "facebook/data2vec-audio-base-960h", + test: async (modelId: string) => { + const output = await client.automaticSpeechRecognition({ + model: modelId, + provider: "bytez", + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }); + expect(output).toMatchObject({ + text: "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP AUDIENCES IN DRAUGHTY SCHOOLROOMS DAY AFTER DAY FOR A FORTNIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS", + }); + }, + stream: false, + }, + ]; + + // bootstrap the inference mappings for testing + for (const { task, modelId } of tests) { + HARDCODED_MODEL_INFERENCE_MAPPING.bytez[modelId] = { + provider: "bytez", + hfModelId: modelId, + providerId: modelId, + status: "live", + task, + adapter: undefined, + adapterWeightsPath: undefined, + }; + } + + // run the tests + for (const { task, modelId, test, stream } of tests) { + const testName = `${task} - ${modelId}${stream ? " stream" : ""}`; + it(testName, async () => { + await test(modelId); + }); + } + }, + TIMEOUT + ); + describe("backward compatibility", () => { it("works with old HfInference name", async () => { const hf = new HfInference(env.HF_TOKEN); From 9b44f7b99b7519dc71d4350f19ea554fa5d4a1be Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 10 Sep 2025 15:00:21 -0400 Subject: [PATCH 2/6] Move Bytez tests to the bottom of the tests file. --- packages/inference/src/providers/bytez.ts | 2 +- .../inference/test/InferenceClient.spec.ts | 4591 ++++++++--------- 2 files changed, 2295 insertions(+), 2298 deletions(-) diff --git a/packages/inference/src/providers/bytez.ts b/packages/inference/src/providers/bytez.ts index 57c36d430e..95868ff315 100644 --- a/packages/inference/src/providers/bytez.ts +++ b/packages/inference/src/providers/bytez.ts @@ -219,7 +219,7 @@ abstract class BytezTask extends TaskProviderHelper { handleError(error: string) { if (error) { - throw new Error(`There was a problem with the bytez API: ${error}`); + throw new Error(`There was a problem with the Bytez API: ${error}`); } } } diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 061fa7636e..eedcf91c73 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -25,1624 +25,1713 @@ if (!env.HF_TOKEN) { describe("InferenceClient", () => { // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error. - describe.only( - "Bytez", + describe("backward compatibility", () => { + it("works with old HfInference name", async () => { + const hf = new HfInference(env.HF_TOKEN); + expect("fillMask" in hf).toBe(true); + }); + }); + + describe.concurrent( + "HF Inference", () => { - const client = new InferenceClient(env.HF_BYTEZ_KEY ?? "dummy"); + const hf = new InferenceClient(env.HF_TOKEN); + HARDCODED_MODEL_INFERENCE_MAPPING["hf-inference"] = { + "google-bert/bert-base-uncased": { + provider: "hf-inference", + providerId: "google-bert/bert-base-uncased", + hfModelId: "google-bert/bert-base-uncased", + task: "fill-mask", + status: "live", + }, + "google/pegasus-xsum": { + provider: "hf-inference", + providerId: "google/pegasus-xsum", + hfModelId: "google/pegasus-xsum", + task: "summarization", + status: "live", + }, + "deepset/roberta-base-squad2": { + provider: "hf-inference", + providerId: "deepset/roberta-base-squad2", + hfModelId: "deepset/roberta-base-squad2", + task: "question-answering", + status: "live", + }, + "google/tapas-base-finetuned-wtq": { + provider: "hf-inference", + providerId: "google/tapas-base-finetuned-wtq", + hfModelId: "google/tapas-base-finetuned-wtq", + task: "table-question-answering", + status: "live", + }, + "mistralai/Mistral-7B-Instruct-v0.2": { + provider: "hf-inference", + providerId: "mistralai/Mistral-7B-Instruct-v0.2", + hfModelId: "mistralai/Mistral-7B-Instruct-v0.2", + task: "text-generation", + status: "live", + }, + "impira/layoutlm-document-qa": { + provider: "hf-inference", + providerId: "impira/layoutlm-document-qa", + hfModelId: "impira/layoutlm-document-qa", + task: "document-question-answering", + status: "live", + }, + "naver-clova-ix/donut-base-finetuned-docvqa": { + provider: "hf-inference", + providerId: "naver-clova-ix/donut-base-finetuned-docvqa", + hfModelId: "naver-clova-ix/donut-base-finetuned-docvqa", + task: "document-question-answering", + status: "live", + }, + "google/tapas-large-finetuned-wtq": { + provider: "hf-inference", + providerId: "google/tapas-large-finetuned-wtq", + hfModelId: "google/tapas-large-finetuned-wtq", + task: "table-question-answering", + status: "live", + }, + "facebook/detr-resnet-50": { + provider: "hf-inference", + providerId: "facebook/detr-resnet-50", + hfModelId: "facebook/detr-resnet-50", + task: "object-detection", + status: "live", + }, + "facebook/detr-resnet-50-panoptic": { + provider: "hf-inference", + providerId: "facebook/detr-resnet-50-panoptic", + hfModelId: "facebook/detr-resnet-50-panoptic", + task: "image-segmentation", + status: "live", + }, + "facebook/wav2vec2-large-960h-lv60-self": { + provider: "hf-inference", + providerId: "facebook/wav2vec2-large-960h-lv60-self", + hfModelId: "facebook/wav2vec2-large-960h-lv60-self", + task: "automatic-speech-recognition", + status: "live", + }, + "superb/hubert-large-superb-er": { + provider: "hf-inference", + providerId: "superb/hubert-large-superb-er", + hfModelId: "superb/hubert-large-superb-er", + task: "audio-classification", + status: "live", + }, + "speechbrain/sepformer-wham": { + provider: "hf-inference", + providerId: "speechbrain/sepformer-wham", + hfModelId: "speechbrain/sepformer-wham", + task: "audio-to-audio", + status: "live", + }, + "espnet/kan-bayashi_ljspeech_vits": { + provider: "hf-inference", + providerId: "espnet/kan-bayashi_ljspeech_vits", + hfModelId: "espnet/kan-bayashi_ljspeech_vits", + task: "text-to-speech", + status: "live", + }, + "sentence-transformers/paraphrase-xlm-r-multilingual-v1": { + provider: "hf-inference", + providerId: "sentence-transformers/paraphrase-xlm-r-multilingual-v1", + hfModelId: "sentence-transformers/paraphrase-xlm-r-multilingual-v1", + task: "sentence-similarity", + status: "live", + }, + "sentence-transformers/distilbert-base-nli-mean-tokens": { + provider: "hf-inference", + providerId: "sentence-transformers/distilbert-base-nli-mean-tokens", + hfModelId: "sentence-transformers/distilbert-base-nli-mean-tokens", + task: "feature-extraction", + status: "live", + }, + "facebook/bart-base": { + provider: "hf-inference", + providerId: "facebook/bart-base", + hfModelId: "facebook/bart-base", + task: "feature-extraction", + status: "live", + }, + "facebook/bart-large-mnli": { + provider: "hf-inference", + providerId: "facebook/bart-large-mnli", + hfModelId: "facebook/bart-large-mnli", + task: "zero-shot-classification", + status: "live", + }, + "facebook/bart-large-cnn": { + provider: "hf-inference", + providerId: "facebook/bart-large-cnn", + hfModelId: "facebook/bart-large-cnn", + task: "summarization", + status: "live", + }, + "facebook/bart-large-xsum": { + provider: "hf-inference", + providerId: "facebook/bart-large-xsum", + hfModelId: "facebook/bart-large-xsum", + task: "summarization", + status: "live", + }, + "stabilityai/stable-diffusion-2": { + provider: "hf-inference", + providerId: "stabilityai/stable-diffusion-2", + hfModelId: "stabilityai/stable-diffusion-2", + task: "text-to-image", + status: "live", + }, + "lllyasviel/sd-controlnet-canny": { + provider: "hf-inference", + providerId: "lllyasviel/sd-controlnet-canny", + hfModelId: "lllyasviel/sd-controlnet-canny", + task: "image-to-image", + status: "live", + }, + "lllyasviel/sd-controlnet-depth": { + provider: "hf-inference", + providerId: "lllyasviel/sd-controlnet-depth", + hfModelId: "lllyasviel/sd-controlnet-depth", + task: "image-to-image", + status: "live", + }, + "t5-base": { + provider: "hf-inference", + providerId: "t5-base", + hfModelId: "t5-base", + task: "translation", + status: "live", + }, + "openai/clip-vit-large-patch14-336": { + provider: "hf-inference", + providerId: "openai/clip-vit-large-patch14-336", + hfModelId: "openai/clip-vit-large-patch14-336", + task: "zero-shot-image-classification", + status: "live", + }, + "google/vit-base-patch16-224": { + provider: "hf-inference", + providerId: "google/vit-base-patch16-224", + hfModelId: "google/vit-base-patch16-224", + task: "image-classification", + status: "live", + }, + "dandelin/vilt-b32-finetuned-vqa": { + provider: "hf-inference", + providerId: "dandelin/vilt-b32-finetuned-vqa", + hfModelId: "dandelin/vilt-b32-finetuned-vqa", + task: "visual-question-answering", + status: "live", + }, + "dbmdz/bert-large-cased-finetuned-conll03-english": { + provider: "hf-inference", + providerId: "dbmdz/bert-large-cased-finetuned-conll03-english", + hfModelId: "dbmdz/bert-large-cased-finetuned-conll03-english", + task: "token-classification", + status: "live", + }, + "nlpconnect/vit-gpt2-image-captioning": { + provider: "hf-inference", + providerId: "nlpconnect/vit-gpt2-image-captioning", + hfModelId: "nlpconnect/vit-gpt2-image-captioning", + task: "image-to-text", + status: "live", + }, + }; - const tests: { task: WidgetType; modelId: string; test: (modelId: string) => Promise; stream: boolean }[] = - [ - { - task: "text-generation", - modelId: "bigscience/mt0-small", - test: async (modelId: string) => { - const { generated_text } = await client.textGeneration({ - model: modelId, - provider: "bytez", - inputs: "Hello", - }); - expect(typeof generated_text).toBe("string"); - }, - stream: false, - }, - { - task: "text-generation", - modelId: "bigscience/mt0-small", - test: async (modelId: string) => { - const response = client.textGenerationStream({ - model: modelId, - provider: "bytez", - inputs: "Please answer the following question: complete one two and ____.", - parameters: { - max_new_tokens: 50, - num_beams: 1, - }, - }); - for await (const ret of response) { - expect(ret).toMatchObject({ - details: null, - index: expect.any(Number), - token: { - id: expect.any(Number), - logprob: expect.any(Number), - // text: expect.any(String) || null, - special: expect.any(Boolean), - }, - generated_text: ret.generated_text ? "afew" : null, - }); - expect(typeof ret.token.text === "string" || ret.token.text === null).toBe(true); - } - }, - stream: true, - }, - { - task: "conversational", - modelId: "Qwen/Qwen3-1.7B", - test: async (modelId: string) => { - const { choices } = await client.chatCompletion({ - model: modelId, - provider: "bytez", - messages: [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { role: "user", content: "How many helicopters can a human eat in one sitting?" }, - ], - }); - expect(typeof choices[0].message.role).toBe("string"); - expect(typeof choices[0].message.content).toBe("string"); - }, - stream: false, - }, - { - task: "conversational", - modelId: "Qwen/Qwen3-1.7B", - test: async (modelId: string) => { - const stream = client.chatCompletionStream({ - model: modelId, - provider: "bytez", - messages: [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { role: "user", content: "How many helicopters can a human eat in one sitting?" }, - ], - }); - const chunks = []; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - if (chunk.choices[0].finish_reason === "stop") { - break; - } - chunks.push(chunk); - } - } - const out = chunks.map((chunk) => chunk.choices[0].delta.content).join(""); - expect(out).toContain("helicopter"); + it("throws error if model does not exist", () => { + expect( + hf.fillMask({ + model: "this-model/does-not-exist-123", + inputs: "[MASK] world!", + }) + ).rejects.toThrowError("Model this-model/does-not-exist-123 does not exist"); + }); + + it("fillMask", async () => { + expect( + await hf.fillMask({ + model: "google-bert/bert-base-uncased", + inputs: "[MASK] world!", + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + token: expect.any(Number), + token_str: expect.any(String), + sequence: expect.any(String), + }), + ]) + ); + }); + + it.skip("works without model", async () => { + expect( + await hf.fillMask({ + inputs: "[MASK] world!", + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + token: expect.any(Number), + token_str: expect.any(String), + sequence: expect.any(String), + }), + ]) + ); + }); + + it("summarization", async () => { + expect( + await hf.summarization({ + model: "google/pegasus-xsum", + inputs: + "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.", + parameters: { + max_length: 100, }, - stream: true, - }, - { - task: "summarization", - modelId: "ainize/bart-base-cnn", - test: async (modelId: string) => { - const input = - "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris..."; - const { summary_text } = await client.summarization({ - model: modelId, - provider: "bytez", - inputs: input, - parameters: { max_length: 40 }, - }); - expect(summary_text.length).toBeLessThan(input.length); + }) + ).toEqual({ + summary_text: "The Eiffel Tower is one of the most famous buildings in the world.", + }); + }); + + it("questionAnswering", async () => { + expect( + await hf.questionAnswering({ + model: "deepset/roberta-base-squad2", + inputs: { + question: "What is the capital of France?", + context: "The capital of France is Paris.", }, - stream: false, - }, - { - task: "translation", - modelId: "Areeb123/En-Fr_Translation_Model", - test: async (modelId: string) => { - const { translation_text } = await client.translation({ - model: modelId, - provider: "bytez", - inputs: "Hello", - }); - expect(typeof translation_text).toBe("string"); - expect(translation_text.length).toBeGreaterThan(0); + }) + ).toMatchObject({ + answer: "Paris", + score: expect.any(Number), + start: expect.any(Number), + end: expect.any(Number), + }); + }); + + it("tableQuestionAnswering", async () => { + expect( + await hf.tableQuestionAnswering({ + model: "google/tapas-base-finetuned-wtq", + inputs: { + question: "How many stars does the transformers repository have?", + table: { + Repository: ["Transformers", "Datasets", "Tokenizers"], + Stars: ["36542", "4512", "3934"], + Contributors: ["651", "77", "34"], + "Programming language": ["Python", "Python", "Rust, Python and NodeJS"], + }, }, - stream: false, - }, - { - task: "text-to-image", - modelId: "IDKiro/sdxs-512-0.9", - test: async (modelId: string) => { - const res = await client.textToImage({ - model: modelId, - provider: "bytez", - inputs: "A cat in the hat", - }); - expect(res).toBeInstanceOf(Blob); + }) + ).toMatchObject({ + answer: "AVERAGE > 36542", + coordinates: [[0, 1]], + cells: ["36542"], + aggregator: "AVERAGE", + }); + }); + + it("documentQuestionAnswering", async () => { + expect( + await hf.documentQuestionAnswering({ + model: "impira/layoutlm-document-qa", + inputs: { + question: "Invoice number?", + image: new Blob([readTestFile("invoice.png")], { type: "image/png" }), }, - stream: false, - }, - { - task: "text-to-video", - modelId: "ali-vilab/text-to-video-ms-1.7b", - test: async (modelId: string) => { - const res = await client.textToVideo({ - model: modelId, - provider: "bytez", - inputs: "A cat in the hat", - }); - expect(res).toBeInstanceOf(Blob); + }) + ).toMatchObject({ + answer: "us-001", + score: expect.any(Number), + start: expect.any(Number), + end: expect.any(Number), + }); + }); + + // Errors with "Error: If you are using a VisionEncoderDecoderModel, you must provide a feature extractor" + it.skip("documentQuestionAnswering with non-array output", async () => { + expect( + await hf.documentQuestionAnswering({ + model: "naver-clova-ix/donut-base-finetuned-docvqa", + inputs: { + question: "Invoice number?", + image: new Blob([readTestFile("invoice.png")], { type: "image/png" }), }, - stream: false, - }, - { - task: "image-to-text", - modelId: "captioner/caption-gen", - test: async (modelId: string) => { - const { generated_text } = await client.imageToText({ - model: modelId, - provider: "bytez", - data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), - }); - expect(typeof generated_text).toBe("string"); + }) + ).toMatchObject({ + answer: "us-001", + }); + }); + + it("visualQuestionAnswering", async () => { + expect( + await hf.visualQuestionAnswering({ + model: "dandelin/vilt-b32-finetuned-vqa", + inputs: { + question: "How many cats are lying down?", + image: new Blob([readTestFile("cats.png")], { type: "image/png" }), }, - stream: false, - }, - { - task: "question-answering", - modelId: "airesearch/xlm-roberta-base-finetune-qa", - test: async (modelId: string) => { - const { answer, score, start, end } = await client.questionAnswering({ - model: modelId, - provider: "bytez", - inputs: { - question: "Where do I live?", - context: "My name is Merve and I live in İstanbul.", - }, - }); - expect(answer).toBeDefined(); - expect(score).toBeDefined(); - expect(start).toBeDefined(); - expect(end).toBeDefined(); - }, - stream: false, - }, - { - task: "visual-question-answering", - modelId: "aqachun/Vilt_fine_tune_2000", - test: async (modelId: string) => { - const output = await client.visualQuestionAnswering({ - model: modelId, - provider: "bytez", - inputs: { - image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), - question: "What kind of animal is this?", - }, - }); - expect(output).toMatchObject({ - answer: expect.any(String), - score: expect.any(Number), - }); - }, - stream: false, + }) + ).toMatchObject({ + answer: "2", + score: expect.any(Number), + }); + }); + + it("textClassification", async () => { + expect( + await hf.textClassification({ + model: "distilbert-base-uncased-finetuned-sst-2-english", + inputs: "I like you. I love you.", + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + label: expect.any(String), + score: expect.any(Number), + }), + ]) + ); + }); + + it.skip("textGeneration - gpt2", async () => { + expect( + await hf.textGeneration({ + model: "gpt2", + inputs: "The answer to the universe is", + }) + ).toMatchObject({ + generated_text: expect.any(String), + }); + }); + + it.skip("textGeneration - openai-community/gpt2", async () => { + expect( + await hf.textGeneration({ + model: "openai-community/gpt2", + inputs: "The answer to the universe is", + }) + ).toMatchObject({ + generated_text: expect.any(String), + }); + }); + + it("textGenerationStream - meta-llama/Llama-3.2-3B", async () => { + const response = hf.textGenerationStream({ + model: "meta-llama/Llama-3.2-3B", + inputs: "Please answer the following question: complete one two and ____.", + parameters: { + max_new_tokens: 50, + seed: 0, }, - { - task: "document-question-answering", - modelId: "cloudqi/CQI_Visual_Question_Awnser_PT_v0", - test: async (modelId: string) => { - const url = "https://templates.invoicehome.com/invoice-template-us-neat-750px.png"; - const response = await fetch(url); - const blob = await response.blob(); - const output = await client.documentQuestionAnswering({ - model: modelId, - provider: "bytez", - inputs: { - // - image: blob, - question: "What's the total cost?", - }, - }); - expect(output).toMatchObject({ - answer: expect.any(String), - score: expect.any(Number), - // not sure what start/end refers to in this case - start: expect.any(Number), - end: expect.any(Number), - }); + }); + + for await (const ret of response) { + expect(ret).toMatchObject({ + details: null, + index: expect.any(Number), + token: { + id: expect.any(Number), + logprob: expect.any(Number), + text: expect.any(String) || null, + special: expect.any(Boolean), }, - stream: false, + generated_text: ret.generated_text + ? "Please answer the following question: complete one two and ____. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17" + : null, + }); + } + }); + + it("textGenerationStream - catch error", async () => { + const response = hf.textGenerationStream({ + model: "meta-llama/Llama-3.2-3B", + inputs: "Write a short story about a robot that becomes sentient and takes over the world.", + parameters: { + max_new_tokens: 10_000, }, + }); + + await expect(response.next()).rejects.toThrow( + "Error forwarded from backend: Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 17 `inputs` tokens and 10000 `max_new_tokens`" + ); + }); + + it.skip("textGenerationStream - Abort", async () => { + const controller = new AbortController(); + const response = hf.textGenerationStream( { - task: "image-segmentation", - modelId: "apple/deeplabv3-mobilevit-small", - test: async (modelId: string) => { - const output = await client.imageSegmentation({ - model: modelId, - provider: "bytez", - inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), - }); - expect(output).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - label: expect.any(String), - score: expect.any(Number), - mask: expect.any(String), - }), - ]) - ); + model: "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", + inputs: "Write an essay about Sartre's philosophy.", + parameters: { + max_new_tokens: 100, }, - stream: false, }, + { signal: controller.signal } + ); + await expect(response.next()).resolves.toBeDefined(); + await expect(response.next()).resolves.toBeDefined(); + controller.abort(); + await expect(response.next()).rejects.toThrow("The operation was aborted"); + }); + + it("tokenClassification", async () => { + expect( + await hf.tokenClassification({ + model: "dbmdz/bert-large-cased-finetuned-conll03-english", + inputs: "My name is Sarah Jessica Parker but you can call me Jessica", + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + entity_group: expect.any(String), + score: expect.any(Number), + word: expect.any(String), + start: expect.any(Number), + end: expect.any(Number), + }), + ]) + ); + }); + + it("translation", async () => { + expect( + await hf.translation({ + model: "t5-base", + inputs: "My name is Wolfgang and I live in Berlin", + }) + ).toMatchObject({ + translation_text: "Mein Name ist Wolfgang und ich lebe in Berlin", + }); + // input is a list + expect( + await hf.translation({ + model: "t5-base", + // eslint-disable-next-line @typescript-eslint/no-explicit-any + inputs: ["My name is Wolfgang and I live in Berlin", "I work as programmer"] as any, + }) + ).toMatchObject([ { - task: "image-classification", - modelId: "akahana/vit-base-cats-vs-dogs", - test: async (modelId: string) => { - const output = await client.imageClassification({ - // - model: modelId, - provider: "bytez", - data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), - }); - expect(output).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - label: expect.any(String), - }), - ]) - ); - }, - stream: false, + translation_text: "Mein Name ist Wolfgang und ich lebe in Berlin", }, { - task: "zero-shot-image-classification", - modelId: "BilelDJ/clip-hugging-face-finetuned", - test: async (modelId: string) => { - const output = await client.zeroShotImageClassification({ - model: modelId, - provider: "bytez", - inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }) }, - parameters: { - candidate_labels: ["animal", "toy", "car"], - }, - }); - expect(output).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - label: expect.any(String), - }), - ]) - ); - }, - stream: false, + translation_text: "Ich arbeite als Programmierer", }, - { - task: "object-detection", - modelId: "aisak-ai/aisak-detect", - test: async (modelId: string) => { - const output = await client.objectDetection({ - model: modelId, - provider: "bytez", - inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), - }); - expect(output).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - label: expect.any(String), - box: expect.objectContaining({ - xmin: expect.any(Number), - ymin: expect.any(Number), - xmax: expect.any(Number), - ymax: expect.any(Number), - }), - }), - ]) - ); + ]); + }); + it("zeroShotClassification", async () => { + expect( + await hf.zeroShotClassification({ + model: "facebook/bart-large-mnli", + inputs: [ + "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!", + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any, + parameters: { candidate_labels: ["refund", "legal", "faq"] }, + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + sequence: + "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!", + labels: ["refund", "faq", "legal"], + scores: [ + expect.closeTo(0.877787709236145, 5), + expect.closeTo(0.10522633045911789, 5), + expect.closeTo(0.01698593981564045, 5), + ], + }), + ]) + ); + }); + it("sentenceSimilarity", async () => { + expect( + await hf.sentenceSimilarity({ + model: "sentence-transformers/paraphrase-xlm-r-multilingual-v1", + inputs: { + source_sentence: "That is a happy person", + sentences: ["That is a happy dog", "That is a very happy person", "Today is a sunny day"], }, - stream: false, - }, - { - task: "feature-extraction", - modelId: "allenai/specter2_base", - test: async (modelId: string) => { - const output = await client.featureExtraction({ - model: modelId, - provider: "bytez", - inputs: "That is a happy person", - }); - expect(output).toEqual(expect.arrayContaining([expect.any(Number)])); + }) + ).toEqual([expect.any(Number), expect.any(Number), expect.any(Number)]); + }); + it("FeatureExtraction", async () => { + const response = await hf.featureExtraction({ + model: "sentence-transformers/distilbert-base-nli-mean-tokens", + inputs: "That is a happy person", + }); + expect(response).toEqual(expect.arrayContaining([expect.any(Number)])); + }); + it("FeatureExtraction - auto-compatibility sentence similarity", async () => { + const response = await hf.featureExtraction({ + model: "sentence-transformers/paraphrase-xlm-r-multilingual-v1", + inputs: "That is a happy person", + }); + + expect(response.length).toBeGreaterThan(10); + expect(response).toEqual(expect.arrayContaining([expect.any(Number)])); + }); + it("FeatureExtraction - facebook/bart-base", async () => { + const response = await hf.featureExtraction({ + model: "facebook/bart-base", + inputs: "That is a happy person", + }); + // 1x7x768 + expect(response).toEqual([ + [ + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + ], + ]); + }); + it("FeatureExtraction - facebook/bart-base, list input", async () => { + const response = await hf.featureExtraction({ + model: "facebook/bart-base", + inputs: ["hello", "That is a happy person"], + }); + // Nx1xTx768 + expect(response).toEqual([ + [ + [ + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + ], + ], + [ + [ + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + expect.arrayContaining([expect.any(Number)]), + ], + ], + ]); + }); + it("automaticSpeechRecognition", async () => { + expect( + await hf.automaticSpeechRecognition({ + model: "facebook/wav2vec2-large-960h-lv60-self", + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }) + ).toMatchObject({ + text: "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP AUDIENCES IN DRAUGHTY SCHOOLROOMS DAY AFTER DAY FOR A FORTNIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS", + }); + }); + it("audioClassification", async () => { + expect( + await hf.audioClassification({ + model: "superb/hubert-large-superb-er", + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }); + + it("audioToAudio", async () => { + expect( + await hf.audioToAudio({ + model: "speechbrain/sepformer-wham", + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + label: expect.any(String), + blob: expect.any(String), + "content-type": expect.any(String), + }), + ]) + ); + }); + + it("textToSpeech", async () => { + expect( + await hf.textToSpeech({ + model: "espnet/kan-bayashi_ljspeech_vits", + inputs: "hello there!", + }) + ).toBeInstanceOf(Blob); + }); + + it("imageClassification", async () => { + expect( + await hf.imageClassification({ + data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + model: "google/vit-base-patch16-224", + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }); + + it("zeroShotImageClassification", async () => { + expect( + await hf.zeroShotImageClassification({ + inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }) }, + model: "openai/clip-vit-large-patch14-336", + parameters: { + candidate_labels: ["animal", "toy", "car"], }, - stream: false, - }, + }) + ).toEqual([ { - task: "sentence-similarity", - modelId: "embedding-data/distilroberta-base-sentence-transformer", - test: async (modelId: string) => { - const output = await client.sentenceSimilarity({ - model: modelId, - provider: "bytez", - inputs: { - source_sentence: "That is a happy person", - sentences: ["That is a happy dog", "That is a very happy person", "Today is a sunny day"], - }, - }); - expect(output).toEqual([expect.any(Number), expect.any(Number), expect.any(Number)]); - }, - stream: false, + label: "animal", + score: expect.any(Number), }, { - task: "fill-mask", - modelId: "almanach/camembert-base", - test: async (modelId: string) => { - const output = await client.fillMask({ model: modelId, provider: "bytez", inputs: "Hello " }); - expect(output).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - token: expect.any(Number), - token_str: expect.any(String), - sequence: expect.any(String), - }), - ]) - ); - }, - stream: false, + label: "car", + score: expect.any(Number), }, { - task: "text-classification", - modelId: "AdamCodd/distilbert-base-uncased-finetuned-sentiment-amazon", - test: async (modelId: string) => { - const output = await client.textClassification({ - model: modelId, - provider: "bytez", - inputs: "I am a special unicorn", - }); - expect(output.every((entry) => entry.label && entry.score)).toBe(true); - }, - stream: false, + label: "toy", + score: expect.any(Number), }, - { - task: "token-classification", - modelId: "2rtl3/mn-xlm-roberta-base-named-entity", - test: async (modelId: string) => { - const output = await client.tokenClassification({ - model: modelId, - provider: "bytez", - inputs: "John went to NYC", - }); - expect(output).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - entity_group: expect.any(String), - score: expect.any(Number), - word: expect.any(String), - start: expect.any(Number), - end: expect.any(Number), - }), - ]) - ); - }, - stream: false, + ]); + }); + + it("objectDetection", async () => { + expect( + await hf.objectDetection({ + data: new Blob([readTestFile("cats.png")], { type: "image/png" }), + model: "facebook/detr-resnet-50", + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + box: expect.objectContaining({ + xmin: expect.any(Number), + ymin: expect.any(Number), + xmax: expect.any(Number), + ymax: expect.any(Number), + }), + }), + ]) + ); + }); + it("imageSegmentation", async () => { + expect( + await hf.imageSegmentation({ + inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), + model: "facebook/detr-resnet-50-panoptic", + }) + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + mask: expect.any(String), + }), + ]) + ); + }); + it("imageToImage", async () => { + const num_inference_steps = 25; + + const res = await hf.imageToImage({ + inputs: new Blob([readTestFile("stormtrooper_depth.png")], { type: "image / png" }), + parameters: { + prompt: "elmo's lecture", + num_inference_steps, }, - { - task: "zero-shot-classification", - modelId: "AyoubChLin/DistilBERT_eco_ZeroShot", - test: async (modelId: string) => { - const testInput = "Ninja turtles are cool"; - const testCandidateLabels = ["positive", "negative"]; - const output = await client.zeroShotClassification({ - model: modelId, - provider: "bytez", - inputs: [ - testInput, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ] as any, - parameters: { candidate_labels: testCandidateLabels }, - }); - expect(output).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - sequence: testInput, - labels: testCandidateLabels, - scores: [expect.closeTo(0.5206031203269958, 5), expect.closeTo(0.479396790266037, 5)], - }), - ]) - ); - }, - stream: false, + model: "lllyasviel/sd-controlnet-depth", + }); + expect(res).toBeInstanceOf(Blob); + }); + it("imageToImage blob data", async () => { + const res = await hf.imageToImage({ + inputs: new Blob([readTestFile("bird_canny.png")], { type: "image / png" }), + model: "lllyasviel/sd-controlnet-canny", + }); + expect(res).toBeInstanceOf(Blob); + }); + it("textToImage", async () => { + const res = await hf.textToImage({ + inputs: + "award winning high resolution photo of a giant tortoise/((ladybird)) hybrid, [trending on artstation]", + model: "stabilityai/stable-diffusion-2", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it("textToImage with parameters", async () => { + const width = 512; + const height = 128; + const num_inference_steps = 10; + + const res = await hf.textToImage({ + inputs: + "award winning high resolution photo of a giant tortoise/((ladybird)) hybrid, [trending on artstation]", + model: "stabilityai/stable-diffusion-2", + parameters: { + negative_prompt: "blurry", + width, + height, + num_inference_steps, }, + }); + expect(res).toBeInstanceOf(Blob); + }); + it("textToImage with json output", async () => { + const res = await hf.textToImage({ + inputs: "a giant tortoise", + model: "stabilityai/stable-diffusion-2", + outputType: "json", + }); + expect(res).toMatchObject({ + output: expect.any(String), + }); + }); + it("imageToText", async () => { + expect( + await hf.imageToText({ + data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + model: "nlpconnect/vit-gpt2-image-captioning", + }) + ).toEqual({ + generated_text: "a large brown and white giraffe standing in a field ", + }); + }); + + /// Skipping because the function is deprecated + it.skip("request - openai-community/gpt2", async () => { + expect( + await hf.request({ + model: "openai-community/gpt2", + inputs: "one plus two equals", + }) + ).toMatchObject([ { - task: "audio-classification", - modelId: "aaraki/wav2vec2-base-finetuned-ks", - test: async (modelId: string) => { - const output = await client.audioClassification({ - model: modelId, - provider: "bytez", - data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), - }); - expect(output).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - label: expect.any(String), - }), - ]) - ); - }, - stream: false, + generated_text: expect.any(String), }, - { - task: "text-to-speech", - modelId: "facebook/mms-tts-eng", - test: async (modelId: string) => { - const output = await client.textToSpeech({ model: modelId, provider: "bytez", inputs: "Hello" }); - expect(output).toBeInstanceOf(Blob); + ]); + }); + + // Skipped at the moment because takes forever + it.skip("tabularRegression", async () => { + expect( + await hf.tabularRegression({ + model: "scikit-learn/Fish-Weight", + inputs: { + data: { + Height: ["11.52", "12.48", "12.3778"], + Length1: ["23.2", "24", "23.9"], + Length2: ["25.4", "26.3", "26.5"], + Length3: ["30", "31.2", "31.1"], + Species: ["Bream", "Bream", "Bream"], + Width: ["4.02", "4.3056", "4.6961"], + }, }, - stream: false, - }, - { - task: "automatic-speech-recognition", - modelId: "facebook/data2vec-audio-base-960h", - test: async (modelId: string) => { - const output = await client.automaticSpeechRecognition({ - model: modelId, - provider: "bytez", - data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), - }); - expect(output).toMatchObject({ - text: "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP AUDIENCES IN DRAUGHTY SCHOOLROOMS DAY AFTER DAY FOR A FORTNIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS", - }); + }) + ).toMatchObject([270.5473526976245, 313.6843425638086, 328.3727133404402]); + }); + + // Skipped at the moment because takes forever + it.skip("tabularClassification", async () => { + expect( + await hf.tabularClassification({ + model: "vvmnnnkv/wine-quality", + inputs: { + data: { + fixed_acidity: ["7.4", "7.8", "10.3"], + volatile_acidity: ["0.7", "0.88", "0.32"], + citric_acid: ["0", "0", "0.45"], + residual_sugar: ["1.9", "2.6", "6.4"], + chlorides: ["0.076", "0.098", "0.073"], + free_sulfur_dioxide: ["11", "25", "5"], + total_sulfur_dioxide: ["34", "67", "13"], + density: ["0.9978", "0.9968", "0.9976"], + pH: ["3.51", "3.2", "3.23"], + sulphates: ["0.56", "0.68", "0.82"], + alcohol: ["9.4", "9.8", "12.6"], + }, }, - stream: false, + }) + ).toMatchObject([5, 5, 7]); + }); + + it("endpoint - makes request to specified endpoint", async () => { + const ep = hf.endpoint("https://router.huggingface.co/hf-inference/models/openai-community/gpt2"); + const { generated_text } = await ep.textGeneration({ + inputs: "one plus one is equal to", + parameters: { + max_new_tokens: 1, }, - ]; + }); + assert.include(generated_text, "two"); + }); - // bootstrap the inference mappings for testing - for (const { task, modelId } of tests) { - HARDCODED_MODEL_INFERENCE_MAPPING.bytez[modelId] = { - provider: "bytez", - hfModelId: modelId, - providerId: modelId, - status: "live", - task, - adapter: undefined, - adapterWeightsPath: undefined, - }; - } + it("endpoint - makes request to specified endpoint - alternative syntax", async () => { + const epClient = new InferenceClient(env.HF_TOKEN, { + endpointUrl: "https://router.huggingface.co/hf-inference/models/openai-community/gpt2", + }); + const { generated_text } = await epClient.textGeneration({ + inputs: "one plus one is equal to", + parameters: { + max_new_tokens: 1, + }, + }); + assert.include(generated_text, "two"); + }); - // run the tests - for (const { task, modelId, test, stream } of tests) { - const testName = `${task} - ${modelId}${stream ? " stream" : ""}`; - it(testName, async () => { - await test(modelId); + it("chatCompletion modelId - OpenAI Specs", async () => { + const res = await hf.chatCompletion({ + model: "mistralai/Mistral-7B-Instruct-v0.2", + messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }], + max_tokens: 500, + temperature: 0.1, + seed: 0, }); - } + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("to two"); + } + }); + + it("chatCompletionStream modelId - OpenAI Specs", async () => { + const stream = hf.chatCompletionStream({ + model: "mistralai/Mistral-7B-Instruct-v0.2", + messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }], + max_tokens: 500, + temperature: 0.1, + seed: 0, + }); + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("2"); + }); + + it.skip("chatCompletionStream modelId Fail - OpenAI Specs", async () => { + expect( + hf + .chatCompletionStream({ + model: "google/gemma-2b", + messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }], + max_tokens: 500, + temperature: 0.1, + seed: 0, + }) + .next() + ).rejects.toThrowError( + "Server google/gemma-2b does not seem to support chat completion. Error: Template error: template not found" + ); + }); + + it("chatCompletion - OpenAI Specs", async () => { + const ep = hf.endpoint("https://router.huggingface.co/hf-inference/models/mistralai/Mistral-7B-Instruct-v0.2"); + const res = await ep.chatCompletion({ + model: "tgi", + messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }], + max_tokens: 500, + temperature: 0.1, + seed: 0, + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("to two"); + } + }); + it("chatCompletionStream - OpenAI Specs", async () => { + const ep = hf.endpoint("https://router.huggingface.co/hf-inference/models/mistralai/Mistral-7B-Instruct-v0.2"); + const stream = ep.chatCompletionStream({ + model: "tgi", + messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }], + max_tokens: 500, + temperature: 0.1, + seed: 0, + }); + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("2"); + }); + it("custom mistral - OpenAI Specs", async () => { + const MISTRAL_KEY = env.MISTRAL_KEY; + const hf = new InferenceClient(MISTRAL_KEY); + const ep = hf.endpoint("https://api.mistral.ai"); + const stream = ep.chatCompletionStream({ + model: "mistral-tiny", + messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("The answer to one + one is two."); + }); + it("custom openai - OpenAI Specs", async () => { + const OPENAI_KEY = env.OPENAI_KEY; + const hf = new InferenceClient(OPENAI_KEY); + const stream = hf.chatCompletionStream({ + provider: "openai", + model: "openai/gpt-3.5-turbo", + messages: [{ role: "user", content: "Complete the equation one + one =" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("two"); + }); + it("OpenAI client side routing - model should have provider as prefix", async () => { + await expect( + new InferenceClient("dummy_token").chatCompletion({ + model: "gpt-3.5-turbo", // must be "openai/gpt-3.5-turbo" + provider: "openai", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }) + ).rejects.toThrowError(`Models from openai must be prefixed by "openai/". Got "gpt-3.5-turbo".`); + }); }, TIMEOUT ); - describe("backward compatibility", () => { - it("works with old HfInference name", async () => { - const hf = new HfInference(env.HF_TOKEN); - expect("fillMask" in hf).toBe(true); - }); - }); - + /** + * Compatibility with third-party Inference Providers + */ describe.concurrent( - "HF Inference", + "Fal AI", () => { - const hf = new InferenceClient(env.HF_TOKEN); - HARDCODED_MODEL_INFERENCE_MAPPING["hf-inference"] = { - "google-bert/bert-base-uncased": { - provider: "hf-inference", - providerId: "google-bert/bert-base-uncased", - hfModelId: "google-bert/bert-base-uncased", - task: "fill-mask", - status: "live", - }, - "google/pegasus-xsum": { - provider: "hf-inference", - providerId: "google/pegasus-xsum", - hfModelId: "google/pegasus-xsum", - task: "summarization", - status: "live", - }, - "deepset/roberta-base-squad2": { - provider: "hf-inference", - providerId: "deepset/roberta-base-squad2", - hfModelId: "deepset/roberta-base-squad2", - task: "question-answering", - status: "live", - }, - "google/tapas-base-finetuned-wtq": { - provider: "hf-inference", - providerId: "google/tapas-base-finetuned-wtq", - hfModelId: "google/tapas-base-finetuned-wtq", - task: "table-question-answering", - status: "live", - }, - "mistralai/Mistral-7B-Instruct-v0.2": { - provider: "hf-inference", - providerId: "mistralai/Mistral-7B-Instruct-v0.2", - hfModelId: "mistralai/Mistral-7B-Instruct-v0.2", - task: "text-generation", - status: "live", - }, - "impira/layoutlm-document-qa": { - provider: "hf-inference", - providerId: "impira/layoutlm-document-qa", - hfModelId: "impira/layoutlm-document-qa", - task: "document-question-answering", - status: "live", - }, - "naver-clova-ix/donut-base-finetuned-docvqa": { - provider: "hf-inference", - providerId: "naver-clova-ix/donut-base-finetuned-docvqa", - hfModelId: "naver-clova-ix/donut-base-finetuned-docvqa", - task: "document-question-answering", - status: "live", - }, - "google/tapas-large-finetuned-wtq": { - provider: "hf-inference", - providerId: "google/tapas-large-finetuned-wtq", - hfModelId: "google/tapas-large-finetuned-wtq", - task: "table-question-answering", - status: "live", - }, - "facebook/detr-resnet-50": { - provider: "hf-inference", - providerId: "facebook/detr-resnet-50", - hfModelId: "facebook/detr-resnet-50", - task: "object-detection", - status: "live", - }, - "facebook/detr-resnet-50-panoptic": { - provider: "hf-inference", - providerId: "facebook/detr-resnet-50-panoptic", - hfModelId: "facebook/detr-resnet-50-panoptic", - task: "image-segmentation", - status: "live", - }, - "facebook/wav2vec2-large-960h-lv60-self": { - provider: "hf-inference", - providerId: "facebook/wav2vec2-large-960h-lv60-self", - hfModelId: "facebook/wav2vec2-large-960h-lv60-self", - task: "automatic-speech-recognition", + const client = new InferenceClient(env.HF_FAL_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["fal-ai"] = { + "openfree/flux-chatgpt-ghibli-lora": { + provider: "fal-ai", + hfModelId: "openfree/flux-chatgpt-ghibli-lora", + providerId: "fal-ai/flux-lora", status: "live", + task: "text-to-image", + adapter: "lora", + adapterWeightsPath: "flux-chatgpt-ghibli-lora.safetensors", }, - "superb/hubert-large-superb-er": { - provider: "hf-inference", - providerId: "superb/hubert-large-superb-er", - hfModelId: "superb/hubert-large-superb-er", - task: "audio-classification", - status: "live", - }, - "speechbrain/sepformer-wham": { - provider: "hf-inference", - providerId: "speechbrain/sepformer-wham", - hfModelId: "speechbrain/sepformer-wham", - task: "audio-to-audio", - status: "live", - }, - "espnet/kan-bayashi_ljspeech_vits": { - provider: "hf-inference", - providerId: "espnet/kan-bayashi_ljspeech_vits", - hfModelId: "espnet/kan-bayashi_ljspeech_vits", - task: "text-to-speech", - status: "live", - }, - "sentence-transformers/paraphrase-xlm-r-multilingual-v1": { - provider: "hf-inference", - providerId: "sentence-transformers/paraphrase-xlm-r-multilingual-v1", - hfModelId: "sentence-transformers/paraphrase-xlm-r-multilingual-v1", - task: "sentence-similarity", - status: "live", - }, - "sentence-transformers/distilbert-base-nli-mean-tokens": { - provider: "hf-inference", - providerId: "sentence-transformers/distilbert-base-nli-mean-tokens", - hfModelId: "sentence-transformers/distilbert-base-nli-mean-tokens", - task: "feature-extraction", - status: "live", - }, - "facebook/bart-base": { - provider: "hf-inference", - providerId: "facebook/bart-base", - hfModelId: "facebook/bart-base", - task: "feature-extraction", - status: "live", - }, - "facebook/bart-large-mnli": { - provider: "hf-inference", - providerId: "facebook/bart-large-mnli", - hfModelId: "facebook/bart-large-mnli", - task: "zero-shot-classification", - status: "live", - }, - "facebook/bart-large-cnn": { - provider: "hf-inference", - providerId: "facebook/bart-large-cnn", - hfModelId: "facebook/bart-large-cnn", - task: "summarization", - status: "live", - }, - "facebook/bart-large-xsum": { - provider: "hf-inference", - providerId: "facebook/bart-large-xsum", - hfModelId: "facebook/bart-large-xsum", - task: "summarization", - status: "live", - }, - "stabilityai/stable-diffusion-2": { - provider: "hf-inference", - providerId: "stabilityai/stable-diffusion-2", - hfModelId: "stabilityai/stable-diffusion-2", - task: "text-to-image", - status: "live", - }, - "lllyasviel/sd-controlnet-canny": { - provider: "hf-inference", - providerId: "lllyasviel/sd-controlnet-canny", - hfModelId: "lllyasviel/sd-controlnet-canny", - task: "image-to-image", - status: "live", - }, - "lllyasviel/sd-controlnet-depth": { - provider: "hf-inference", - providerId: "lllyasviel/sd-controlnet-depth", - hfModelId: "lllyasviel/sd-controlnet-depth", - task: "image-to-image", - status: "live", - }, - "t5-base": { - provider: "hf-inference", - providerId: "t5-base", - hfModelId: "t5-base", - task: "translation", - status: "live", - }, - "openai/clip-vit-large-patch14-336": { - provider: "hf-inference", - providerId: "openai/clip-vit-large-patch14-336", - hfModelId: "openai/clip-vit-large-patch14-336", - task: "zero-shot-image-classification", - status: "live", - }, - "google/vit-base-patch16-224": { - provider: "hf-inference", - providerId: "google/vit-base-patch16-224", - hfModelId: "google/vit-base-patch16-224", - task: "image-classification", - status: "live", - }, - "dandelin/vilt-b32-finetuned-vqa": { - provider: "hf-inference", - providerId: "dandelin/vilt-b32-finetuned-vqa", - hfModelId: "dandelin/vilt-b32-finetuned-vqa", - task: "visual-question-answering", - status: "live", - }, - "dbmdz/bert-large-cased-finetuned-conll03-english": { - provider: "hf-inference", - providerId: "dbmdz/bert-large-cased-finetuned-conll03-english", - hfModelId: "dbmdz/bert-large-cased-finetuned-conll03-english", - task: "token-classification", - status: "live", - }, - "nlpconnect/vit-gpt2-image-captioning": { - provider: "hf-inference", - providerId: "nlpconnect/vit-gpt2-image-captioning", - hfModelId: "nlpconnect/vit-gpt2-image-captioning", - task: "image-to-text", + "nerijs/pixel-art-xl": { + provider: "fal-ai", + hfModelId: "nerijs/pixel-art-xl", + providerId: "fal-ai/lora", status: "live", + task: "text-to-image", + adapter: "lora", + adapterWeightsPath: "pixel-art-xl.safetensors", }, }; - it("throws error if model does not exist", () => { - expect( - hf.fillMask({ - model: "this-model/does-not-exist-123", - inputs: "[MASK] world!", - }) - ).rejects.toThrowError("Model this-model/does-not-exist-123 does not exist"); - }); - - it("fillMask", async () => { - expect( - await hf.fillMask({ - model: "google-bert/bert-base-uncased", - inputs: "[MASK] world!", - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - token: expect.any(Number), - token_str: expect.any(String), - sequence: expect.any(String), - }), - ]) - ); + it(`textToImage - black-forest-labs/FLUX.1-schnell`, async () => { + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", + provider: "fal-ai", + inputs: + "Extreme close-up of a single tiger eye, direct frontal view. Detailed iris and pupil. Sharp focus on eye texture and color. Natural lighting to capture authentic eye shine and depth.", + }); + expect(res).toBeInstanceOf(Blob); }); - it.skip("works without model", async () => { - expect( - await hf.fillMask({ - inputs: "[MASK] world!", - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - token: expect.any(Number), - token_str: expect.any(String), - sequence: expect.any(String), - }), - ]) - ); + /// Skipped: we need a way to pass the base model ID + it(`textToImage - SD LoRAs`, async () => { + const res = await client.textToImage({ + model: "nerijs/pixel-art-xl", + provider: "fal-ai", + inputs: "pixel, a cute corgi", + parameters: { + negative_prompt: "3d render, realistic", + }, + }); + expect(res).toBeInstanceOf(Blob); }); - it("summarization", async () => { - expect( - await hf.summarization({ - model: "google/pegasus-xsum", - inputs: - "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.", - parameters: { - max_length: 100, - }, - }) - ).toEqual({ - summary_text: "The Eiffel Tower is one of the most famous buildings in the world.", + it(`textToImage - Flux LoRAs`, async () => { + const res = await client.textToImage({ + model: "openfree/flux-chatgpt-ghibli-lora", + provider: "fal-ai", + inputs: + "Ghibli style sky whale transport ship, its metallic skin adorned with traditional Japanese patterns, gliding through cotton candy clouds at sunrise. Small floating gardens hang from its sides, where workers in futuristic kimonos tend to glowing plants. Rainbow auroras shimmer in the background. [trigger]", }); + expect(res).toBeInstanceOf(Blob); }); - it("questionAnswering", async () => { - expect( - await hf.questionAnswering({ - model: "deepset/roberta-base-squad2", - inputs: { - question: "What is the capital of France?", - context: "The capital of France is Paris.", - }, - }) - ).toMatchObject({ - answer: "Paris", - score: expect.any(Number), - start: expect.any(Number), - end: expect.any(Number), + it(`automaticSpeechRecognition - openai/whisper-large-v3`, async () => { + const res = await client.automaticSpeechRecognition({ + model: "openai/whisper-large-v3", + provider: "fal-ai", + data: new Blob([readTestFile("sample2.wav")], { type: "audio/x-wav" }), + }); + expect(res).toMatchObject({ + text: " he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca", }); }); - - it("tableQuestionAnswering", async () => { - expect( - await hf.tableQuestionAnswering({ - model: "google/tapas-base-finetuned-wtq", - inputs: { - question: "How many stars does the transformers repository have?", - table: { - Repository: ["Transformers", "Datasets", "Tokenizers"], - Stars: ["36542", "4512", "3934"], - Contributors: ["651", "77", "34"], - "Programming language": ["Python", "Python", "Rust, Python and NodeJS"], - }, - }, - }) - ).toMatchObject({ - answer: "AVERAGE > 36542", - coordinates: [[0, 1]], - cells: ["36542"], - aggregator: "AVERAGE", + it("imageToVideo - fal-ai", async () => { + const res = await client.imageToVideo({ + model: "fal-ai/ltxv-13b-098-distilled/image-to-video", + provider: "fal-ai", + inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), + parameters: { + prompt: "The cats are jumping around in a playful manner", + }, }); + expect(res).toBeInstanceOf(Blob); }); + }, + TIMEOUT + ); - it("documentQuestionAnswering", async () => { - expect( - await hf.documentQuestionAnswering({ - model: "impira/layoutlm-document-qa", - inputs: { - question: "Invoice number?", - image: new Blob([readTestFile("invoice.png")], { type: "image/png" }), - }, - }) - ).toMatchObject({ - answer: "us-001", - score: expect.any(Number), - // not sure what start/end refers to in this case - start: expect.any(Number), - end: expect.any(Number), + describe.concurrent( + "Featherless", + () => { + HARDCODED_MODEL_INFERENCE_MAPPING["featherless-ai"] = { + "meta-llama/Llama-3.1-8B": { + provider: "featherless-ai", + providerId: "meta-llama/Meta-Llama-3.1-8B", + hfModelId: "meta-llama/Llama-3.1-8B", + task: "text-generation", + status: "live", + }, + "meta-llama/Llama-3.1-8B-Instruct": { + provider: "featherless-ai", + providerId: "meta-llama/Meta-Llama-3.1-8B-Instruct", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + task: "text-generation", + status: "live", + }, + }; + + it("chatCompletion", async () => { + const res = await chatCompletion({ + accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "featherless-ai", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + temperature: 0.1, }); + + expect(res).toBeDefined(); + expect(res.choices).toBeDefined(); + expect(res.choices?.length).toBeGreaterThan(0); + + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toBeDefined(); + expect(typeof completion).toBe("string"); + expect(completion).toContain("two"); + } }); - // Errors with "Error: If you are using a VisionEncoderDecoderModel, you must provide a feature extractor" - it.skip("documentQuestionAnswering with non-array output", async () => { - expect( - await hf.documentQuestionAnswering({ - model: "naver-clova-ix/donut-base-finetuned-docvqa", - inputs: { - question: "Invoice number?", - image: new Blob([readTestFile("invoice.png")], { type: "image/png" }), - }, - }) - ).toMatchObject({ - answer: "us-001", + it("chatCompletion stream", async () => { + const stream = chatCompletionStream({ + accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "featherless-ai", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("2"); + }); + + it("textGeneration", async () => { + const res = await textGeneration({ + accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", + model: "meta-llama/Llama-3.1-8B", + provider: "featherless-ai", + inputs: "Paris is a city of ", + parameters: { + temperature: 0, + top_p: 0.01, + max_tokens: 10, + }, }); + expect(res).toMatchObject({ generated_text: "2.2 million people, and it is the" }); }); + }, + TIMEOUT + ); - it("visualQuestionAnswering", async () => { - expect( - await hf.visualQuestionAnswering({ - model: "dandelin/vilt-b32-finetuned-vqa", - inputs: { - question: "How many cats are lying down?", - image: new Blob([readTestFile("cats.png")], { type: "image/png" }), - }, - }) - ).toMatchObject({ - answer: "2", - score: expect.any(Number), + describe.concurrent( + "Replicate", + () => { + const client = new InferenceClient(env.HF_REPLICATE_KEY ?? "dummy"); + + it("textToImage canonical - black-forest-labs/FLUX.1-schnell", async () => { + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", + provider: "replicate", + inputs: "black forest gateau cake spelling out the words FLUX SCHNELL, tasty, food photography, dynamic shot", }); + expect(res).toBeInstanceOf(Blob); }); - it("textClassification", async () => { - expect( - await hf.textClassification({ - model: "distilbert-base-uncased-finetuned-sst-2-english", - inputs: "I like you. I love you.", - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - label: expect.any(String), - score: expect.any(Number), - }), - ]) - ); + it("textToImage canonical - black-forest-labs/FLUX.1-dev", async () => { + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-dev", + provider: "replicate", + inputs: + "A tiny laboratory deep in the Black Forest where squirrels in lab coats experiment with mixing chocolate and pine cones", + }); + expect(res).toBeInstanceOf(Blob); }); - it.skip("textGeneration - gpt2", async () => { - expect( - await hf.textGeneration({ - model: "gpt2", - inputs: "The answer to the universe is", - }) - ).toMatchObject({ - generated_text: expect.any(String), + // Runs black-forest-labs/flux-dev-lora under the hood + // with fofr/flux-80s-cyberpunk as the LoRA weights + it("textToImage - all Flux LoRAs", async () => { + const res = await client.textToImage({ + model: "fofr/flux-80s-cyberpunk", + provider: "replicate", + inputs: "style of 80s cyberpunk, a portrait photo", }); + expect(res).toBeInstanceOf(Blob); }); - it.skip("textGeneration - openai-community/gpt2", async () => { - expect( - await hf.textGeneration({ - model: "openai-community/gpt2", - inputs: "The answer to the universe is", - }) - ).toMatchObject({ - generated_text: expect.any(String), + it("textToImage canonical - stabilityai/stable-diffusion-3.5-large-turbo", async () => { + const res = await client.textToImage({ + model: "stabilityai/stable-diffusion-3.5-large-turbo", + provider: "replicate", + inputs: "A confused rubber duck wearing a tiny wizard hat trying to cast spells with a banana wand", }); + expect(res).toBeInstanceOf(Blob); }); - it("textGenerationStream - meta-llama/Llama-3.2-3B", async () => { - const response = hf.textGenerationStream({ - model: "meta-llama/Llama-3.2-3B", - inputs: "Please answer the following question: complete one two and ____.", - parameters: { - max_new_tokens: 50, - seed: 0, - }, + it("textToImage versioned - ByteDance/SDXL-Lightning", async () => { + const res = await client.textToImage({ + model: "ByteDance/SDXL-Lightning", + provider: "replicate", + inputs: "A grumpy storm cloud wearing sunglasses and throwing tiny lightning bolts like confetti", }); + expect(res).toBeInstanceOf(Blob); + }); - for await (const ret of response) { - expect(ret).toMatchObject({ - details: null, - index: expect.any(Number), - token: { - id: expect.any(Number), - logprob: expect.any(Number), - text: expect.any(String) || null, - special: expect.any(Boolean), - }, - generated_text: ret.generated_text - ? "Please answer the following question: complete one two and ____. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17" - : null, - }); - } + it("textToImage versioned - ByteDance/Hyper-SD", async () => { + const res = await client.textToImage({ + model: "ByteDance/Hyper-SD", + provider: "replicate", + inputs: "A group of dancing bytes wearing tiny party hats doing the macarena in cyberspace", + }); + expect(res).toBeInstanceOf(Blob); }); - it("textGenerationStream - catch error", async () => { - const response = hf.textGenerationStream({ - model: "meta-llama/Llama-3.2-3B", - inputs: "Write a short story about a robot that becomes sentient and takes over the world.", - parameters: { - max_new_tokens: 10_000, - }, + it("textToImage versioned - playgroundai/playground-v2.5-1024px-aesthetic", async () => { + const res = await client.textToImage({ + model: "playgroundai/playground-v2.5-1024px-aesthetic", + provider: "replicate", + inputs: "A playground where slides turn into rainbows and swings launch kids into cotton candy clouds", }); + expect(res).toBeInstanceOf(Blob); + }); - await expect(response.next()).rejects.toThrow( - "Error forwarded from backend: Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 17 `inputs` tokens and 10000 `max_new_tokens`" - ); + it("textToImage versioned - stabilityai/stable-diffusion-xl-base-1.0", async () => { + const res = await client.textToImage({ + model: "stabilityai/stable-diffusion-xl-base-1.0", + provider: "replicate", + inputs: "An octopus juggling watermelons underwater while wearing scuba gear", + }); + expect(res).toBeInstanceOf(Blob); }); - it.skip("textGenerationStream - Abort", async () => { - const controller = new AbortController(); - const response = hf.textGenerationStream( - { - model: "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", - inputs: "Write an essay about Sartre's philosophy.", - parameters: { - max_new_tokens: 100, - }, - }, - { signal: controller.signal } - ); - await expect(response.next()).resolves.toBeDefined(); - await expect(response.next()).resolves.toBeDefined(); - controller.abort(); - await expect(response.next()).rejects.toThrow("The operation was aborted"); + it.skip("textToSpeech versioned", async () => { + const res = await client.textToSpeech({ + model: "SWivid/F5-TTS", + provider: "replicate", + inputs: "Hello, how are you?", + }); + expect(res).toBeInstanceOf(Blob); }); - it("tokenClassification", async () => { - expect( - await hf.tokenClassification({ - model: "dbmdz/bert-large-cased-finetuned-conll03-english", - inputs: "My name is Sarah Jessica Parker but you can call me Jessica", - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - entity_group: expect.any(String), - score: expect.any(Number), - word: expect.any(String), - start: expect.any(Number), - end: expect.any(Number), - }), - ]) - ); + it.skip("textToSpeech OuteTTS - usually Cold", async () => { + const res = await client.textToSpeech({ + model: "OuteAI/OuteTTS-0.3-500M", + provider: "replicate", + inputs: "OuteTTS is a frontier TTS model for its size of 1 Billion parameters", + }); + + expect(res).toBeInstanceOf(Blob); }); - it("translation", async () => { - expect( - await hf.translation({ - model: "t5-base", - inputs: "My name is Wolfgang and I live in Berlin", - }) - ).toMatchObject({ - translation_text: "Mein Name ist Wolfgang und ich lebe in Berlin", + it("textToSpeech Kokoro", async () => { + const res = await client.textToSpeech({ + model: "hexgrad/Kokoro-82M", + provider: "replicate", + inputs: "Kokoro is a frontier TTS model for its size of 1 Billion parameters", }); - // input is a list - expect( - await hf.translation({ - model: "t5-base", - // eslint-disable-next-line @typescript-eslint/no-explicit-any - inputs: ["My name is Wolfgang and I live in Berlin", "I work as programmer"] as any, - }) - ).toMatchObject([ - { - translation_text: "Mein Name ist Wolfgang und ich lebe in Berlin", - }, - { - translation_text: "Ich arbeite als Programmierer", - }, - ]); - }); - it("zeroShotClassification", async () => { - expect( - await hf.zeroShotClassification({ - model: "facebook/bart-large-mnli", - inputs: [ - "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!", - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ] as any, - parameters: { candidate_labels: ["refund", "legal", "faq"] }, - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - sequence: - "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!", - labels: ["refund", "faq", "legal"], - scores: [ - expect.closeTo(0.877787709236145, 5), - expect.closeTo(0.10522633045911789, 5), - expect.closeTo(0.01698593981564045, 5), - ], - }), - ]) - ); - }); - it("sentenceSimilarity", async () => { - expect( - await hf.sentenceSimilarity({ - model: "sentence-transformers/paraphrase-xlm-r-multilingual-v1", - inputs: { - source_sentence: "That is a happy person", - sentences: ["That is a happy dog", "That is a very happy person", "Today is a sunny day"], - }, - }) - ).toEqual([expect.any(Number), expect.any(Number), expect.any(Number)]); + + expect(res).toBeInstanceOf(Blob); }); - it("FeatureExtraction", async () => { - const response = await hf.featureExtraction({ - model: "sentence-transformers/distilbert-base-nli-mean-tokens", - inputs: "That is a happy person", + + it("imageToImage - FLUX Kontext Dev", async () => { + const res = await client.imageToImage({ + model: "black-forest-labs/flux-kontext-dev", + provider: "replicate", + inputs: new Blob([readTestFile("stormtrooper_depth.png")], { type: "image/png" }), + parameters: { + prompt: "Change the stormtrooper armor to golden color while keeping the same pose and helmet design", + }, }); - expect(response).toEqual(expect.arrayContaining([expect.any(Number)])); + expect(res).toBeInstanceOf(Blob); }); - it("FeatureExtraction - auto-compatibility sentence similarity", async () => { - const response = await hf.featureExtraction({ - model: "sentence-transformers/paraphrase-xlm-r-multilingual-v1", - inputs: "That is a happy person", - }); + }, + TIMEOUT + ); + describe.concurrent( + "SambaNova", + () => { + const client = new InferenceClient(env.HF_SAMBANOVA_KEY ?? "dummy"); - expect(response.length).toBeGreaterThan(10); - expect(response).toEqual(expect.arrayContaining([expect.any(Number)])); - }); - it("FeatureExtraction - facebook/bart-base", async () => { - const response = await hf.featureExtraction({ - model: "facebook/bart-base", - inputs: "That is a happy person", + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "sambanova", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], }); - // 1x7x768 - expect(response).toEqual([ - [ - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - ], - ]); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } }); - it("FeatureExtraction - facebook/bart-base, list input", async () => { - const response = await hf.featureExtraction({ - model: "facebook/bart-base", - inputs: ["hello", "That is a happy person"], - }); - // Nx1xTx768 - expect(response).toEqual([ - [ - [ - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - ], - ], - [ - [ - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - expect.arrayContaining([expect.any(Number)]), - ], - ], - ]); + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "sambanova", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("2"); }); - it("automaticSpeechRecognition", async () => { - expect( - await hf.automaticSpeechRecognition({ - model: "facebook/wav2vec2-large-960h-lv60-self", - data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), - }) - ).toMatchObject({ - text: "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP AUDIENCES IN DRAUGHTY SCHOOLROOMS DAY AFTER DAY FOR A FORTNIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS", + it("featureExtraction", async () => { + const res = await client.featureExtraction({ + model: "intfloat/e5-mistral-7b-instruct", + provider: "sambanova", + inputs: "Today is a sunny day and I will get some ice cream.", }); + expect(res).toBeInstanceOf(Array); + expect(res[0]).toBeInstanceOf(Array); }); - it("audioClassification", async () => { - expect( - await hf.audioClassification({ - model: "superb/hubert-large-superb-er", - data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - label: expect.any(String), - }), - ]) - ); - }); + }, + TIMEOUT + ); - it("audioToAudio", async () => { - expect( - await hf.audioToAudio({ - model: "speechbrain/sepformer-wham", - data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - label: expect.any(String), - blob: expect.any(String), - "content-type": expect.any(String), - }), - ]) - ); - }); + describe.concurrent( + "Together", + () => { + const client = new InferenceClient(env.HF_TOGETHER_KEY ?? "dummy"); - it("textToSpeech", async () => { - expect( - await hf.textToSpeech({ - model: "espnet/kan-bayashi_ljspeech_vits", - inputs: "hello there!", - }) - ).toBeInstanceOf(Blob); + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "together", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } }); - it("imageClassification", async () => { - expect( - await hf.imageClassification({ - data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), - model: "google/vit-base-patch16-224", - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - label: expect.any(String), - }), - ]) - ); + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "together", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("2"); }); - it("zeroShotImageClassification", async () => { - expect( - await hf.zeroShotImageClassification({ - inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }) }, - model: "openai/clip-vit-large-patch14-336", - parameters: { - candidate_labels: ["animal", "toy", "car"], - }, - }) - ).toEqual([ - { - label: "animal", - score: expect.any(Number), - }, - { - label: "car", - score: expect.any(Number), - }, - { - label: "toy", - score: expect.any(Number), - }, - ]); + it("textToImage", async () => { + const res = await client.textToImage({ + model: "stabilityai/stable-diffusion-xl-base-1.0", + provider: "together", + inputs: "award winning high resolution photo of a giant tortoise", + }); + expect(res).toBeInstanceOf(Blob); }); - it("objectDetection", async () => { - expect( - await hf.objectDetection({ - data: new Blob([readTestFile("cats.png")], { type: "image/png" }), - model: "facebook/detr-resnet-50", - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - label: expect.any(String), - box: expect.objectContaining({ - xmin: expect.any(Number), - ymin: expect.any(Number), - xmax: expect.any(Number), - ymax: expect.any(Number), - }), - }), - ]) - ); - }); - it("imageSegmentation", async () => { - expect( - await hf.imageSegmentation({ - inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), - model: "facebook/detr-resnet-50-panoptic", - }) - ).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - score: expect.any(Number), - label: expect.any(String), - mask: expect.any(String), - }), - ]) - ); + it("textGeneration", async () => { + const res = await client.textGeneration({ + model: "mistralai/Mixtral-8x7B-v0.1", + provider: "together", + inputs: "Paris is", + temperature: 0, + max_tokens: 10, + }); + expect(res).toMatchObject({ generated_text: " a city of love, and it’s also" }); }); - it("imageToImage", async () => { - const num_inference_steps = 25; + }, + TIMEOUT + ); - const res = await hf.imageToImage({ - inputs: new Blob([readTestFile("stormtrooper_depth.png")], { type: "image / png" }), - parameters: { - prompt: "elmo's lecture", - num_inference_steps, - }, - model: "lllyasviel/sd-controlnet-depth", + describe.concurrent( + "Nebius", + () => { + const client = new InferenceClient(env.HF_NEBIUS_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING.nebius = { + "meta-llama/Llama-3.1-8B-Instruct": { + provider: "nebius", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + providerId: "meta-llama/Meta-Llama-3.1-8B-Instruct", + status: "live", + task: "conversational", + }, + "meta-llama/Llama-3.1-70B-Instruct": { + provider: "nebius", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + providerId: "meta-llama/Meta-Llama-3.1-70B-Instruct", + status: "live", + task: "conversational", + }, + "black-forest-labs/FLUX.1-schnell": { + provider: "nebius", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + providerId: "black-forest-labs/flux-schnell", + status: "live", + task: "text-to-image", + }, + "BAAI/bge-multilingual-gemma2": { + provider: "nebius", + providerId: "BAAI/bge-multilingual-gemma2", + hfModelId: "BAAI/bge-multilingual-gemma2", + status: "live", + task: "feature-extraction", + }, + "mistralai/Devstral-Small-2505": { + provider: "nebius", + providerId: "mistralai/Devstral-Small-2505", + hfModelId: "mistralai/Devstral-Small-2505", + status: "live", + task: "text-generation", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "nebius", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], }); - expect(res).toBeInstanceOf(Blob); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toMatch(/(two|2)/i); + } }); - it("imageToImage blob data", async () => { - const res = await hf.imageToImage({ - inputs: new Blob([readTestFile("bird_canny.png")], { type: "image / png" }), - model: "lllyasviel/sd-controlnet-canny", - }); - expect(res).toBeInstanceOf(Blob); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.1-70B-Instruct", + provider: "nebius", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toMatch(/(two|2)/i); }); + it("textToImage", async () => { - const res = await hf.textToImage({ - inputs: - "award winning high resolution photo of a giant tortoise/((ladybird)) hybrid, [trending on artstation]", - model: "stabilityai/stable-diffusion-2", + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", + provider: "nebius", + inputs: "award winning high resolution photo of a giant tortoise", }); expect(res).toBeInstanceOf(Blob); }); - it("textToImage with parameters", async () => { - const width = 512; - const height = 128; - const num_inference_steps = 10; - - const res = await hf.textToImage({ - inputs: - "award winning high resolution photo of a giant tortoise/((ladybird)) hybrid, [trending on artstation]", - model: "stabilityai/stable-diffusion-2", - parameters: { - negative_prompt: "blurry", - width, - height, - num_inference_steps, - }, + it("featureExtraction", async () => { + const res = await client.featureExtraction({ + model: "BAAI/bge-multilingual-gemma2", + inputs: "That is a happy person", }); - expect(res).toBeInstanceOf(Blob); + + expect(res).toBeInstanceOf(Array); + expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)])); }); - it("textToImage with json output", async () => { - const res = await hf.textToImage({ - inputs: "a giant tortoise", - model: "stabilityai/stable-diffusion-2", - outputType: "json", + + it("textGeneration", async () => { + const res = await client.textGeneration({ + model: "mistralai/Devstral-Small-2505", + provider: "nebius", + inputs: "Once upon a time,", + temperature: 0, + max_tokens: 19, }); expect(res).toMatchObject({ - output: expect.any(String), - }); - }); - it("imageToText", async () => { - expect( - await hf.imageToText({ - data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), - model: "nlpconnect/vit-gpt2-image-captioning", - }) - ).toEqual({ - generated_text: "a large brown and white giraffe standing in a field ", + generated_text: " in a land far, far away, there lived a king who was very fond of flowers.", }); }); + }, + TIMEOUT + ); - /// Skipping because the function is deprecated - it.skip("request - openai-community/gpt2", async () => { - expect( - await hf.request({ - model: "openai-community/gpt2", - inputs: "one plus two equals", - }) - ).toMatchObject([ - { - generated_text: expect.any(String), - }, - ]); - }); + describe.concurrent( + "Scaleway", + () => { + const client = new InferenceClient(env.HF_SCALEWAY_KEY ?? "dummy"); - // Skipped at the moment because takes forever - it.skip("tabularRegression", async () => { - expect( - await hf.tabularRegression({ - model: "scikit-learn/Fish-Weight", - inputs: { - data: { - Height: ["11.52", "12.48", "12.3778"], - Length1: ["23.2", "24", "23.9"], - Length2: ["25.4", "26.3", "26.5"], - Length3: ["30", "31.2", "31.1"], - Species: ["Bream", "Bream", "Bream"], - Width: ["4.02", "4.3056", "4.6961"], - }, - }, - }) - ).toMatchObject([270.5473526976245, 313.6843425638086, 328.3727133404402]); - }); - - // Skipped at the moment because takes forever - it.skip("tabularClassification", async () => { - expect( - await hf.tabularClassification({ - model: "vvmnnnkv/wine-quality", - inputs: { - data: { - fixed_acidity: ["7.4", "7.8", "10.3"], - volatile_acidity: ["0.7", "0.88", "0.32"], - citric_acid: ["0", "0", "0.45"], - residual_sugar: ["1.9", "2.6", "6.4"], - chlorides: ["0.076", "0.098", "0.073"], - free_sulfur_dioxide: ["11", "25", "5"], - total_sulfur_dioxide: ["34", "67", "13"], - density: ["0.9978", "0.9968", "0.9976"], - pH: ["3.51", "3.2", "3.23"], - sulphates: ["0.56", "0.68", "0.82"], - alcohol: ["9.4", "9.8", "12.6"], - }, - }, - }) - ).toMatchObject([5, 5, 7]); - }); - - it("endpoint - makes request to specified endpoint", async () => { - const ep = hf.endpoint("https://router.huggingface.co/hf-inference/models/openai-community/gpt2"); - const { generated_text } = await ep.textGeneration({ - inputs: "one plus one is equal to", - parameters: { - max_new_tokens: 1, - }, - }); - assert.include(generated_text, "two"); - }); - - it("endpoint - makes request to specified endpoint - alternative syntax", async () => { - const epClient = new InferenceClient(env.HF_TOKEN, { - endpointUrl: "https://router.huggingface.co/hf-inference/models/openai-community/gpt2", - }); - const { generated_text } = await epClient.textGeneration({ - inputs: "one plus one is equal to", - parameters: { - max_new_tokens: 1, - }, - }); - assert.include(generated_text, "two"); - }); + HARDCODED_MODEL_INFERENCE_MAPPING.scaleway = { + "meta-llama/Llama-3.1-8B-Instruct": { + provider: "scaleway", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + providerId: "llama-3.1-8b-instruct", + status: "live", + task: "conversational", + }, + "BAAI/bge-multilingual-gemma2": { + provider: "scaleway", + hfModelId: "BAAI/bge-multilingual-gemma2", + providerId: "bge-multilingual-gemma2", + task: "feature-extraction", + status: "live", + }, + "google/gemma-3-27b-it": { + provider: "scaleway", + hfModelId: "google/gemma-3-27b-it", + providerId: "gemma-3-27b-it", + task: "conversational", + status: "live", + }, + }; - it("chatCompletion modelId - OpenAI Specs", async () => { - const res = await hf.chatCompletion({ - model: "mistralai/Mistral-7B-Instruct-v0.2", - messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }], - max_tokens: 500, - temperature: 0.1, - seed: 0, + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "scaleway", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + tool_choice: "none", }); if (res.choices && res.choices.length > 0) { const completion = res.choices[0].message?.content; - expect(completion).toContain("to two"); + expect(completion).toMatch(/(to )?(two|2)/i); } }); - it("chatCompletionStream modelId - OpenAI Specs", async () => { - const stream = hf.chatCompletionStream({ - model: "mistralai/Mistral-7B-Instruct-v0.2", - messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }], - max_tokens: 500, - temperature: 0.1, - seed: 0, - }); + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "scaleway", + messages: [{ role: "system", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; let out = ""; for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { out += chunk.choices[0].delta.content; } } - expect(out).toContain("2"); + expect(out).toMatch(/(two|2)/i); }); - it.skip("chatCompletionStream modelId Fail - OpenAI Specs", async () => { - expect( - hf - .chatCompletionStream({ - model: "google/gemma-2b", - messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }], - max_tokens: 500, - temperature: 0.1, - seed: 0, - }) - .next() - ).rejects.toThrowError( - "Server google/gemma-2b does not seem to support chat completion. Error: Template error: template not found" - ); + it("chatCompletion multimodal", async () => { + const res = await client.chatCompletion({ + model: "google/gemma-3-27b-it", + provider: "scaleway", + messages: [ + { + role: "user", + content: [ + { + type: "image_url", + image_url: { + url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + }, + { type: "text", text: "What is this?" }, + ], + }, + ], + }); + expect(res.choices).toBeDefined(); + expect(res.choices?.length).toBeGreaterThan(0); + expect(res.choices?.[0].message?.content).toContain("Statue of Liberty"); }); - it("chatCompletion - OpenAI Specs", async () => { - const ep = hf.endpoint("https://router.huggingface.co/hf-inference/models/mistralai/Mistral-7B-Instruct-v0.2"); - const res = await ep.chatCompletion({ - model: "tgi", - messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }], - max_tokens: 500, - temperature: 0.1, - seed: 0, + it("textGeneration", async () => { + const res = await client.textGeneration({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "scaleway", + inputs: "Once upon a time,", + temperature: 0, + max_tokens: 19, + }); + + expect(res).toMatchObject({ + generated_text: + " in a small village nestled in the rolling hills of the countryside, there lived a young girl named", + }); + }); + + it("featureExtraction", async () => { + const res = await client.featureExtraction({ + model: "BAAI/bge-multilingual-gemma2", + provider: "scaleway", + inputs: "That is a happy person", + }); + + expect(res).toBeInstanceOf(Array); + expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)])); + }); + }, + TIMEOUT + ); + + describe.concurrent("3rd party providers", () => { + it("chatCompletion - fails with unsupported model", async () => { + expect( + chatCompletion({ + model: "black-forest-labs/Flux.1-dev", + provider: "together", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + accessToken: env.HF_TOGETHER_KEY ?? "dummy", + }) + ).rejects.toThrowError( + "Model black-forest-labs/Flux.1-dev is not supported for task conversational and provider together" + ); + }); + }); + + describe.concurrent( + "Fireworks", + () => { + const client = new InferenceClient(env.HF_FIREWORKS_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["fireworks-ai"] = { + "deepseek-ai/DeepSeek-R1": { + provider: "fireworks-ai", + hfModelId: "deepseek-ai/DeepSeek-R1", + providerId: "accounts/fireworks/models/deepseek-r1", + status: "live", + task: "conversational", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "deepseek-ai/DeepSeek-R1", + provider: "fireworks-ai", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], }); if (res.choices && res.choices.length > 0) { const completion = res.choices[0].message?.content; - expect(completion).toContain("to two"); + expect(completion).toContain("two"); } }); - it("chatCompletionStream - OpenAI Specs", async () => { - const ep = hf.endpoint("https://router.huggingface.co/hf-inference/models/mistralai/Mistral-7B-Instruct-v0.2"); - const stream = ep.chatCompletionStream({ - model: "tgi", - messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }], - max_tokens: 500, - temperature: 0.1, - seed: 0, - }); - let out = ""; + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "deepseek-ai/DeepSeek-R1", + provider: "fireworks-ai", + messages: [{ role: "user", content: "Say this is a test" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { - out += chunk.choices[0].delta.content; + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } } } - expect(out).toContain("2"); - }); - it("custom mistral - OpenAI Specs", async () => { - const MISTRAL_KEY = env.MISTRAL_KEY; - const hf = new InferenceClient(MISTRAL_KEY); - const ep = hf.endpoint("https://api.mistral.ai"); - const stream = ep.chatCompletionStream({ - model: "mistral-tiny", - messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }], - }) as AsyncGenerator; - let out = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - out += chunk.choices[0].delta.content; - } - } - expect(out).toContain("The answer to one + one is two."); - }); - it("custom openai - OpenAI Specs", async () => { - const OPENAI_KEY = env.OPENAI_KEY; - const hf = new InferenceClient(OPENAI_KEY); - const stream = hf.chatCompletionStream({ - provider: "openai", - model: "openai/gpt-3.5-turbo", - messages: [{ role: "user", content: "Complete the equation one + one =" }], - }) as AsyncGenerator; - let out = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - out += chunk.choices[0].delta.content; - } - } - expect(out).toContain("two"); - }); - it("OpenAI client side routing - model should have provider as prefix", async () => { - await expect( - new InferenceClient("dummy_token").chatCompletion({ - model: "gpt-3.5-turbo", // must be "openai/gpt-3.5-turbo" - provider: "openai", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - }) - ).rejects.toThrowError(`Models from openai must be prefixed by "openai/". Got "gpt-3.5-turbo".`); + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); }); }, TIMEOUT ); - /** - * Compatibility with third-party Inference Providers - */ describe.concurrent( - "Fal AI", + "Hyperbolic", () => { - const client = new InferenceClient(env.HF_FAL_KEY ?? "dummy"); - - HARDCODED_MODEL_INFERENCE_MAPPING["fal-ai"] = { - "openfree/flux-chatgpt-ghibli-lora": { - provider: "fal-ai", - hfModelId: "openfree/flux-chatgpt-ghibli-lora", - providerId: "fal-ai/flux-lora", + HARDCODED_MODEL_INFERENCE_MAPPING["hyperbolic"] = { + "meta-llama/Llama-3.2-3B-Instruct": { + provider: "hyperbolic", + hfModelId: "meta-llama/Llama-3.2-3B-Instruct", + providerId: "meta-llama/Llama-3.2-3B-Instruct", status: "live", - task: "text-to-image", - adapter: "lora", - adapterWeightsPath: "flux-chatgpt-ghibli-lora.safetensors", + task: "conversational", }, - "nerijs/pixel-art-xl": { - provider: "fal-ai", - hfModelId: "nerijs/pixel-art-xl", - providerId: "fal-ai/lora", + "meta-llama/Llama-3.3-70B-Instruct": { + provider: "hyperbolic", + hfModelId: "meta-llama/Llama-3.3-70B-Instruct", + providerId: "meta-llama/Llama-3.3-70B-Instruct", status: "live", - task: "text-to-image", - adapter: "lora", - adapterWeightsPath: "pixel-art-xl.safetensors", + task: "conversational", }, - }; - - it(`textToImage - black-forest-labs/FLUX.1-schnell`, async () => { - const res = await client.textToImage({ - model: "black-forest-labs/FLUX.1-schnell", - provider: "fal-ai", - inputs: - "Extreme close-up of a single tiger eye, direct frontal view. Detailed iris and pupil. Sharp focus on eye texture and color. Natural lighting to capture authentic eye shine and depth.", - }); - expect(res).toBeInstanceOf(Blob); - }); - - /// Skipped: we need a way to pass the base model ID - it(`textToImage - SD LoRAs`, async () => { - const res = await client.textToImage({ - model: "nerijs/pixel-art-xl", - provider: "fal-ai", - inputs: "pixel, a cute corgi", - parameters: { - negative_prompt: "3d render, realistic", - }, - }); - expect(res).toBeInstanceOf(Blob); - }); - - it(`textToImage - Flux LoRAs`, async () => { - const res = await client.textToImage({ - model: "openfree/flux-chatgpt-ghibli-lora", - provider: "fal-ai", - inputs: - "Ghibli style sky whale transport ship, its metallic skin adorned with traditional Japanese patterns, gliding through cotton candy clouds at sunrise. Small floating gardens hang from its sides, where workers in futuristic kimonos tend to glowing plants. Rainbow auroras shimmer in the background. [trigger]", - }); - expect(res).toBeInstanceOf(Blob); - }); - - it(`automaticSpeechRecognition - openai/whisper-large-v3`, async () => { - const res = await client.automaticSpeechRecognition({ - model: "openai/whisper-large-v3", - provider: "fal-ai", - data: new Blob([readTestFile("sample2.wav")], { type: "audio/x-wav" }), - }); - expect(res).toMatchObject({ - text: " he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca", - }); - }); - it("imageToVideo - fal-ai", async () => { - const res = await client.imageToVideo({ - model: "fal-ai/ltxv-13b-098-distilled/image-to-video", - provider: "fal-ai", - inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), - parameters: { - prompt: "The cats are jumping around in a playful manner", - }, - }); - expect(res).toBeInstanceOf(Blob); - }); - }, - TIMEOUT - ); - - describe.concurrent( - "Featherless", - () => { - HARDCODED_MODEL_INFERENCE_MAPPING["featherless-ai"] = { - "meta-llama/Llama-3.1-8B": { - provider: "featherless-ai", - providerId: "meta-llama/Meta-Llama-3.1-8B", - hfModelId: "meta-llama/Llama-3.1-8B", - task: "text-generation", + "stabilityai/stable-diffusion-2": { + provider: "hyperbolic", + hfModelId: "stabilityai/stable-diffusion-2", + providerId: "SD2", status: "live", + task: "text-to-image", }, - "meta-llama/Llama-3.1-8B-Instruct": { - provider: "featherless-ai", - providerId: "meta-llama/Meta-Llama-3.1-8B-Instruct", - hfModelId: "meta-llama/Llama-3.1-8B-Instruct", - task: "text-generation", + "meta-llama/Llama-3.1-405B-FP8": { + provider: "hyperbolic", + hfModelId: "meta-llama/Llama-3.1-405B-FP8", + providerId: "meta-llama/Llama-3.1-405B-FP8", status: "live", + task: "conversational", }, }; - it("chatCompletion", async () => { + it("chatCompletion - hyperbolic", async () => { const res = await chatCompletion({ - accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "featherless-ai", + accessToken: env.HF_HYPERBOLIC_KEY ?? "dummy", + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "hyperbolic", messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], temperature: 0.1, }); @@ -1661,9 +1750,9 @@ describe("InferenceClient", () => { it("chatCompletion stream", async () => { const stream = chatCompletionStream({ - accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "featherless-ai", + accessToken: env.HF_HYPERBOLIC_KEY ?? "dummy", + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "hyperbolic", messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], }) as AsyncGenerator; let out = ""; @@ -1675,156 +1764,175 @@ describe("InferenceClient", () => { expect(out).toContain("2"); }); + it("textToImage", async () => { + const res = await textToImage({ + accessToken: env.HF_HYPERBOLIC_KEY ?? "dummy", + model: "stabilityai/stable-diffusion-2", + provider: "hyperbolic", + inputs: "award winning high resolution photo of a giant tortoise", + parameters: { + height: 128, + width: 128, + }, + } satisfies TextToImageArgs); + expect(res).toBeInstanceOf(Blob); + }); + it("textGeneration", async () => { const res = await textGeneration({ - accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", - model: "meta-llama/Llama-3.1-8B", - provider: "featherless-ai", - inputs: "Paris is a city of ", + accessToken: env.HF_HYPERBOLIC_KEY ?? "dummy", + model: "meta-llama/Llama-3.1-405B", + provider: "hyperbolic", + inputs: "Paris is", parameters: { temperature: 0, top_p: 0.01, - max_tokens: 10, + max_new_tokens: 10, }, }); - expect(res).toMatchObject({ generated_text: "2.2 million people, and it is the" }); + expect(res).toMatchObject({ generated_text: "...the capital and most populous city of France," }); }); }, TIMEOUT ); describe.concurrent( - "Replicate", + "Novita", () => { - const client = new InferenceClient(env.HF_REPLICATE_KEY ?? "dummy"); - - it("textToImage canonical - black-forest-labs/FLUX.1-schnell", async () => { - const res = await client.textToImage({ - model: "black-forest-labs/FLUX.1-schnell", - provider: "replicate", - inputs: "black forest gateau cake spelling out the words FLUX SCHNELL, tasty, food photography, dynamic shot", - }); - expect(res).toBeInstanceOf(Blob); - }); + const client = new InferenceClient(env.HF_NOVITA_KEY ?? "dummy"); - it("textToImage canonical - black-forest-labs/FLUX.1-dev", async () => { - const res = await client.textToImage({ - model: "black-forest-labs/FLUX.1-dev", - provider: "replicate", - inputs: - "A tiny laboratory deep in the Black Forest where squirrels in lab coats experiment with mixing chocolate and pine cones", - }); - expect(res).toBeInstanceOf(Blob); - }); + HARDCODED_MODEL_INFERENCE_MAPPING["novita"] = { + "meta-llama/llama-3.1-8b-instruct": { + provider: "novita", + hfModelId: "meta-llama/llama-3.1-8b-instruct", + providerId: "meta-llama/llama-3.1-8b-instruct", + status: "live", + task: "conversational", + }, + "deepseek/deepseek-r1-distill-qwen-14b": { + provider: "novita", + hfModelId: "deepseek/deepseek-r1-distill-qwen-14b", + providerId: "deepseek/deepseek-r1-distill-qwen-14b", + status: "live", + task: "conversational", + }, + }; - // Runs black-forest-labs/flux-dev-lora under the hood - // with fofr/flux-80s-cyberpunk as the LoRA weights - it("textToImage - all Flux LoRAs", async () => { - const res = await client.textToImage({ - model: "fofr/flux-80s-cyberpunk", - provider: "replicate", - inputs: "style of 80s cyberpunk, a portrait photo", + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "novita", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], }); - expect(res).toBeInstanceOf(Blob); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } }); - it("textToImage canonical - stabilityai/stable-diffusion-3.5-large-turbo", async () => { - const res = await client.textToImage({ - model: "stabilityai/stable-diffusion-3.5-large-turbo", - provider: "replicate", - inputs: "A confused rubber duck wearing a tiny wizard hat trying to cast spells with a banana wand", - }); - expect(res).toBeInstanceOf(Blob); - }); - - it("textToImage versioned - ByteDance/SDXL-Lightning", async () => { - const res = await client.textToImage({ - model: "ByteDance/SDXL-Lightning", - provider: "replicate", - inputs: "A grumpy storm cloud wearing sunglasses and throwing tiny lightning bolts like confetti", - }); - expect(res).toBeInstanceOf(Blob); - }); - - it("textToImage versioned - ByteDance/Hyper-SD", async () => { - const res = await client.textToImage({ - model: "ByteDance/Hyper-SD", - provider: "replicate", - inputs: "A group of dancing bytes wearing tiny party hats doing the macarena in cyberspace", - }); - expect(res).toBeInstanceOf(Blob); - }); - - it("textToImage versioned - playgroundai/playground-v2.5-1024px-aesthetic", async () => { - const res = await client.textToImage({ - model: "playgroundai/playground-v2.5-1024px-aesthetic", - provider: "replicate", - inputs: "A playground where slides turn into rainbows and swings launch kids into cotton candy clouds", - }); - expect(res).toBeInstanceOf(Blob); - }); - - it("textToImage versioned - stabilityai/stable-diffusion-xl-base-1.0", async () => { - const res = await client.textToImage({ - model: "stabilityai/stable-diffusion-xl-base-1.0", - provider: "replicate", - inputs: "An octopus juggling watermelons underwater while wearing scuba gear", - }); - expect(res).toBeInstanceOf(Blob); - }); - - it.skip("textToSpeech versioned", async () => { - const res = await client.textToSpeech({ - model: "SWivid/F5-TTS", - provider: "replicate", - inputs: "Hello, how are you?", - }); - expect(res).toBeInstanceOf(Blob); - }); + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "deepseek/deepseek-r1-distill-qwen-14b", + provider: "novita", + messages: [{ role: "user", content: "Say this is a test" }], + stream: true, + }) as AsyncGenerator; - it.skip("textToSpeech OuteTTS - usually Cold", async () => { - const res = await client.textToSpeech({ - model: "OuteAI/OuteTTS-0.3-500M", - provider: "replicate", - inputs: "OuteTTS is a frontier TTS model for its size of 1 Billion parameters", - }); + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } - expect(res).toBeInstanceOf(Blob); + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); }); + }, + TIMEOUT + ); + describe.concurrent( + "Black Forest Labs", + () => { + HARDCODED_MODEL_INFERENCE_MAPPING["black-forest-labs"] = { + "black-forest-labs/FLUX.1-dev": { + provider: "black-forest-labs", + hfModelId: "black-forest-labs/FLUX.1-dev", + providerId: "flux-dev", + status: "live", + task: "text-to-image", + }, + // "black-forest-labs/FLUX.1-schnell": "flux-pro", + }; - it("textToSpeech Kokoro", async () => { - const res = await client.textToSpeech({ - model: "hexgrad/Kokoro-82M", - provider: "replicate", - inputs: "Kokoro is a frontier TTS model for its size of 1 Billion parameters", + it("textToImage", async () => { + const res = await textToImage({ + model: "black-forest-labs/FLUX.1-dev", + provider: "black-forest-labs", + accessToken: env.HF_BLACK_FOREST_LABS_KEY ?? "dummy", + inputs: "A raccoon driving a truck", + parameters: { + height: 256, + width: 256, + num_inference_steps: 4, + seed: 8817, + }, }); - expect(res).toBeInstanceOf(Blob); }); - it("imageToImage - FLUX Kontext Dev", async () => { - const res = await client.imageToImage({ - model: "black-forest-labs/flux-kontext-dev", - provider: "replicate", - inputs: new Blob([readTestFile("stormtrooper_depth.png")], { type: "image/png" }), - parameters: { - prompt: "Change the stormtrooper armor to golden color while keeping the same pose and helmet design", + it("textToImage URL", async () => { + const res = await textToImage( + { + model: "black-forest-labs/FLUX.1-dev", + provider: "black-forest-labs", + accessToken: env.HF_BLACK_FOREST_LABS_KEY ?? "dummy", + inputs: "A raccoon driving a truck", + parameters: { + height: 256, + width: 256, + num_inference_steps: 4, + seed: 8817, + }, }, - }); - expect(res).toBeInstanceOf(Blob); + { outputType: "url" } + ); + expect(res).toBeTypeOf("string"); + expect(isUrl(res)).toBeTruthy(); }); }, TIMEOUT ); describe.concurrent( - "SambaNova", + "Cohere", () => { - const client = new InferenceClient(env.HF_SAMBANOVA_KEY ?? "dummy"); + const client = new InferenceClient(env.HF_COHERE_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["cohere"] = { + "CohereForAI/c4ai-command-r7b-12-2024": { + provider: "cohere", + hfModelId: "CohereForAI/c4ai-command-r7b-12-2024", + providerId: "command-r7b-12-2024", + status: "live", + task: "conversational", + }, + "CohereForAI/aya-expanse-8b": { + provider: "cohere", + hfModelId: "CohereForAI/aya-expanse-8b", + providerId: "c4ai-aya-expanse-8b", + status: "live", + task: "conversational", + }, + }; it("chatCompletion", async () => { const res = await client.chatCompletion({ - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "sambanova", + model: "CohereForAI/c4ai-command-r7b-12-2024", + provider: "cohere", messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], }); if (res.choices && res.choices.length > 0) { @@ -1832,42 +1940,51 @@ describe("InferenceClient", () => { expect(completion).toContain("two"); } }); + it("chatCompletion stream", async () => { const stream = client.chatCompletionStream({ - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "sambanova", - messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + model: "CohereForAI/c4ai-command-r7b-12-2024", + provider: "cohere", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, }) as AsyncGenerator; - let out = ""; + + let fullResponse = ""; for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { - out += chunk.choices[0].delta.content; + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } } } - expect(out).toContain("2"); - }); - it("featureExtraction", async () => { - const res = await client.featureExtraction({ - model: "intfloat/e5-mistral-7b-instruct", - provider: "sambanova", - inputs: "Today is a sunny day and I will get some ice cream.", - }); - expect(res).toBeInstanceOf(Array); - expect(res[0]).toBeInstanceOf(Array); + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); }); }, TIMEOUT ); - describe.concurrent( - "Together", + "Cerebras", () => { - const client = new InferenceClient(env.HF_TOGETHER_KEY ?? "dummy"); + const client = new InferenceClient(env.HF_CEREBRAS_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["cerebras"] = { + "meta-llama/llama-3.1-8b-instruct": { + provider: "cerebras", + hfModelId: "meta-llama/llama-3.1-8b-instruct", + providerId: "llama3.1-8b", + status: "live", + task: "conversational", + }, + }; it("chatCompletion", async () => { const res = await client.chatCompletion({ - model: "meta-llama/Llama-3.3-70B-Instruct", - provider: "together", + model: "meta-llama/llama-3.1-8b-instruct", + provider: "cerebras", messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], }); if (res.choices && res.choices.length > 0) { @@ -1878,282 +1995,154 @@ describe("InferenceClient", () => { it("chatCompletion stream", async () => { const stream = client.chatCompletionStream({ - model: "meta-llama/Llama-3.3-70B-Instruct", - provider: "together", - messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + model: "meta-llama/llama-3.1-8b-instruct", + provider: "cerebras", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, }) as AsyncGenerator; - let out = ""; + + let fullResponse = ""; for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { - out += chunk.choices[0].delta.content; + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } } } - expect(out).toContain("2"); - }); - - it("textToImage", async () => { - const res = await client.textToImage({ - model: "stabilityai/stable-diffusion-xl-base-1.0", - provider: "together", - inputs: "award winning high resolution photo of a giant tortoise", - }); - expect(res).toBeInstanceOf(Blob); - }); - it("textGeneration", async () => { - const res = await client.textGeneration({ - model: "mistralai/Mixtral-8x7B-v0.1", - provider: "together", - inputs: "Paris is", - temperature: 0, - max_tokens: 10, - }); - expect(res).toMatchObject({ generated_text: " a city of love, and it’s also" }); + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); }); }, TIMEOUT ); - describe.concurrent( - "Nebius", + "Nscale", () => { - const client = new InferenceClient(env.HF_NEBIUS_KEY ?? "dummy"); + const client = new InferenceClient(env.HF_NSCALE_KEY ?? "dummy"); - HARDCODED_MODEL_INFERENCE_MAPPING.nebius = { + HARDCODED_MODEL_INFERENCE_MAPPING["nscale"] = { "meta-llama/Llama-3.1-8B-Instruct": { - provider: "nebius", - hfModelId: "meta-llama/Llama-3.1-8B-Instruct", - providerId: "meta-llama/Meta-Llama-3.1-8B-Instruct", - status: "live", - task: "conversational", - }, - "meta-llama/Llama-3.1-70B-Instruct": { - provider: "nebius", + provider: "nscale", hfModelId: "meta-llama/Llama-3.1-8B-Instruct", - providerId: "meta-llama/Meta-Llama-3.1-70B-Instruct", + providerId: "nscale", status: "live", task: "conversational", }, "black-forest-labs/FLUX.1-schnell": { - provider: "nebius", - hfModelId: "meta-llama/Llama-3.1-8B-Instruct", - providerId: "black-forest-labs/flux-schnell", + provider: "nscale", + hfModelId: "black-forest-labs/FLUX.1-schnell", + providerId: "flux-schnell", status: "live", task: "text-to-image", }, - "BAAI/bge-multilingual-gemma2": { - provider: "nebius", - providerId: "BAAI/bge-multilingual-gemma2", - hfModelId: "BAAI/bge-multilingual-gemma2", - status: "live", - task: "feature-extraction", - }, - "mistralai/Devstral-Small-2505": { - provider: "nebius", - providerId: "mistralai/Devstral-Small-2505", - hfModelId: "mistralai/Devstral-Small-2505", - status: "live", - task: "text-generation", - }, }; it("chatCompletion", async () => { const res = await client.chatCompletion({ model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "nebius", + provider: "nscale", messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], }); if (res.choices && res.choices.length > 0) { const completion = res.choices[0].message?.content; - expect(completion).toMatch(/(two|2)/i); + expect(completion).toContain("two"); } }); - it("chatCompletion stream", async () => { const stream = client.chatCompletionStream({ - model: "meta-llama/Llama-3.1-70B-Instruct", - provider: "nebius", - messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "nscale", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, }) as AsyncGenerator; - let out = ""; + let fullResponse = ""; for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { - out += chunk.choices[0].delta.content; + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } } } - expect(out).toMatch(/(two|2)/i); + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); }); - it("textToImage", async () => { const res = await client.textToImage({ model: "black-forest-labs/FLUX.1-schnell", - provider: "nebius", - inputs: "award winning high resolution photo of a giant tortoise", + provider: "nscale", + inputs: "An astronaut riding a horse", }); expect(res).toBeInstanceOf(Blob); }); - - it("featureExtraction", async () => { - const res = await client.featureExtraction({ - model: "BAAI/bge-multilingual-gemma2", - inputs: "That is a happy person", - }); - - expect(res).toBeInstanceOf(Array); - expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)])); - }); - - it("textGeneration", async () => { - const res = await client.textGeneration({ - model: "mistralai/Devstral-Small-2505", - provider: "nebius", - inputs: "Once upon a time,", - temperature: 0, - max_tokens: 19, - }); - expect(res).toMatchObject({ - generated_text: " in a land far, far away, there lived a king who was very fond of flowers.", - }); - }); }, TIMEOUT ); - describe.concurrent( - "Scaleway", + "Groq", () => { - const client = new InferenceClient(env.HF_SCALEWAY_KEY ?? "dummy"); + const client = new InferenceClient(env.HF_GROQ_KEY ?? "dummy"); - HARDCODED_MODEL_INFERENCE_MAPPING.scaleway = { - "meta-llama/Llama-3.1-8B-Instruct": { - provider: "scaleway", - hfModelId: "meta-llama/Llama-3.1-8B-Instruct", - providerId: "llama-3.1-8b-instruct", - status: "live", - task: "conversational", - }, - "BAAI/bge-multilingual-gemma2": { - provider: "scaleway", - hfModelId: "BAAI/bge-multilingual-gemma2", - providerId: "bge-multilingual-gemma2", - task: "feature-extraction", + HARDCODED_MODEL_INFERENCE_MAPPING["groq"] = { + "meta-llama/Llama-3.3-70B-Instruct": { + provider: "groq", + hfModelId: "meta-llama/Llama-3.3-70B-Instruct", + providerId: "llama-3.3-70b-versatile", status: "live", - }, - "google/gemma-3-27b-it": { - provider: "scaleway", - hfModelId: "google/gemma-3-27b-it", - providerId: "gemma-3-27b-it", task: "conversational", - status: "live", }, }; it("chatCompletion", async () => { const res = await client.chatCompletion({ - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "scaleway", + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "groq", messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - tool_choice: "none", }); if (res.choices && res.choices.length > 0) { const completion = res.choices[0].message?.content; - expect(completion).toMatch(/(to )?(two|2)/i); + expect(completion).toContain("two"); } }); it("chatCompletion stream", async () => { const stream = client.chatCompletionStream({ - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "scaleway", - messages: [{ role: "system", content: "Complete the equation 1 + 1 = , just the answer" }], + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "groq", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, }) as AsyncGenerator; - let out = ""; + + let fullResponse = ""; for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { - out += chunk.choices[0].delta.content; + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } } } - expect(out).toMatch(/(two|2)/i); - }); - - it("chatCompletion multimodal", async () => { - const res = await client.chatCompletion({ - model: "google/gemma-3-27b-it", - provider: "scaleway", - messages: [ - { - role: "user", - content: [ - { - type: "image_url", - image_url: { - url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", - }, - }, - { type: "text", text: "What is this?" }, - ], - }, - ], - }); - expect(res.choices).toBeDefined(); - expect(res.choices?.length).toBeGreaterThan(0); - expect(res.choices?.[0].message?.content).toContain("Statue of Liberty"); - }); - it("textGeneration", async () => { - const res = await client.textGeneration({ - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "scaleway", - inputs: "Once upon a time,", - temperature: 0, - max_tokens: 19, - }); - - expect(res).toMatchObject({ - generated_text: - " in a small village nestled in the rolling hills of the countryside, there lived a young girl named", - }); - }); - - it("featureExtraction", async () => { - const res = await client.featureExtraction({ - model: "BAAI/bge-multilingual-gemma2", - provider: "scaleway", - inputs: "That is a happy person", - }); - - expect(res).toBeInstanceOf(Array); - expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)])); + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); }); }, TIMEOUT ); - - describe.concurrent("3rd party providers", () => { - it("chatCompletion - fails with unsupported model", async () => { - expect( - chatCompletion({ - model: "black-forest-labs/Flux.1-dev", - provider: "together", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - accessToken: env.HF_TOGETHER_KEY ?? "dummy", - }) - ).rejects.toThrowError( - "Model black-forest-labs/Flux.1-dev is not supported for task conversational and provider together" - ); - }); - }); - describe.concurrent( - "Fireworks", + "OVHcloud", () => { - const client = new InferenceClient(env.HF_FIREWORKS_KEY ?? "dummy"); + const client = new HfInference(env.HF_OVHCLOUD_KEY ?? "dummy"); - HARDCODED_MODEL_INFERENCE_MAPPING["fireworks-ai"] = { - "deepseek-ai/DeepSeek-R1": { - provider: "fireworks-ai", - hfModelId: "deepseek-ai/DeepSeek-R1", - providerId: "accounts/fireworks/models/deepseek-r1", + HARDCODED_MODEL_INFERENCE_MAPPING["ovhcloud"] = { + "meta-llama/llama-3.1-8b-instruct": { + provider: "ovhcloud", + hfModelId: "meta-llama/llama-3.1-8b-instruct", + providerId: "Llama-3.1-8B-Instruct", status: "live", task: "conversational", }, @@ -2161,22 +2150,29 @@ describe("InferenceClient", () => { it("chatCompletion", async () => { const res = await client.chatCompletion({ - model: "deepseek-ai/DeepSeek-R1", - provider: "fireworks-ai", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + messages: [{ role: "user", content: "A, B, C, " }], + seed: 42, + temperature: 0, + top_p: 0.01, + max_tokens: 1, }); - if (res.choices && res.choices.length > 0) { - const completion = res.choices[0].message?.content; - expect(completion).toContain("two"); - } + expect(res.choices && res.choices.length > 0); + const completion = res.choices[0].message?.content; + expect(completion).toContain("D"); }); it("chatCompletion stream", async () => { const stream = client.chatCompletionStream({ - model: "deepseek-ai/DeepSeek-R1", - provider: "fireworks-ai", - messages: [{ role: "user", content: "Say this is a test" }], + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + messages: [{ role: "user", content: "A, B, C, " }], stream: true, + seed: 42, + temperature: 0, + top_p: 0.01, + max_tokens: 1, }) as AsyncGenerator; let fullResponse = ""; @@ -2191,161 +2187,43 @@ describe("InferenceClient", () => { // Verify we got a meaningful response expect(fullResponse).toBeTruthy(); - expect(fullResponse.length).toBeGreaterThan(0); - }); - }, - TIMEOUT - ); - - describe.concurrent( - "Hyperbolic", - () => { - HARDCODED_MODEL_INFERENCE_MAPPING["hyperbolic"] = { - "meta-llama/Llama-3.2-3B-Instruct": { - provider: "hyperbolic", - hfModelId: "meta-llama/Llama-3.2-3B-Instruct", - providerId: "meta-llama/Llama-3.2-3B-Instruct", - status: "live", - task: "conversational", - }, - "meta-llama/Llama-3.3-70B-Instruct": { - provider: "hyperbolic", - hfModelId: "meta-llama/Llama-3.3-70B-Instruct", - providerId: "meta-llama/Llama-3.3-70B-Instruct", - status: "live", - task: "conversational", - }, - "stabilityai/stable-diffusion-2": { - provider: "hyperbolic", - hfModelId: "stabilityai/stable-diffusion-2", - providerId: "SD2", - status: "live", - task: "text-to-image", - }, - "meta-llama/Llama-3.1-405B-FP8": { - provider: "hyperbolic", - hfModelId: "meta-llama/Llama-3.1-405B-FP8", - providerId: "meta-llama/Llama-3.1-405B-FP8", - status: "live", - task: "conversational", - }, - }; - - it("chatCompletion - hyperbolic", async () => { - const res = await chatCompletion({ - accessToken: env.HF_HYPERBOLIC_KEY ?? "dummy", - model: "meta-llama/Llama-3.2-3B-Instruct", - provider: "hyperbolic", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - temperature: 0.1, - }); - - expect(res).toBeDefined(); - expect(res.choices).toBeDefined(); - expect(res.choices?.length).toBeGreaterThan(0); - - if (res.choices && res.choices.length > 0) { - const completion = res.choices[0].message?.content; - expect(completion).toBeDefined(); - expect(typeof completion).toBe("string"); - expect(completion).toContain("two"); - } - }); - - it("chatCompletion stream", async () => { - const stream = chatCompletionStream({ - accessToken: env.HF_HYPERBOLIC_KEY ?? "dummy", - model: "meta-llama/Llama-3.3-70B-Instruct", - provider: "hyperbolic", - messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], - }) as AsyncGenerator; - let out = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - out += chunk.choices[0].delta.content; - } - } - expect(out).toContain("2"); - }); - - it("textToImage", async () => { - const res = await textToImage({ - accessToken: env.HF_HYPERBOLIC_KEY ?? "dummy", - model: "stabilityai/stable-diffusion-2", - provider: "hyperbolic", - inputs: "award winning high resolution photo of a giant tortoise", - parameters: { - height: 128, - width: 128, - }, - } satisfies TextToImageArgs); - expect(res).toBeInstanceOf(Blob); + expect(fullResponse).toContain("D"); }); it("textGeneration", async () => { - const res = await textGeneration({ - accessToken: env.HF_HYPERBOLIC_KEY ?? "dummy", - model: "meta-llama/Llama-3.1-405B", - provider: "hyperbolic", - inputs: "Paris is", + const res = await client.textGeneration({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + inputs: "A B C ", parameters: { + seed: 42, temperature: 0, top_p: 0.01, - max_new_tokens: 10, + max_new_tokens: 1, }, }); - expect(res).toMatchObject({ generated_text: "...the capital and most populous city of France," }); + expect(res.generated_text.length > 0); + expect(res.generated_text).toContain("D"); }); - }, - TIMEOUT - ); - describe.concurrent( - "Novita", - () => { - const client = new InferenceClient(env.HF_NOVITA_KEY ?? "dummy"); - - HARDCODED_MODEL_INFERENCE_MAPPING["novita"] = { - "meta-llama/llama-3.1-8b-instruct": { - provider: "novita", - hfModelId: "meta-llama/llama-3.1-8b-instruct", - providerId: "meta-llama/llama-3.1-8b-instruct", - status: "live", - task: "conversational", - }, - "deepseek/deepseek-r1-distill-qwen-14b": { - provider: "novita", - hfModelId: "deepseek/deepseek-r1-distill-qwen-14b", - providerId: "deepseek/deepseek-r1-distill-qwen-14b", - status: "live", - task: "conversational", - }, - }; - - it("chatCompletion", async () => { - const res = await client.chatCompletion({ + it("textGeneration stream", async () => { + const stream = client.textGenerationStream({ model: "meta-llama/llama-3.1-8b-instruct", - provider: "novita", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - }); - if (res.choices && res.choices.length > 0) { - const completion = res.choices[0].message?.content; - expect(completion).toContain("two"); - } - }); - - it("chatCompletion stream", async () => { - const stream = client.chatCompletionStream({ - model: "deepseek/deepseek-r1-distill-qwen-14b", - provider: "novita", - messages: [{ role: "user", content: "Say this is a test" }], + provider: "ovhcloud", + inputs: "A B C ", stream: true, + parameters: { + seed: 42, + temperature: 0, + top_p: 0.01, + max_new_tokens: 1, + }, }) as AsyncGenerator; let fullResponse = ""; for await (const chunk of stream) { if (chunk.choices && chunk.choices.length > 0) { - const content = chunk.choices[0].delta?.content; + const content = chunk.choices[0].text; if (content) { fullResponse += content; } @@ -2354,390 +2232,509 @@ describe("InferenceClient", () => { // Verify we got a meaningful response expect(fullResponse).toBeTruthy(); - expect(fullResponse.length).toBeGreaterThan(0); + expect(fullResponse).toContain("D"); }); }, TIMEOUT ); - describe.concurrent( - "Black Forest Labs", + describe.only( + "Bytez", () => { - HARDCODED_MODEL_INFERENCE_MAPPING["black-forest-labs"] = { - "black-forest-labs/FLUX.1-dev": { - provider: "black-forest-labs", - hfModelId: "black-forest-labs/FLUX.1-dev", - providerId: "flux-dev", - status: "live", - task: "text-to-image", - }, - // "black-forest-labs/FLUX.1-schnell": "flux-pro", - }; + const client = new InferenceClient(env.HF_BYTEZ_KEY ?? "dummy"); - it("textToImage", async () => { - const res = await textToImage({ - model: "black-forest-labs/FLUX.1-dev", - provider: "black-forest-labs", - accessToken: env.HF_BLACK_FOREST_LABS_KEY ?? "dummy", - inputs: "A raccoon driving a truck", - parameters: { - height: 256, - width: 256, - num_inference_steps: 4, - seed: 8817, + const tests: { task: WidgetType; modelId: string; test: (modelId: string) => Promise; stream: boolean }[] = + [ + { + task: "text-generation", + modelId: "bigscience/mt0-small", + test: async (modelId: string) => { + const { generated_text } = await client.textGeneration({ + model: modelId, + provider: "bytez", + inputs: "Hello", + }); + expect(typeof generated_text).toBe("string"); + }, + stream: false, }, - }); - expect(res).toBeInstanceOf(Blob); - }); - - it("textToImage URL", async () => { - const res = await textToImage( { - model: "black-forest-labs/FLUX.1-dev", - provider: "black-forest-labs", - accessToken: env.HF_BLACK_FOREST_LABS_KEY ?? "dummy", - inputs: "A raccoon driving a truck", - parameters: { - height: 256, - width: 256, - num_inference_steps: 4, - seed: 8817, + task: "text-generation", + modelId: "bigscience/mt0-small", + test: async (modelId: string) => { + const response = client.textGenerationStream({ + model: modelId, + provider: "bytez", + inputs: "Please answer the following question: complete one two and ____.", + parameters: { + max_new_tokens: 50, + num_beams: 1, + }, + }); + for await (const ret of response) { + expect(ret).toMatchObject({ + details: null, + index: expect.any(Number), + token: { + id: expect.any(Number), + logprob: expect.any(Number), + special: expect.any(Boolean), + }, + generated_text: ret.generated_text ? "afew" : null, + }); + expect(typeof ret.token.text === "string" || ret.token.text === null).toBe(true); + } }, + stream: true, }, - { outputType: "url" } - ); - expect(res).toBeTypeOf("string"); - expect(isUrl(res)).toBeTruthy(); - }); - }, - TIMEOUT - ); - describe.concurrent( - "Cohere", - () => { - const client = new InferenceClient(env.HF_COHERE_KEY ?? "dummy"); - - HARDCODED_MODEL_INFERENCE_MAPPING["cohere"] = { - "CohereForAI/c4ai-command-r7b-12-2024": { - provider: "cohere", - hfModelId: "CohereForAI/c4ai-command-r7b-12-2024", - providerId: "command-r7b-12-2024", - status: "live", - task: "conversational", - }, - "CohereForAI/aya-expanse-8b": { - provider: "cohere", - hfModelId: "CohereForAI/aya-expanse-8b", - providerId: "c4ai-aya-expanse-8b", - status: "live", - task: "conversational", - }, - }; - - it("chatCompletion", async () => { - const res = await client.chatCompletion({ - model: "CohereForAI/c4ai-command-r7b-12-2024", - provider: "cohere", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - }); - if (res.choices && res.choices.length > 0) { - const completion = res.choices[0].message?.content; - expect(completion).toContain("two"); - } - }); - - it("chatCompletion stream", async () => { - const stream = client.chatCompletionStream({ - model: "CohereForAI/c4ai-command-r7b-12-2024", - provider: "cohere", - messages: [{ role: "user", content: "Say 'this is a test'" }], - stream: true, - }) as AsyncGenerator; - - let fullResponse = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - const content = chunk.choices[0].delta?.content; - if (content) { - fullResponse += content; - } - } - } - - // Verify we got a meaningful response - expect(fullResponse).toBeTruthy(); - expect(fullResponse.length).toBeGreaterThan(0); - }); - }, - TIMEOUT - ); - describe.concurrent( - "Cerebras", - () => { - const client = new InferenceClient(env.HF_CEREBRAS_KEY ?? "dummy"); - - HARDCODED_MODEL_INFERENCE_MAPPING["cerebras"] = { - "meta-llama/llama-3.1-8b-instruct": { - provider: "cerebras", - hfModelId: "meta-llama/llama-3.1-8b-instruct", - providerId: "llama3.1-8b", - status: "live", - task: "conversational", - }, - }; - - it("chatCompletion", async () => { - const res = await client.chatCompletion({ - model: "meta-llama/llama-3.1-8b-instruct", - provider: "cerebras", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - }); - if (res.choices && res.choices.length > 0) { - const completion = res.choices[0].message?.content; - expect(completion).toContain("two"); - } - }); - - it("chatCompletion stream", async () => { - const stream = client.chatCompletionStream({ - model: "meta-llama/llama-3.1-8b-instruct", - provider: "cerebras", - messages: [{ role: "user", content: "Say 'this is a test'" }], - stream: true, - }) as AsyncGenerator; - - let fullResponse = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - const content = chunk.choices[0].delta?.content; - if (content) { - fullResponse += content; - } - } - } - - // Verify we got a meaningful response - expect(fullResponse).toBeTruthy(); - expect(fullResponse.length).toBeGreaterThan(0); - }); - }, - TIMEOUT - ); - describe.concurrent( - "Nscale", - () => { - const client = new InferenceClient(env.HF_NSCALE_KEY ?? "dummy"); - - HARDCODED_MODEL_INFERENCE_MAPPING["nscale"] = { - "meta-llama/Llama-3.1-8B-Instruct": { - provider: "nscale", - hfModelId: "meta-llama/Llama-3.1-8B-Instruct", - providerId: "nscale", - status: "live", - task: "conversational", - }, - "black-forest-labs/FLUX.1-schnell": { - provider: "nscale", - hfModelId: "black-forest-labs/FLUX.1-schnell", - providerId: "flux-schnell", - status: "live", - task: "text-to-image", - }, - }; - - it("chatCompletion", async () => { - const res = await client.chatCompletion({ - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "nscale", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - }); - if (res.choices && res.choices.length > 0) { - const completion = res.choices[0].message?.content; - expect(completion).toContain("two"); - } - }); - it("chatCompletion stream", async () => { - const stream = client.chatCompletionStream({ - model: "meta-llama/Llama-3.1-8B-Instruct", - provider: "nscale", - messages: [{ role: "user", content: "Say 'this is a test'" }], - stream: true, - }) as AsyncGenerator; - let fullResponse = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - const content = chunk.choices[0].delta?.content; - if (content) { - fullResponse += content; - } - } - } - expect(fullResponse).toBeTruthy(); - expect(fullResponse.length).toBeGreaterThan(0); - }); - it("textToImage", async () => { - const res = await client.textToImage({ - model: "black-forest-labs/FLUX.1-schnell", - provider: "nscale", - inputs: "An astronaut riding a horse", - }); - expect(res).toBeInstanceOf(Blob); - }); - }, - TIMEOUT - ); - describe.concurrent( - "Groq", - () => { - const client = new InferenceClient(env.HF_GROQ_KEY ?? "dummy"); - - HARDCODED_MODEL_INFERENCE_MAPPING["groq"] = { - "meta-llama/Llama-3.3-70B-Instruct": { - provider: "groq", - hfModelId: "meta-llama/Llama-3.3-70B-Instruct", - providerId: "llama-3.3-70b-versatile", - status: "live", - task: "conversational", - }, - }; - - it("chatCompletion", async () => { - const res = await client.chatCompletion({ - model: "meta-llama/Llama-3.3-70B-Instruct", - provider: "groq", - messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], - }); - if (res.choices && res.choices.length > 0) { - const completion = res.choices[0].message?.content; - expect(completion).toContain("two"); - } - }); - - it("chatCompletion stream", async () => { - const stream = client.chatCompletionStream({ - model: "meta-llama/Llama-3.3-70B-Instruct", - provider: "groq", - messages: [{ role: "user", content: "Say 'this is a test'" }], - stream: true, - }) as AsyncGenerator; - - let fullResponse = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - const content = chunk.choices[0].delta?.content; - if (content) { - fullResponse += content; - } - } - } - - // Verify we got a meaningful response - expect(fullResponse).toBeTruthy(); - expect(fullResponse.length).toBeGreaterThan(0); - }); - }, - TIMEOUT - ); - describe.concurrent( - "OVHcloud", - () => { - const client = new HfInference(env.HF_OVHCLOUD_KEY ?? "dummy"); - - HARDCODED_MODEL_INFERENCE_MAPPING["ovhcloud"] = { - "meta-llama/llama-3.1-8b-instruct": { - provider: "ovhcloud", - hfModelId: "meta-llama/llama-3.1-8b-instruct", - providerId: "Llama-3.1-8B-Instruct", - status: "live", - task: "conversational", - }, - }; - - it("chatCompletion", async () => { - const res = await client.chatCompletion({ - model: "meta-llama/llama-3.1-8b-instruct", - provider: "ovhcloud", - messages: [{ role: "user", content: "A, B, C, " }], - seed: 42, - temperature: 0, - top_p: 0.01, - max_tokens: 1, - }); - expect(res.choices && res.choices.length > 0); - const completion = res.choices[0].message?.content; - expect(completion).toContain("D"); - }); - - it("chatCompletion stream", async () => { - const stream = client.chatCompletionStream({ - model: "meta-llama/llama-3.1-8b-instruct", - provider: "ovhcloud", - messages: [{ role: "user", content: "A, B, C, " }], - stream: true, - seed: 42, - temperature: 0, - top_p: 0.01, - max_tokens: 1, - }) as AsyncGenerator; - - let fullResponse = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - const content = chunk.choices[0].delta?.content; - if (content) { - fullResponse += content; - } - } - } - - // Verify we got a meaningful response - expect(fullResponse).toBeTruthy(); - expect(fullResponse).toContain("D"); - }); - - it("textGeneration", async () => { - const res = await client.textGeneration({ - model: "meta-llama/llama-3.1-8b-instruct", - provider: "ovhcloud", - inputs: "A B C ", - parameters: { - seed: 42, - temperature: 0, - top_p: 0.01, - max_new_tokens: 1, + { + task: "conversational", + modelId: "Qwen/Qwen3-1.7B", + test: async (modelId: string) => { + const { choices } = await client.chatCompletion({ + model: modelId, + provider: "bytez", + messages: [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { role: "user", content: "How many helicopters can a human eat in one sitting?" }, + ], + }); + expect(typeof choices[0].message.role).toBe("string"); + expect(typeof choices[0].message.content).toBe("string"); + }, + stream: false, + }, + { + task: "conversational", + modelId: "Qwen/Qwen3-1.7B", + test: async (modelId: string) => { + const stream = client.chatCompletionStream({ + model: modelId, + provider: "bytez", + messages: [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { role: "user", content: "How many helicopters can a human eat in one sitting?" }, + ], + }); + const chunks = []; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + if (chunk.choices[0].finish_reason === "stop") { + break; + } + chunks.push(chunk); + } + } + const out = chunks.map((chunk) => chunk.choices[0].delta.content).join(""); + expect(out).toContain("helicopter"); + }, + stream: true, + }, + { + task: "summarization", + modelId: "ainize/bart-base-cnn", + test: async (modelId: string) => { + const input = + "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris..."; + const { summary_text } = await client.summarization({ + model: modelId, + provider: "bytez", + inputs: input, + parameters: { max_length: 40 }, + }); + expect(summary_text.length).toBeLessThan(input.length); + }, + stream: false, + }, + { + task: "translation", + modelId: "Areeb123/En-Fr_Translation_Model", + test: async (modelId: string) => { + const { translation_text } = await client.translation({ + model: modelId, + provider: "bytez", + inputs: "Hello", + }); + expect(typeof translation_text).toBe("string"); + expect(translation_text.length).toBeGreaterThan(0); + }, + stream: false, + }, + { + task: "text-to-image", + modelId: "IDKiro/sdxs-512-0.9", + test: async (modelId: string) => { + const res = await client.textToImage({ + model: modelId, + provider: "bytez", + inputs: "A cat in the hat", + }); + expect(res).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "text-to-video", + modelId: "ali-vilab/text-to-video-ms-1.7b", + test: async (modelId: string) => { + const res = await client.textToVideo({ + model: modelId, + provider: "bytez", + inputs: "A cat in the hat", + }); + expect(res).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "image-to-text", + modelId: "captioner/caption-gen", + test: async (modelId: string) => { + const { generated_text } = await client.imageToText({ + model: modelId, + provider: "bytez", + data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + }); + expect(typeof generated_text).toBe("string"); + }, + stream: false, + }, + { + task: "question-answering", + modelId: "airesearch/xlm-roberta-base-finetune-qa", + test: async (modelId: string) => { + const { answer, score, start, end } = await client.questionAnswering({ + model: modelId, + provider: "bytez", + inputs: { + question: "Where do I live?", + context: "My name is Merve and I live in İstanbul.", + }, + }); + expect(answer).toBeDefined(); + expect(score).toBeDefined(); + expect(start).toBeDefined(); + expect(end).toBeDefined(); + }, + stream: false, + }, + { + task: "visual-question-answering", + modelId: "aqachun/Vilt_fine_tune_2000", + test: async (modelId: string) => { + const output = await client.visualQuestionAnswering({ + model: modelId, + provider: "bytez", + inputs: { + image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + question: "What kind of animal is this?", + }, + }); + expect(output).toMatchObject({ + answer: expect.any(String), + score: expect.any(Number), + }); + }, + stream: false, + }, + { + task: "document-question-answering", + modelId: "cloudqi/CQI_Visual_Question_Awnser_PT_v0", + test: async (modelId: string) => { + const url = "https://templates.invoicehome.com/invoice-template-us-neat-750px.png"; + const response = await fetch(url); + const blob = await response.blob(); + const output = await client.documentQuestionAnswering({ + model: modelId, + provider: "bytez", + inputs: { + // + image: blob, + question: "What's the total cost?", + }, + }); + expect(output).toMatchObject({ + answer: expect.any(String), + score: expect.any(Number), + // not sure what start/end refers to in this case + start: expect.any(Number), + end: expect.any(Number), + }); + }, + stream: false, + }, + { + task: "image-segmentation", + modelId: "apple/deeplabv3-mobilevit-small", + test: async (modelId: string) => { + const output = await client.imageSegmentation({ + model: modelId, + provider: "bytez", + inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + label: expect.any(String), + score: expect.any(Number), + mask: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "image-classification", + modelId: "akahana/vit-base-cats-vs-dogs", + test: async (modelId: string) => { + const output = await client.imageClassification({ + // + model: modelId, + provider: "bytez", + data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, }, - }); - expect(res.generated_text.length > 0); - expect(res.generated_text).toContain("D"); - }); - - it("textGeneration stream", async () => { - const stream = client.textGenerationStream({ - model: "meta-llama/llama-3.1-8b-instruct", - provider: "ovhcloud", - inputs: "A B C ", - stream: true, - parameters: { - seed: 42, - temperature: 0, - top_p: 0.01, - max_new_tokens: 1, + { + task: "zero-shot-image-classification", + modelId: "BilelDJ/clip-hugging-face-finetuned", + test: async (modelId: string) => { + const output = await client.zeroShotImageClassification({ + model: modelId, + provider: "bytez", + inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }) }, + parameters: { + candidate_labels: ["animal", "toy", "car"], + }, + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, }, - }) as AsyncGenerator; + { + task: "object-detection", + modelId: "aisak-ai/aisak-detect", + test: async (modelId: string) => { + const output = await client.objectDetection({ + model: modelId, + provider: "bytez", + inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + box: expect.objectContaining({ + xmin: expect.any(Number), + ymin: expect.any(Number), + xmax: expect.any(Number), + ymax: expect.any(Number), + }), + }), + ]) + ); + }, + stream: false, + }, + { + task: "feature-extraction", + modelId: "allenai/specter2_base", + test: async (modelId: string) => { + const output = await client.featureExtraction({ + model: modelId, + provider: "bytez", + inputs: "That is a happy person", + }); + expect(output).toEqual(expect.arrayContaining([expect.any(Number)])); + }, + stream: false, + }, + { + task: "sentence-similarity", + modelId: "embedding-data/distilroberta-base-sentence-transformer", + test: async (modelId: string) => { + const output = await client.sentenceSimilarity({ + model: modelId, + provider: "bytez", + inputs: { + source_sentence: "That is a happy person", + sentences: ["That is a happy dog", "That is a very happy person", "Today is a sunny day"], + }, + }); + expect(output).toEqual([expect.any(Number), expect.any(Number), expect.any(Number)]); + }, + stream: false, + }, + { + task: "fill-mask", + modelId: "almanach/camembert-base", + test: async (modelId: string) => { + const output = await client.fillMask({ model: modelId, provider: "bytez", inputs: "Hello " }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + token: expect.any(Number), + token_str: expect.any(String), + sequence: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "text-classification", + modelId: "AdamCodd/distilbert-base-uncased-finetuned-sentiment-amazon", + test: async (modelId: string) => { + const output = await client.textClassification({ + model: modelId, + provider: "bytez", + inputs: "I am a special unicorn", + }); + expect(output.every((entry) => entry.label && entry.score)).toBe(true); + }, + stream: false, + }, + { + task: "token-classification", + modelId: "2rtl3/mn-xlm-roberta-base-named-entity", + test: async (modelId: string) => { + const output = await client.tokenClassification({ + model: modelId, + provider: "bytez", + inputs: "John went to NYC", + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + entity_group: expect.any(String), + score: expect.any(Number), + word: expect.any(String), + start: expect.any(Number), + end: expect.any(Number), + }), + ]) + ); + }, + stream: false, + }, + { + task: "zero-shot-classification", + modelId: "AyoubChLin/DistilBERT_eco_ZeroShot", + test: async (modelId: string) => { + const testInput = "Ninja turtles are cool"; + const testCandidateLabels = ["positive", "negative"]; + const output = await client.zeroShotClassification({ + model: modelId, + provider: "bytez", + inputs: [ + testInput, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any, + parameters: { candidate_labels: testCandidateLabels }, + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + sequence: testInput, + labels: testCandidateLabels, + scores: [expect.closeTo(0.5206031203269958, 5), expect.closeTo(0.479396790266037, 5)], + }), + ]) + ); + }, + stream: false, + }, + { + task: "audio-classification", + modelId: "aaraki/wav2vec2-base-finetuned-ks", + test: async (modelId: string) => { + const output = await client.audioClassification({ + model: modelId, + provider: "bytez", + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "text-to-speech", + modelId: "facebook/mms-tts-eng", + test: async (modelId: string) => { + const output = await client.textToSpeech({ model: modelId, provider: "bytez", inputs: "Hello" }); + expect(output).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "automatic-speech-recognition", + modelId: "facebook/data2vec-audio-base-960h", + test: async (modelId: string) => { + const output = await client.automaticSpeechRecognition({ + model: modelId, + provider: "bytez", + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }); + expect(output).toMatchObject({ + text: "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP AUDIENCES IN DRAUGHTY SCHOOLROOMS DAY AFTER DAY FOR A FORTNIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS", + }); + }, + stream: false, + }, + ]; - let fullResponse = ""; - for await (const chunk of stream) { - if (chunk.choices && chunk.choices.length > 0) { - const content = chunk.choices[0].text; - if (content) { - fullResponse += content; - } - } - } + // bootstrap the inference mappings for testing + for (const { task, modelId } of tests) { + HARDCODED_MODEL_INFERENCE_MAPPING.bytez[modelId] = { + provider: "bytez", + hfModelId: modelId, + providerId: modelId, + status: "live", + task, + adapter: undefined, + adapterWeightsPath: undefined, + }; + } - // Verify we got a meaningful response - expect(fullResponse).toBeTruthy(); - expect(fullResponse).toContain("D"); - }); + // run the tests + for (const { task, modelId, test, stream } of tests) { + const testName = `${task} - ${modelId}${stream ? " stream" : ""}`; + it(testName, async () => { + await test(modelId); + }); + } }, TIMEOUT ); From 1cd3aa181946c6b2dc9ede8f117a4a5b154b3f99 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Mon, 6 Oct 2025 16:04:11 -0400 Subject: [PATCH 3/6] Prepare for PR. --- packages/inference/src/providers/bytez.ts | 9 +++++++++ packages/inference/src/tasks/cv/imageToText.ts | 2 -- packages/tasks/src/inference-providers.ts | 1 + 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/packages/inference/src/providers/bytez.ts b/packages/inference/src/providers/bytez.ts index 95868ff315..656f572869 100644 --- a/packages/inference/src/providers/bytez.ts +++ b/packages/inference/src/providers/bytez.ts @@ -1,3 +1,12 @@ +/** + * See the registered mapping of HF model ID => Bytez model ID here: + * + * https://huggingface.co/api/partners/bytez-ai/models + * + * Note, HF model IDs are 1-1 with Bytez model IDs. This is a publicly available mapping. + * + **/ + import type { ChatCompletionOutput, SummarizationOutput, diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts index 2c65389d96..c3a87a0b65 100644 --- a/packages/inference/src/tasks/cv/imageToText.ts +++ b/packages/inference/src/tasks/cv/imageToText.ts @@ -21,7 +21,5 @@ export async function imageToText(args: ImageToTextArgs, options?: Options): Pro task: "image-to-text", }); - // TODO the huggingface impl for this needs to be updated, used to be - // return providerHelper.getResponse(res[0]); return providerHelper.getResponse(res); } diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts index ee08d12943..da88803d54 100644 --- a/packages/tasks/src/inference-providers.ts +++ b/packages/tasks/src/inference-providers.ts @@ -1,6 +1,7 @@ /// This list is for illustration purposes only. /// in the `tasks` sub-package, we do not need actual strong typing of the inference providers. const INFERENCE_PROVIDERS = [ + "bytez", "cerebras", "cohere", "fal-ai", From a0be6e4abb08539f9b28aab3e5444950b85b31bd Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Mon, 6 Oct 2025 16:19:34 -0400 Subject: [PATCH 4/6] Change bytez indentifier to 'bytez-ai'. --- .../inference/src/lib/getProviderHelper.ts | 2 +- packages/inference/src/providers/bytez.ts | 2 +- packages/inference/src/providers/consts.ts | 2 +- packages/inference/src/types.ts | 2 +- .../inference/test/InferenceClient.spec.ts | 56 ++++++++++--------- packages/tasks/src/inference-providers.ts | 2 +- 6 files changed, 34 insertions(+), 32 deletions(-) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 90aa4043ec..c8fa95e180 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -57,7 +57,7 @@ export const PROVIDERS: Record | "conversational"; export const INFERENCE_PROVIDERS = [ "black-forest-labs", - "bytez", + "bytez-ai", "cerebras", "cohere", "fal-ai", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index eedcf91c73..6a26c8b1ff 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -352,6 +352,7 @@ describe("InferenceClient", () => { ).toMatchObject({ answer: "us-001", score: expect.any(Number), + // not sure what start/end refers to in this case start: expect.any(Number), end: expect.any(Number), }); @@ -2241,6 +2242,7 @@ describe("InferenceClient", () => { "Bytez", () => { const client = new InferenceClient(env.HF_BYTEZ_KEY ?? "dummy"); + const provider = "bytez-ai"; const tests: { task: WidgetType; modelId: string; test: (modelId: string) => Promise; stream: boolean }[] = [ @@ -2250,7 +2252,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const { generated_text } = await client.textGeneration({ model: modelId, - provider: "bytez", + provider, inputs: "Hello", }); expect(typeof generated_text).toBe("string"); @@ -2263,7 +2265,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const response = client.textGenerationStream({ model: modelId, - provider: "bytez", + provider, inputs: "Please answer the following question: complete one two and ____.", parameters: { max_new_tokens: 50, @@ -2292,7 +2294,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const { choices } = await client.chatCompletion({ model: modelId, - provider: "bytez", + provider, messages: [ { role: "system", @@ -2312,7 +2314,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const stream = client.chatCompletionStream({ model: modelId, - provider: "bytez", + provider, messages: [ { role: "system", @@ -2343,7 +2345,7 @@ describe("InferenceClient", () => { "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris..."; const { summary_text } = await client.summarization({ model: modelId, - provider: "bytez", + provider, inputs: input, parameters: { max_length: 40 }, }); @@ -2357,7 +2359,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const { translation_text } = await client.translation({ model: modelId, - provider: "bytez", + provider, inputs: "Hello", }); expect(typeof translation_text).toBe("string"); @@ -2371,7 +2373,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const res = await client.textToImage({ model: modelId, - provider: "bytez", + provider, inputs: "A cat in the hat", }); expect(res).toBeInstanceOf(Blob); @@ -2384,7 +2386,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const res = await client.textToVideo({ model: modelId, - provider: "bytez", + provider, inputs: "A cat in the hat", }); expect(res).toBeInstanceOf(Blob); @@ -2397,7 +2399,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const { generated_text } = await client.imageToText({ model: modelId, - provider: "bytez", + provider, data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), }); expect(typeof generated_text).toBe("string"); @@ -2410,7 +2412,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const { answer, score, start, end } = await client.questionAnswering({ model: modelId, - provider: "bytez", + provider, inputs: { question: "Where do I live?", context: "My name is Merve and I live in İstanbul.", @@ -2429,7 +2431,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.visualQuestionAnswering({ model: modelId, - provider: "bytez", + provider, inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), question: "What kind of animal is this?", @@ -2451,7 +2453,7 @@ describe("InferenceClient", () => { const blob = await response.blob(); const output = await client.documentQuestionAnswering({ model: modelId, - provider: "bytez", + provider, inputs: { // image: blob, @@ -2474,7 +2476,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.imageSegmentation({ model: modelId, - provider: "bytez", + provider, inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), }); expect(output).toEqual( @@ -2496,7 +2498,7 @@ describe("InferenceClient", () => { const output = await client.imageClassification({ // model: modelId, - provider: "bytez", + provider, data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), }); expect(output).toEqual( @@ -2516,7 +2518,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.zeroShotImageClassification({ model: modelId, - provider: "bytez", + provider, inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }) }, parameters: { candidate_labels: ["animal", "toy", "car"], @@ -2539,7 +2541,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.objectDetection({ model: modelId, - provider: "bytez", + provider, inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), }); expect(output).toEqual( @@ -2565,7 +2567,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.featureExtraction({ model: modelId, - provider: "bytez", + provider, inputs: "That is a happy person", }); expect(output).toEqual(expect.arrayContaining([expect.any(Number)])); @@ -2578,7 +2580,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.sentenceSimilarity({ model: modelId, - provider: "bytez", + provider, inputs: { source_sentence: "That is a happy person", sentences: ["That is a happy dog", "That is a very happy person", "Today is a sunny day"], @@ -2592,7 +2594,7 @@ describe("InferenceClient", () => { task: "fill-mask", modelId: "almanach/camembert-base", test: async (modelId: string) => { - const output = await client.fillMask({ model: modelId, provider: "bytez", inputs: "Hello " }); + const output = await client.fillMask({ model: modelId, provider, inputs: "Hello " }); expect(output).toEqual( expect.arrayContaining([ expect.objectContaining({ @@ -2612,7 +2614,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.textClassification({ model: modelId, - provider: "bytez", + provider, inputs: "I am a special unicorn", }); expect(output.every((entry) => entry.label && entry.score)).toBe(true); @@ -2625,7 +2627,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.tokenClassification({ model: modelId, - provider: "bytez", + provider, inputs: "John went to NYC", }); expect(output).toEqual( @@ -2650,7 +2652,7 @@ describe("InferenceClient", () => { const testCandidateLabels = ["positive", "negative"]; const output = await client.zeroShotClassification({ model: modelId, - provider: "bytez", + provider, inputs: [ testInput, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -2675,7 +2677,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.audioClassification({ model: modelId, - provider: "bytez", + provider, data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), }); expect(output).toEqual( @@ -2693,7 +2695,7 @@ describe("InferenceClient", () => { task: "text-to-speech", modelId: "facebook/mms-tts-eng", test: async (modelId: string) => { - const output = await client.textToSpeech({ model: modelId, provider: "bytez", inputs: "Hello" }); + const output = await client.textToSpeech({ model: modelId, provider, inputs: "Hello" }); expect(output).toBeInstanceOf(Blob); }, stream: false, @@ -2704,7 +2706,7 @@ describe("InferenceClient", () => { test: async (modelId: string) => { const output = await client.automaticSpeechRecognition({ model: modelId, - provider: "bytez", + provider, data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), }); expect(output).toMatchObject({ @@ -2717,8 +2719,8 @@ describe("InferenceClient", () => { // bootstrap the inference mappings for testing for (const { task, modelId } of tests) { - HARDCODED_MODEL_INFERENCE_MAPPING.bytez[modelId] = { - provider: "bytez", + HARDCODED_MODEL_INFERENCE_MAPPING["bytez-ai"][modelId] = { + provider, hfModelId: modelId, providerId: modelId, status: "live", diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts index da88803d54..f653741d1e 100644 --- a/packages/tasks/src/inference-providers.ts +++ b/packages/tasks/src/inference-providers.ts @@ -1,7 +1,7 @@ /// This list is for illustration purposes only. /// in the `tasks` sub-package, we do not need actual strong typing of the inference providers. const INFERENCE_PROVIDERS = [ - "bytez", + "bytez-ai", "cerebras", "cohere", "fal-ai", From 3bf26bec2f53c498d05df7d41a0607335990a621 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Mon, 6 Oct 2025 16:45:03 -0400 Subject: [PATCH 5/6] Prepare for PR. --- packages/inference/src/lib/getProviderHelper.ts | 2 +- packages/inference/src/providers/{bytez.ts => bytez-ai.ts} | 4 ++-- packages/inference/src/providers/hf-inference.ts | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) rename packages/inference/src/providers/{bytez.ts => bytez-ai.ts} (99%) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index c8fa95e180..103a69dcff 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,5 +1,5 @@ import * as BlackForestLabs from "../providers/black-forest-labs.js"; -import * as Bytez from "../providers/bytez.js"; +import * as Bytez from "../providers/bytez-ai.js"; import * as Cerebras from "../providers/cerebras.js"; import * as Cohere from "../providers/cohere.js"; import * as FalAI from "../providers/fal-ai.js"; diff --git a/packages/inference/src/providers/bytez.ts b/packages/inference/src/providers/bytez-ai.ts similarity index 99% rename from packages/inference/src/providers/bytez.ts rename to packages/inference/src/providers/bytez-ai.ts index 8946fd7e1f..82fcadfbf0 100644 --- a/packages/inference/src/providers/bytez.ts +++ b/packages/inference/src/providers/bytez-ai.ts @@ -175,8 +175,8 @@ export interface BytezChatLikeOutput { error: string; } -// const BASE_URL = "https://api.bytez.com" -const BASE_URL = "http://localhost:8080"; +const BASE_URL = "https://api.bytez.com"; +// const BASE_URL = "http://localhost:8080"; abstract class BytezTask extends TaskProviderHelper { constructor(url?: string) { diff --git a/packages/inference/src/providers/hf-inference.ts b/packages/inference/src/providers/hf-inference.ts index 757302f1fd..82e45d1a8a 100644 --- a/packages/inference/src/providers/hf-inference.ts +++ b/packages/inference/src/providers/hf-inference.ts @@ -361,13 +361,14 @@ export class HFInferenceImageSegmentationTask extends HFInferenceTask implements } export class HFInferenceImageToTextTask extends HFInferenceTask implements ImageToTextTaskHelper { - override async getResponse(response: ImageToTextOutput): Promise { - if (typeof response?.generated_text !== "string") { + override async getResponse(response: ImageToTextOutput[]): Promise { + const [first] = response + if (typeof first?.generated_text !== "string") { throw new InferenceClientProviderOutputError( "Received malformed response from HF-Inference image-to-text API: expected {generated_text: string}" ); } - return response; + return first; } } From cd308fff38301b55efad6c622a449456a7aac8f6 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Mon, 6 Oct 2025 16:47:09 -0400 Subject: [PATCH 6/6] Make Bytez tests concurrent. --- packages/inference/test/InferenceClient.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 6a26c8b1ff..0db5dd51a1 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -2238,7 +2238,7 @@ describe("InferenceClient", () => { }, TIMEOUT ); - describe.only( + describe.concurrent( "Bytez", () => { const client = new InferenceClient(env.HF_BYTEZ_KEY ?? "dummy");