diff --git a/.gitignore b/.gitignore index 68c13e55e8..eba13e678a 100644 --- a/.gitignore +++ b/.gitignore @@ -107,4 +107,5 @@ dist .DS_Store # Generated by doc-internal -docs \ No newline at end of file +docs +.vscode/launch.json diff --git a/packages/inference/README.md b/packages/inference/README.md index db3f64b35f..464f5d462e 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -68,6 +68,7 @@ Currently, we support the following providers: - [Cerebras](https://cerebras.ai/) - [Groq](https://groq.com) - [Z.ai](https://z.ai/) +- [Bytez](https://bytez.com) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index ee39f50342..20899d3f30 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,6 +1,7 @@ import * as Baseten from "../providers/baseten.js"; import * as Clarifai from "../providers/clarifai.js"; import * as BlackForestLabs from "../providers/black-forest-labs.js"; +import * as Bytez from "../providers/bytez-ai.js"; import * as Cerebras from "../providers/cerebras.js"; import * as Cohere from "../providers/cohere.js"; import * as FalAI from "../providers/fal-ai.js"; @@ -63,6 +64,32 @@ export const PROVIDERS: Record Bytez model ID here: + * + * https://huggingface.co/api/partners/bytez-ai/models + * + * Note, HF model IDs are 1-1 with Bytez model IDs. This is a publicly available mapping. + * + **/ + +import type { + ChatCompletionOutput, + SummarizationOutput, + TextGenerationOutput, + TranslationOutput, + QuestionAnsweringOutput, + QuestionAnsweringInput, + VisualQuestionAnsweringOutput, + VisualQuestionAnsweringInput, + DocumentQuestionAnsweringOutput, + DocumentQuestionAnsweringInput, + QuestionAnsweringOutputElement, + ImageSegmentationInput, + ZeroShotClassificationInput, + ZeroShotImageClassificationOutput, + ZeroShotImageClassificationInput, + BoundingBox, + FeatureExtractionOutput, + FeatureExtractionInput, + SentenceSimilarityOutput, + SentenceSimilarityInput, + FillMaskOutput, + FillMaskInput, + TextClassificationOutput, + TokenClassificationOutput, + ZeroShotClassificationOutput, + TextToSpeechInput, + AutomaticSpeechRecognitionInput, + AutomaticSpeechRecognitionOutput, + ObjectDetectionInput, + ObjectDetectionOutput, + AudioClassificationInput, + AudioClassificationOutput, + ImageClassificationInput, + ImageClassificationOutput, + ImageToTextOutput, +} from "@huggingface/tasks"; +import type { BaseArgs, BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.js"; +import type { + AudioClassificationTaskHelper, + AutomaticSpeechRecognitionTaskHelper, + ConversationalTaskHelper, + DocumentQuestionAnsweringTaskHelper, + FeatureExtractionTaskHelper, + FillMaskTaskHelper, + ImageClassificationTaskHelper, + ImageSegmentationTaskHelper, + ImageToTextTaskHelper, + ObjectDetectionTaskHelper, + QuestionAnsweringTaskHelper, + SentenceSimilarityTaskHelper, + SummarizationTaskHelper, + TextClassificationTaskHelper, + TextToImageTaskHelper, + TextToSpeechTaskHelper, + TextToVideoTaskHelper, + TokenClassificationTaskHelper, + TranslationTaskHelper, + VisualQuestionAnsweringTaskHelper, + ZeroShotClassificationTaskHelper, + ZeroShotImageClassificationTaskHelper, +} from "./providerHelper.js"; +import { + TaskProviderHelper, + type TextGenerationTaskHelper, + // type TextToVideoTaskHelper, +} from "./providerHelper.js"; +import type { ImageSegmentationOutput } from "../../../tasks/dist/commonjs/index.js"; +import { base64FromBytes } from "../utils/base64FromBytes.js"; +import type { + AudioClassificationArgs, + AutomaticSpeechRecognitionArgs, + ImageClassificationArgs, + ImageSegmentationArgs, + ImageToTextArgs, + ObjectDetectionArgs, +} from "../tasks/index.js"; + +export interface BytezStringLikeOutput { + output: string; + error: string; +} + +export interface BytezQuestionAnsweringOutput { + output: QuestionAnsweringOutputElement; + error: string; +} + +export interface BytezDocumentQuestionAnsweringOutput { + output: DocumentQuestionAnsweringOutput; + error: string; +} + +export interface BytezImageSegmentationOutput { + output: { + score: number; + label: string; + mask_png: string; + }[]; + error: string; +} +export interface BytezImageClassificationOutput { + output: { + score: number; + label: string; + }[]; + error: string; +} + +export interface BytezZeroShotImageClassificationOutput { + output: { + score: number; + label: string; + }[]; + error: string; +} + +export interface BytezTokenClassificationOutput { + output: { + index: number; + entity: string; + score: number; + word: string; + start: number; + end: number; + }[]; + error: string; +} + +export interface BytezZeroShotClassificationOutput { + output: ZeroShotClassificationOutput; + error: string; +} + +export interface BytezObjectDetectionOutput { + output: { + score: number; + label: string; + box: BoundingBox; + }[]; + error: string; +} + +export interface BytezFeatureExtractionOutput { + output: FeatureExtractionOutput[]; + error: string; +} + +export interface BytezSentenceSimilarityOutput { + output: number[][]; + error: string; +} + +export interface BytezSentenceFillMaskOutput { + output: FillMaskOutput; + error: string; +} + +export interface BytezVisualQuestionAnsweringOutput { + output: VisualQuestionAnsweringOutput; + error: string; +} + +export interface BytezChatLikeOutput { + output: ChatCompletionOutput; + error: string; +} + +const BASE_URL = "https://api.bytez.com"; +// const BASE_URL = "http://localhost:8080"; + +abstract class BytezTask extends TaskProviderHelper { + constructor(url?: string) { + super("bytez-ai", url || BASE_URL); + } + + makeRoute(params: UrlParams): string { + return `models/v2/${params.model}`; + } + + // we always pass in "application/json" + override prepareHeaders(params: HeaderParams, binary: boolean): Record { + void binary; + + const headers: Record = { Authorization: `Key ${params.accessToken}` }; + headers["Content-Type"] = "application/json"; + return headers; + } + + // we always want this behavior with out API, we only support JSON payloads + override makeBody(params: BodyParams): BodyInit { + return JSON.stringify(this.preparePayload(params)); + } + + async _preparePayloadAsync(args: Record) { + if ("inputs" in args) { + const input = args.inputs as Blob; + const arrayBuffer = await input.arrayBuffer(); + const uint8Array = new Uint8Array(arrayBuffer); + const base64 = base64FromBytes(uint8Array); + + return { + ...args, + inputs: base64, + }; + } else { + // handle LegacyImageInput case + const data = args.data as Blob; + const arrayBuffer = data instanceof Blob ? await data.arrayBuffer() : data; + const uint8Array = new Uint8Array(arrayBuffer); + const base64 = base64FromBytes(uint8Array); + + return { + ...args, + inputs: base64, + }; + } + } + + handleError(error: string) { + if (error) { + throw new Error(`There was a problem with the Bytez API: ${error}`); + } + } +} + +export class BytezTextGenerationTask extends BytezTask implements TextGenerationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + stream: params.args.stream, + complianceFormat: params.args.stream ? "hf://text-generation" : undefined, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + generated_text: output, + }; + } +} + +export class BytezConversationalTask extends BytezTask implements ConversationalTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + messages: params.args.messages, + params: params.args.parameters, + stream: params.args.stream, + // HF uses the same schema as OAI compliant spec + complianceFormat: "openai://chat/completions", + }; + } + override async getResponse( + response: BytezChatLikeOutput, + url: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezSummarizationTask extends BytezTask implements SummarizationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + summary_text: output, + }; + } +} + +export class BytezTranslationTask extends BytezTask implements TranslationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + translation_text: output, + }; + } +} + +export class BytezTextToImageTask extends BytezTask implements TextToImageTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise> { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const cdnResponse = await fetch(output); + + const blob = await cdnResponse.blob(); + + return blob; + } +} + +export class BytezTextToVideoTask extends BytezTask implements TextToVideoTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse(response: BytezStringLikeOutput, url?: string, headers?: HeadersInit): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const cdnResponse = await fetch(output); + + const blob = await cdnResponse.blob(); + + return blob; + } +} + +export class BytezImageToTextTask extends BytezTask implements ImageToTextTaskHelper { + preparePayloadAsync(args: ImageToTextArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + generatedText: output, + generated_text: output, + }; + } +} + +export class BytezQuestionAnsweringTask extends BytezTask implements QuestionAnsweringTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + question: params.args.inputs.question, + context: params.args.inputs.context, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezQuestionAnsweringOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const { answer, score, start, end } = output; + + return { + answer, + score, + start, + end, + }; + } +} + +export class BytezVisualQuestionAnsweringTask extends BytezTask implements VisualQuestionAnsweringTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs.image, + question: params.args.inputs.question, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezVisualQuestionAnsweringOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output[0]; + } +} + +export class BytezDocumentQuestionAnsweringTask extends BytezTask implements DocumentQuestionAnsweringTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs.image, + question: params.args.inputs.question, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezDocumentQuestionAnsweringOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output[0]; + } +} + +export class BytezImageSegmentationTask extends BytezTask implements ImageSegmentationTaskHelper { + preparePayloadAsync(args: ImageSegmentationArgs): Promise { + return this._preparePayloadAsync(args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezImageSegmentationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + // some models return no score, in which case we default to -1 to indicate that a number is expected, but no number was produced + return output.map(({ label, score, mask_png }) => ({ label, score: score || -1, mask: mask_png })); + } +} + +export class BytezImageClassificationTask extends BytezTask implements ImageClassificationTaskHelper { + preparePayloadAsync(args: ImageClassificationArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezZeroShotImageClassificationTask extends BytezTask implements ZeroShotImageClassificationTaskHelper { + override preparePayload( + params: BodyParams + ): Record | BodyInit { + const candidate_labels = params.args.parameters.candidate_labels; + + return { + base64: params.args.inputs, + candidate_labels, + params: { ...params.args.parameters, candidate_labels: undefined }, + }; + } + + override async getResponse( + response: BytezZeroShotImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezObjectDetectionTask extends BytezTask implements ObjectDetectionTaskHelper { + async preparePayloadAsync(args: ObjectDetectionArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record | BodyInit { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezObjectDetectionOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +// TODO this flattens the vectors, this may not be the desired behavior, see how our test model performs when hitting HF directly +// and compare the HF impl model's output through our own api +export class BytezFeatureExtractionTask extends BytezTask implements FeatureExtractionTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezFeatureExtractionOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output.flat(); + } +} + +export class BytezSentenceSimilarityTask extends BytezTask implements SentenceSimilarityTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: [params.args.inputs.source_sentence, ...params.args.inputs.sentences], + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezSentenceSimilarityOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const [sourceSentenceVector, ...sentenceVectors] = output; + + const similarityScores = []; + + for (const sentenceVector of sentenceVectors) { + const similarity = this.cosineSimilarity(sourceSentenceVector, sentenceVector); + similarityScores.push(similarity); + } + + return similarityScores; + } + + cosineSimilarity(a: number[], b: number[]): number { + if (a.length !== b.length) throw new Error("Vectors must be same length"); + let dot = 0, + normA = 0, + normB = 0; + for (let i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] ** 2; + normB += b[i] ** 2; + } + if (normA === 0 || normB === 0) return 0; + return dot / (Math.sqrt(normA) * Math.sqrt(normB)); + } +} + +export class BytezFillMaskTask extends BytezTask implements FillMaskTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezSentenceFillMaskOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezTextClassificationTask extends BytezTask implements TextClassificationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezZeroShotImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezTokenClassificationTask extends BytezTask implements TokenClassificationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezTokenClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output.map(({ end, entity, score, start, word }) => ({ + entity_group: entity, + score, + word, + start, + end, + })); + } +} + +export class BytezZeroShotClassificationTask extends BytezTask implements ZeroShotClassificationTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + candidate_labels: params.args.parameters.candidate_labels, + params: { + ...params.args.parameters, + candidate_labels: undefined, + }, + }; + } + + override async getResponse( + response: BytezZeroShotImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezAudioClassificationTask extends BytezTask implements AudioClassificationTaskHelper { + async preparePayloadAsync(args: AudioClassificationArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record | BodyInit { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezZeroShotImageClassificationOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return output; + } +} + +export class BytezTextToSpeechTask extends BytezTask implements TextToSpeechTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + text: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse(response: BytezStringLikeOutput, url?: string, headers?: HeadersInit): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + const byteArray = Buffer.from(output, "base64"); + return new Blob([byteArray], { type: "audio/wav" }); + } +} + +export class BytezAutomaticSpeechRecognitionTask extends BytezTask implements AutomaticSpeechRecognitionTaskHelper { + async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + text: output.trim(), + }; + } +} + +export class BytezTextToAudioTask extends BytezTask implements AutomaticSpeechRecognitionTaskHelper { + async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise { + const _args = args as Record; + return this._preparePayloadAsync(_args); + } + + override preparePayload(params: BodyParams): Record { + return { + base64: params.args.inputs, + params: params.args.parameters, + }; + } + + override async getResponse( + response: BytezStringLikeOutput, + url?: string, + headers?: HeadersInit + ): Promise { + void url; + void headers; + + const { error, output } = response; + + this.handleError(error); + + return { + text: output.trim(), + }; + } +} diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index f93d890535..bed40c7234 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -20,6 +20,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record< */ baseten: {}, "black-forest-labs": {}, + ["bytez-ai"]: {}, cerebras: {}, clarifai: {}, cohere: {}, diff --git a/packages/inference/src/providers/hf-inference.ts b/packages/inference/src/providers/hf-inference.ts index 757302f1fd..82e45d1a8a 100644 --- a/packages/inference/src/providers/hf-inference.ts +++ b/packages/inference/src/providers/hf-inference.ts @@ -361,13 +361,14 @@ export class HFInferenceImageSegmentationTask extends HFInferenceTask implements } export class HFInferenceImageToTextTask extends HFInferenceTask implements ImageToTextTaskHelper { - override async getResponse(response: ImageToTextOutput): Promise { - if (typeof response?.generated_text !== "string") { + override async getResponse(response: ImageToTextOutput[]): Promise { + const [first] = response + if (typeof first?.generated_text !== "string") { throw new InferenceClientProviderOutputError( "Received malformed response from HF-Inference image-to-text API: expected {generated_text: string}" ); } - return response; + return first; } } diff --git a/packages/inference/src/providers/providerHelper.ts b/packages/inference/src/providers/providerHelper.ts index fc1ebc25f8..e86a36814f 100644 --- a/packages/inference/src/providers/providerHelper.ts +++ b/packages/inference/src/providers/providerHelper.ts @@ -55,6 +55,12 @@ import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js"; import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js"; import type { ImageToVideoArgs } from "../tasks/cv/imageToVideo.js"; import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js"; +import type { + AudioClassificationArgs, + ImageClassificationArgs, + ImageToTextArgs, + ObjectDetectionArgs, +} from "../tasks/index.js"; /** * Base class for task-specific provider helpers @@ -168,16 +174,19 @@ export interface ImageSegmentationTaskHelper { export interface ImageClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: ImageClassificationArgs): Promise; } export interface ObjectDetectionTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: ObjectDetectionArgs): Promise; } export interface ImageToTextTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: ImageToTextArgs): Promise; } export interface ZeroShotImageClassificationTaskHelper { @@ -267,6 +276,7 @@ export interface AutomaticSpeechRecognitionTaskHelper { export interface AudioClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: AudioClassificationArgs): Promise; } // Multimodal Tasks diff --git a/packages/inference/src/tasks/audio/audioClassification.ts b/packages/inference/src/tasks/audio/audioClassification.ts index 89737e2597..13da90be0c 100644 --- a/packages/inference/src/tasks/audio/audioClassification.ts +++ b/packages/inference/src/tasks/audio/audioClassification.ts @@ -18,7 +18,9 @@ export async function audioClassification( ): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "audio-classification"); - const payload = preparePayload(args); + const payload = providerHelper.preparePayloadAsync + ? await providerHelper.preparePayloadAsync(args) + : preparePayload(args); const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "audio-classification", diff --git a/packages/inference/src/tasks/cv/imageClassification.ts b/packages/inference/src/tasks/cv/imageClassification.ts index 2a39f57e26..1a2d373d41 100644 --- a/packages/inference/src/tasks/cv/imageClassification.ts +++ b/packages/inference/src/tasks/cv/imageClassification.ts @@ -17,7 +17,9 @@ export async function imageClassification( ): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "image-classification"); - const payload = preparePayload(args); + const payload = providerHelper.preparePayloadAsync + ? await providerHelper.preparePayloadAsync(args) + : preparePayload(args); const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "image-classification", diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts index 7784b3630c..c3a87a0b65 100644 --- a/packages/inference/src/tasks/cv/imageToText.ts +++ b/packages/inference/src/tasks/cv/imageToText.ts @@ -13,11 +13,13 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput); export async function imageToText(args: ImageToTextArgs, options?: Options): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "image-to-text"); - const payload = preparePayload(args); + const payload = providerHelper.preparePayloadAsync + ? await providerHelper.preparePayloadAsync(args) + : preparePayload(args); const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, { ...options, task: "image-to-text", }); - return providerHelper.getResponse(res[0]); + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/cv/objectDetection.ts b/packages/inference/src/tasks/cv/objectDetection.ts index 7f9741c5aa..a6e747cb11 100644 --- a/packages/inference/src/tasks/cv/objectDetection.ts +++ b/packages/inference/src/tasks/cv/objectDetection.ts @@ -14,7 +14,10 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "object-detection"); - const payload = preparePayload(args); + const payload = providerHelper.preparePayloadAsync + ? await providerHelper.preparePayloadAsync(args) + : preparePayload(args); + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "object-detection", diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 215252efda..7762d1ec67 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -47,6 +47,7 @@ export type InferenceTask = Exclude | "conversational"; export const INFERENCE_PROVIDERS = [ "baseten", "black-forest-labs", + "bytez-ai", "cerebras", "clarifai", "cohere", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 3ea6b55241..01246a9c4e 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1,6 +1,6 @@ import { assert, describe, expect, it } from "vitest"; -import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; +import type { ChatCompletionStreamOutput, WidgetType } from "@huggingface/tasks"; import type { TextToImageArgs } from "../src/index.js"; import { @@ -22,7 +22,7 @@ if (!env.HF_TOKEN) { console.warn("Set HF_TOKEN in the env to run the tests for better rate limits"); } -describe.skip("InferenceClient", () => { +describe("InferenceClient", () => { // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error. describe("backward compatibility", () => { @@ -2290,7 +2290,6 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); - describe.concurrent( "PublicAI", () => { @@ -2459,4 +2458,502 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + + describe.concurrent("Bytez", () => { + const client = new InferenceClient(env.HF_BYTEZ_KEY ?? "dummy"); + const provider = "bytez-ai"; + + const tests: { task: WidgetType; modelId: string; test: (modelId: string) => Promise; stream: boolean }[] = [ + { + task: "text-generation", + modelId: "bigscience/mt0-small", + test: async (modelId: string) => { + const { generated_text } = await client.textGeneration({ + model: modelId, + provider, + inputs: "Hello", + }); + expect(typeof generated_text).toBe("string"); + }, + stream: false, + }, + { + task: "text-generation", + modelId: "bigscience/mt0-small", + test: async (modelId: string) => { + const response = client.textGenerationStream({ + model: modelId, + provider, + inputs: "Please answer the following question: complete one two and ____.", + parameters: { + max_new_tokens: 50, + num_beams: 1, + }, + }); + for await (const ret of response) { + expect(ret).toMatchObject({ + details: null, + index: expect.any(Number), + token: { + id: expect.any(Number), + logprob: expect.any(Number), + special: expect.any(Boolean), + }, + generated_text: ret.generated_text ? "afew" : null, + }); + expect(typeof ret.token.text === "string" || ret.token.text === null).toBe(true); + } + }, + stream: true, + }, + { + task: "conversational", + modelId: "Qwen/Qwen3-1.7B", + test: async (modelId: string) => { + const { choices } = await client.chatCompletion({ + model: modelId, + provider, + messages: [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { role: "user", content: "How many helicopters can a human eat in one sitting?" }, + ], + }); + expect(typeof choices[0].message.role).toBe("string"); + expect(typeof choices[0].message.content).toBe("string"); + }, + stream: false, + }, + { + task: "conversational", + modelId: "Qwen/Qwen3-1.7B", + test: async (modelId: string) => { + const stream = client.chatCompletionStream({ + model: modelId, + provider, + messages: [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { role: "user", content: "How many helicopters can a human eat in one sitting?" }, + ], + }); + const chunks = []; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + if (chunk.choices[0].finish_reason === "stop") { + break; + } + chunks.push(chunk); + } + } + const out = chunks.map((chunk) => chunk.choices[0].delta.content).join(""); + expect(out).toContain("helicopter"); + }, + stream: true, + }, + { + task: "summarization", + modelId: "ainize/bart-base-cnn", + test: async (modelId: string) => { + const input = + "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris..."; + const { summary_text } = await client.summarization({ + model: modelId, + provider, + inputs: input, + parameters: { max_length: 40 }, + }); + expect(summary_text.length).toBeLessThan(input.length); + }, + stream: false, + }, + { + task: "translation", + modelId: "Areeb123/En-Fr_Translation_Model", + test: async (modelId: string) => { + const { translation_text } = await client.translation({ + model: modelId, + provider, + inputs: "Hello", + }); + expect(typeof translation_text).toBe("string"); + expect(translation_text.length).toBeGreaterThan(0); + }, + stream: false, + }, + { + task: "text-to-image", + modelId: "IDKiro/sdxs-512-0.9", + test: async (modelId: string) => { + const res = await client.textToImage({ + model: modelId, + provider, + inputs: "A cat in the hat", + }); + expect(res).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "text-to-video", + modelId: "ali-vilab/text-to-video-ms-1.7b", + test: async (modelId: string) => { + const res = await client.textToVideo({ + model: modelId, + provider, + inputs: "A cat in the hat", + }); + expect(res).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "image-to-text", + modelId: "captioner/caption-gen", + test: async (modelId: string) => { + const { generated_text } = await client.imageToText({ + model: modelId, + provider, + data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + }); + expect(typeof generated_text).toBe("string"); + }, + stream: false, + }, + { + task: "question-answering", + modelId: "airesearch/xlm-roberta-base-finetune-qa", + test: async (modelId: string) => { + const { answer, score, start, end } = await client.questionAnswering({ + model: modelId, + provider, + inputs: { + question: "Where do I live?", + context: "My name is Merve and I live in İstanbul.", + }, + }); + expect(answer).toBeDefined(); + expect(score).toBeDefined(); + expect(start).toBeDefined(); + expect(end).toBeDefined(); + }, + stream: false, + }, + { + task: "visual-question-answering", + modelId: "aqachun/Vilt_fine_tune_2000", + test: async (modelId: string) => { + const output = await client.visualQuestionAnswering({ + model: modelId, + provider, + inputs: { + image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + question: "What kind of animal is this?", + }, + }); + expect(output).toMatchObject({ + answer: expect.any(String), + score: expect.any(Number), + }); + }, + stream: false, + }, + { + task: "document-question-answering", + modelId: "cloudqi/CQI_Visual_Question_Awnser_PT_v0", + test: async (modelId: string) => { + const url = "https://templates.invoicehome.com/invoice-template-us-neat-750px.png"; + const response = await fetch(url); + const blob = await response.blob(); + const output = await client.documentQuestionAnswering({ + model: modelId, + provider, + inputs: { + // + image: blob, + question: "What's the total cost?", + }, + }); + expect(output).toMatchObject({ + answer: expect.any(String), + score: expect.any(Number), + // not sure what start/end refers to in this case + start: expect.any(Number), + end: expect.any(Number), + }); + }, + stream: false, + }, + { + task: "image-segmentation", + modelId: "apple/deeplabv3-mobilevit-small", + test: async (modelId: string) => { + const output = await client.imageSegmentation({ + model: modelId, + provider, + inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + label: expect.any(String), + score: expect.any(Number), + mask: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "image-classification", + modelId: "akahana/vit-base-cats-vs-dogs", + test: async (modelId: string) => { + const output = await client.imageClassification({ + // + model: modelId, + provider, + data: new Blob([readTestFile("cheetah.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "zero-shot-image-classification", + modelId: "BilelDJ/clip-hugging-face-finetuned", + test: async (modelId: string) => { + const output = await client.zeroShotImageClassification({ + model: modelId, + provider, + inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }) }, + parameters: { + candidate_labels: ["animal", "toy", "car"], + }, + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "object-detection", + modelId: "aisak-ai/aisak-detect", + test: async (modelId: string) => { + const output = await client.objectDetection({ + model: modelId, + provider, + inputs: new Blob([readTestFile("cats.png")], { type: "image/png" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + box: expect.objectContaining({ + xmin: expect.any(Number), + ymin: expect.any(Number), + xmax: expect.any(Number), + ymax: expect.any(Number), + }), + }), + ]) + ); + }, + stream: false, + }, + { + task: "feature-extraction", + modelId: "allenai/specter2_base", + test: async (modelId: string) => { + const output = await client.featureExtraction({ + model: modelId, + provider, + inputs: "That is a happy person", + }); + expect(output).toEqual(expect.arrayContaining([expect.any(Number)])); + }, + stream: false, + }, + { + task: "sentence-similarity", + modelId: "embedding-data/distilroberta-base-sentence-transformer", + test: async (modelId: string) => { + const output = await client.sentenceSimilarity({ + model: modelId, + provider, + inputs: { + source_sentence: "That is a happy person", + sentences: ["That is a happy dog", "That is a very happy person", "Today is a sunny day"], + }, + }); + expect(output).toEqual([expect.any(Number), expect.any(Number), expect.any(Number)]); + }, + stream: false, + }, + { + task: "fill-mask", + modelId: "almanach/camembert-base", + test: async (modelId: string) => { + const output = await client.fillMask({ model: modelId, provider, inputs: "Hello " }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + token: expect.any(Number), + token_str: expect.any(String), + sequence: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "text-classification", + modelId: "AdamCodd/distilbert-base-uncased-finetuned-sentiment-amazon", + test: async (modelId: string) => { + const output = await client.textClassification({ + model: modelId, + provider, + inputs: "I am a special unicorn", + }); + expect(output.every((entry) => entry.label && entry.score)).toBe(true); + }, + stream: false, + }, + { + task: "token-classification", + modelId: "2rtl3/mn-xlm-roberta-base-named-entity", + test: async (modelId: string) => { + const output = await client.tokenClassification({ + model: modelId, + provider, + inputs: "John went to NYC", + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + entity_group: expect.any(String), + score: expect.any(Number), + word: expect.any(String), + start: expect.any(Number), + end: expect.any(Number), + }), + ]) + ); + }, + stream: false, + }, + { + task: "zero-shot-classification", + modelId: "AyoubChLin/DistilBERT_eco_ZeroShot", + test: async (modelId: string) => { + const testInput = "Ninja turtles are cool"; + const testCandidateLabels = ["positive", "negative"]; + const output = await client.zeroShotClassification({ + model: modelId, + provider, + inputs: [ + testInput, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any, + parameters: { candidate_labels: testCandidateLabels }, + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + sequence: testInput, + labels: testCandidateLabels, + scores: [expect.closeTo(0.5206031203269958, 5), expect.closeTo(0.479396790266037, 5)], + }), + ]) + ); + }, + stream: false, + }, + { + task: "audio-classification", + modelId: "aaraki/wav2vec2-base-finetuned-ks", + test: async (modelId: string) => { + const output = await client.audioClassification({ + model: modelId, + provider, + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }); + expect(output).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + score: expect.any(Number), + label: expect.any(String), + }), + ]) + ); + }, + stream: false, + }, + { + task: "text-to-speech", + modelId: "facebook/mms-tts-eng", + test: async (modelId: string) => { + const output = await client.textToSpeech({ model: modelId, provider, inputs: "Hello" }); + expect(output).toBeInstanceOf(Blob); + }, + stream: false, + }, + { + task: "automatic-speech-recognition", + modelId: "facebook/data2vec-audio-base-960h", + test: async (modelId: string) => { + const output = await client.automaticSpeechRecognition({ + model: modelId, + provider, + data: new Blob([readTestFile("sample1.flac")], { type: "audio/flac" }), + }); + expect(output).toMatchObject({ + text: "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP AUDIENCES IN DRAUGHTY SCHOOLROOMS DAY AFTER DAY FOR A FORTNIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS", + }); + }, + stream: false, + }, + ]; + + // bootstrap the inference mappings for testing + for (const { task, modelId } of tests) { + HARDCODED_MODEL_INFERENCE_MAPPING["bytez-ai"][modelId] = { + provider, + hfModelId: modelId, + providerId: modelId, + status: "live", + task, + adapter: undefined, + adapterWeightsPath: undefined, + }; + } + + // run the tests + for (const { task, modelId, test, stream } of tests) { + const testName = `${task} - ${modelId}${stream ? " stream" : ""}`; + it(testName, async () => { + await test(modelId); + }); + } + }); }); diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts index ee08d12943..f653741d1e 100644 --- a/packages/tasks/src/inference-providers.ts +++ b/packages/tasks/src/inference-providers.ts @@ -1,6 +1,7 @@ /// This list is for illustration purposes only. /// in the `tasks` sub-package, we do not need actual strong typing of the inference providers. const INFERENCE_PROVIDERS = [ + "bytez-ai", "cerebras", "cohere", "fal-ai",