diff --git a/packages/inference/README.md b/packages/inference/README.md index c4a2ea4b1d..241ba7a523 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -572,31 +572,6 @@ await hf.tabularClassification({ }) ``` -## Custom Calls - -For models with custom parameters / outputs. - -```typescript -await hf.request({ - model: 'my-custom-model', - inputs: 'hello world', - parameters: { - custom_param: 'some magic', - } -}) - -// Custom streaming call, for models with custom parameters / outputs -for await (const output of hf.streamingRequest({ - model: 'my-custom-model', - inputs: 'hello world', - parameters: { - custom_param: 'some magic', - } -})) { - ... -} -``` - You can use any Chat Completion API-compatible provider with the `chatCompletion` method. ```typescript diff --git a/packages/inference/src/tasks/audio/audioClassification.ts b/packages/inference/src/tasks/audio/audioClassification.ts index b17c875051..12dbeb2585 100644 --- a/packages/inference/src/tasks/audio/audioClassification.ts +++ b/packages/inference/src/tasks/audio/audioClassification.ts @@ -1,7 +1,7 @@ import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; import type { LegacyAudioInput } from "./utils"; import { preparePayload } from "./utils"; @@ -16,7 +16,7 @@ export async function audioClassification( options?: Options ): Promise { const payload = preparePayload(args); - const res = await request(payload, { + const { data: res } = await innerRequest(payload, { ...options, task: "audio-classification", }); diff --git a/packages/inference/src/tasks/audio/audioToAudio.ts b/packages/inference/src/tasks/audio/audioToAudio.ts index 84e4e79641..258873ecb7 100644 --- a/packages/inference/src/tasks/audio/audioToAudio.ts +++ b/packages/inference/src/tasks/audio/audioToAudio.ts @@ -1,6 +1,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; import type { LegacyAudioInput } from "./utils"; import { preparePayload } from "./utils"; @@ -37,7 +37,7 @@ export interface AudioToAudioOutput { */ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise { const payload = preparePayload(args); - const res = await request(payload, { + const { data: res } = await innerRequest(payload, { ...options, task: "audio-to-audio", }); diff --git a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts index 672ce0c5ba..2a792d0c00 100644 --- a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts +++ b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts @@ -2,10 +2,10 @@ import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; -import { request } from "../custom/request"; +import { omit } from "../../utils/omit"; +import { innerRequest } from "../../utils/request"; import type { LegacyAudioInput } from "./utils"; import { preparePayload } from "./utils"; -import { omit } from "../../utils/omit"; export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput); /** @@ -17,7 +17,7 @@ export async function automaticSpeechRecognition( options?: Options ): Promise { const payload = await buildPayload(args); - const res = await request(payload, { + const { data: res } = await innerRequest(payload, { ...options, task: "automatic-speech-recognition", }); diff --git a/packages/inference/src/tasks/audio/textToSpeech.ts b/packages/inference/src/tasks/audio/textToSpeech.ts index 72d053ffad..2839322bde 100644 --- a/packages/inference/src/tasks/audio/textToSpeech.ts +++ b/packages/inference/src/tasks/audio/textToSpeech.ts @@ -2,7 +2,7 @@ import type { TextToSpeechInput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; import { omit } from "../../utils/omit"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; type TextToSpeechArgs = BaseArgs & TextToSpeechInput; interface OutputUrlTextToSpeechGeneration { @@ -22,7 +22,7 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P text: args.inputs, } : args; - const res = await request(payload, { + const { data: res } = await innerRequest(payload, { ...options, task: "text-to-speech", }); diff --git a/packages/inference/src/tasks/custom/request.ts b/packages/inference/src/tasks/custom/request.ts index 847d5d73f7..ad3dbd268e 100644 --- a/packages/inference/src/tasks/custom/request.ts +++ b/packages/inference/src/tasks/custom/request.ts @@ -1,8 +1,9 @@ import type { InferenceTask, Options, RequestArgs } from "../../types"; -import { makeRequestOptions } from "../../lib/makeRequestOptions"; +import { innerRequest } from "../../utils/request"; /** * Primitive to make custom calls to the inference provider + * @deprecated Use specific task functions instead. This function will be removed in a future version. */ export async function request( args: RequestArgs, @@ -13,35 +14,9 @@ export async function request( chatCompletion?: boolean; } ): Promise { - const { url, info } = await makeRequestOptions(args, options); - const response = await (options?.fetch ?? fetch)(url, info); - - if (options?.retry_on_error !== false && response.status === 503) { - return request(args, options); - } - - if (!response.ok) { - const contentType = response.headers.get("Content-Type"); - if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) { - const output = await response.json(); - if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) { - throw new Error( - `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}` - ); - } - if (output.error || output.detail) { - throw new Error(JSON.stringify(output.error ?? output.detail)); - } else { - throw new Error(output); - } - } - const message = contentType?.startsWith("text/plain;") ? await response.text() : undefined; - throw new Error(message ?? "An error occurred while fetching the blob"); - } - - if (response.headers.get("Content-Type")?.startsWith("application/json")) { - return await response.json(); - } - - return (await response.blob()) as T; + console.warn( + "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead." + ); + const result = await innerRequest(args, options); + return result.data; } diff --git a/packages/inference/src/tasks/custom/streamingRequest.ts b/packages/inference/src/tasks/custom/streamingRequest.ts index b1716ba003..de32eeea26 100644 --- a/packages/inference/src/tasks/custom/streamingRequest.ts +++ b/packages/inference/src/tasks/custom/streamingRequest.ts @@ -1,10 +1,8 @@ import type { InferenceTask, Options, RequestArgs } from "../../types"; -import { makeRequestOptions } from "../../lib/makeRequestOptions"; -import type { EventSourceMessage } from "../../vendor/fetch-event-source/parse"; -import { getLines, getMessages } from "../../vendor/fetch-event-source/parse"; - +import { innerStreamingRequest } from "../../utils/request"; /** * Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator + * @deprecated Use specific task functions instead. This function will be removed in a future version. */ export async function* streamingRequest( args: RequestArgs, @@ -15,86 +13,8 @@ export async function* streamingRequest( chatCompletion?: boolean; } ): AsyncGenerator { - const { url, info } = await makeRequestOptions({ ...args, stream: true }, options); - const response = await (options?.fetch ?? fetch)(url, info); - - if (options?.retry_on_error !== false && response.status === 503) { - return yield* streamingRequest(args, options); - } - if (!response.ok) { - if (response.headers.get("Content-Type")?.startsWith("application/json")) { - const output = await response.json(); - if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) { - throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`); - } - if (typeof output.error === "string") { - throw new Error(output.error); - } - if (output.error && "message" in output.error && typeof output.error.message === "string") { - /// OpenAI errors - throw new Error(output.error.message); - } - } - - throw new Error(`Server response contains error: ${response.status}`); - } - if (!response.headers.get("content-type")?.startsWith("text/event-stream")) { - throw new Error( - `Server does not support event stream content type, it returned ` + response.headers.get("content-type") - ); - } - - if (!response.body) { - return; - } - - const reader = response.body.getReader(); - let events: EventSourceMessage[] = []; - - const onEvent = (event: EventSourceMessage) => { - // accumulate events in array - events.push(event); - }; - - const onChunk = getLines( - getMessages( - () => {}, - () => {}, - onEvent - ) + console.warn( + "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead." ); - - try { - while (true) { - const { done, value } = await reader.read(); - if (done) { - return; - } - onChunk(value); - for (const event of events) { - if (event.data.length > 0) { - if (event.data === "[DONE]") { - return; - } - const data = JSON.parse(event.data); - if (typeof data === "object" && data !== null && "error" in data) { - const errorStr = - typeof data.error === "string" - ? data.error - : typeof data.error === "object" && - data.error && - "message" in data.error && - typeof data.error.message === "string" - ? data.error.message - : JSON.stringify(data.error); - throw new Error(`Error forwarded from backend: ` + errorStr); - } - yield data as T; - } - } - events = []; - } - } finally { - reader.releaseLock(); - } + yield* innerStreamingRequest(args, options); } diff --git a/packages/inference/src/tasks/cv/imageClassification.ts b/packages/inference/src/tasks/cv/imageClassification.ts index 4f7a6e6b04..e68661d7c0 100644 --- a/packages/inference/src/tasks/cv/imageClassification.ts +++ b/packages/inference/src/tasks/cv/imageClassification.ts @@ -1,7 +1,7 @@ import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; import { preparePayload, type LegacyImageInput } from "./utils"; export type ImageClassificationArgs = BaseArgs & (ImageClassificationInput | LegacyImageInput); @@ -15,7 +15,7 @@ export async function imageClassification( options?: Options ): Promise { const payload = preparePayload(args); - const res = await request(payload, { + const { data: res } = await innerRequest(payload, { ...options, task: "image-classification", }); diff --git a/packages/inference/src/tasks/cv/imageSegmentation.ts b/packages/inference/src/tasks/cv/imageSegmentation.ts index abbc808bf4..e541520786 100644 --- a/packages/inference/src/tasks/cv/imageSegmentation.ts +++ b/packages/inference/src/tasks/cv/imageSegmentation.ts @@ -1,7 +1,7 @@ import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; import { preparePayload, type LegacyImageInput } from "./utils"; export type ImageSegmentationArgs = BaseArgs & (ImageSegmentationInput | LegacyImageInput); @@ -15,7 +15,7 @@ export async function imageSegmentation( options?: Options ): Promise { const payload = preparePayload(args); - const res = await request(payload, { + const { data: res } = await innerRequest(payload, { ...options, task: "image-segmentation", }); diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index 41a098f797..37df537efb 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -2,7 +2,7 @@ import type { ImageToImageInput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type ImageToImageArgs = BaseArgs & ImageToImageInput; @@ -26,7 +26,7 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P ), }; } - const res = await request(reqArgs, { + const { data: res } = await innerRequest(reqArgs, { ...options, task: "image-to-image", }); diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts index 52cefd1fc4..ff359bac56 100644 --- a/packages/inference/src/tasks/cv/imageToText.ts +++ b/packages/inference/src/tasks/cv/imageToText.ts @@ -1,7 +1,7 @@ import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; import type { LegacyImageInput } from "./utils"; import { preparePayload } from "./utils"; @@ -11,16 +11,14 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput); */ export async function imageToText(args: ImageToTextArgs, options?: Options): Promise { const payload = preparePayload(args); - const res = ( - await request<[ImageToTextOutput]>(payload, { - ...options, - task: "image-to-text", - }) - )?.[0]; + const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, { + ...options, + task: "image-to-text", + }); - if (typeof res?.generated_text !== "string") { + if (typeof res?.[0]?.generated_text !== "string") { throw new InferenceOutputError("Expected {generated_text: string}"); } - return res; + return res?.[0]; } diff --git a/packages/inference/src/tasks/cv/objectDetection.ts b/packages/inference/src/tasks/cv/objectDetection.ts index e66372e1d8..948b59e4b3 100644 --- a/packages/inference/src/tasks/cv/objectDetection.ts +++ b/packages/inference/src/tasks/cv/objectDetection.ts @@ -1,7 +1,7 @@ -import { request } from "../custom/request"; -import type { BaseArgs, Options } from "../../types"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks"; +import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import type { BaseArgs, Options } from "../../types"; +import { innerRequest } from "../../utils/request"; import { preparePayload, type LegacyImageInput } from "./utils"; export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImageInput); @@ -12,7 +12,7 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage */ export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise { const payload = preparePayload(args); - const res = await request(payload, { + const { data: res } = await innerRequest(payload, { ...options, task: "object-detection", }); diff --git a/packages/inference/src/tasks/cv/textToImage.ts b/packages/inference/src/tasks/cv/textToImage.ts index 3d5f421e1f..ee577c51aa 100644 --- a/packages/inference/src/tasks/cv/textToImage.ts +++ b/packages/inference/src/tasks/cv/textToImage.ts @@ -1,9 +1,9 @@ import type { TextToImageInput, TextToImageOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, InferenceProvider, Options } from "../../types"; -import { omit } from "../../utils/omit"; -import { request } from "../custom/request"; import { delay } from "../../utils/delay"; +import { omit } from "../../utils/omit"; +import { innerRequest } from "../../utils/request"; export type TextToImageArgs = BaseArgs & TextToImageInput; @@ -65,7 +65,7 @@ export async function textToImage(args: TextToImageArgs, options?: TextToImageOp ...getResponseFormatArg(args.provider), prompt: args.inputs, }; - const res = await request< + const { data: res } = await innerRequest< | TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration diff --git a/packages/inference/src/tasks/cv/textToVideo.ts b/packages/inference/src/tasks/cv/textToVideo.ts index fddea60eb7..5609dd0140 100644 --- a/packages/inference/src/tasks/cv/textToVideo.ts +++ b/packages/inference/src/tasks/cv/textToVideo.ts @@ -1,12 +1,11 @@ -import type { BaseArgs, InferenceProvider, Options } from "../../types"; import type { TextToVideoInput } from "@huggingface/tasks"; -import { request } from "../custom/request"; -import { omit } from "../../utils/omit"; -import { isUrl } from "../../lib/isUrl"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; -import { typedInclude } from "../../utils/typedInclude"; -import { makeRequestOptions } from "../../lib/makeRequestOptions"; +import { isUrl } from "../../lib/isUrl"; import { pollFalResponse, type FalAiQueueOutput } from "../../providers/fal-ai"; +import type { BaseArgs, InferenceProvider, Options } from "../../types"; +import { omit } from "../../utils/omit"; +import { innerRequest } from "../../utils/request"; +import { typedInclude } from "../../utils/typedInclude"; export type TextToVideoArgs = BaseArgs & TextToVideoInput; @@ -35,37 +34,41 @@ export async function textToVideo(args: TextToVideoArgs, options?: Options): Pro args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "novita" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs } : args; - const res = await request(payload, { + const { data, requestContext } = await innerRequest(payload, { ...options, task: "text-to-video", }); + if (args.provider === "fal-ai") { - const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" }); - return await pollFalResponse(res as FalAiQueueOutput, url, info.headers as Record); + return await pollFalResponse( + data as FalAiQueueOutput, + requestContext.url, + requestContext.info.headers as Record + ); } else if (args.provider === "novita") { const isValidOutput = - typeof res === "object" && - !!res && - "video" in res && - typeof res.video === "object" && - !!res.video && - "video_url" in res.video && - typeof res.video.video_url === "string" && - isUrl(res.video.video_url); + typeof data === "object" && + !!data && + "video" in data && + typeof data.video === "object" && + !!data.video && + "video_url" in data.video && + typeof data.video.video_url === "string" && + isUrl(data.video.video_url); if (!isValidOutput) { throw new InferenceOutputError("Expected { video: { video_url: string } }"); } - const urlResponse = await fetch((res as NovitaOutput).video.video_url); + const urlResponse = await fetch((data as NovitaOutput).video.video_url); return await urlResponse.blob(); } else { /// TODO: Replicate: handle the case where the generation request "times out" / is async (ie output is null) /// https://replicate.com/docs/topics/predictions/create-a-prediction const isValidOutput = - typeof res === "object" && !!res && "output" in res && typeof res.output === "string" && isUrl(res.output); + typeof data === "object" && !!data && "output" in data && typeof data.output === "string" && isUrl(data.output); if (!isValidOutput) { throw new InferenceOutputError("Expected { output: string }"); } - const urlResponse = await fetch(res.output); + const urlResponse = await fetch(data.output); return await urlResponse.blob(); } } diff --git a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts index ca80c9c37f..e4cd7f4da8 100644 --- a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts +++ b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts @@ -1,9 +1,8 @@ +import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; -import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; -import type { RequestArgs } from "../../types"; +import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; -import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks"; +import { innerRequest } from "../../utils/request"; /** * @deprecated @@ -46,7 +45,7 @@ export async function zeroShotImageClassification( options?: Options ): Promise { const payload = await preparePayload(args); - const res = await request(payload, { + const { data: res } = await innerRequest(payload, { ...options, task: "zero-shot-image-classification", }); diff --git a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts index b0fe1af3fa..24f99bcb02 100644 --- a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts @@ -1,14 +1,13 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; -import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; -import type { RequestArgs } from "../../types"; -import { toArray } from "../../utils/toArray"; -import { base64FromBytes } from "../../utils/base64FromBytes"; import type { DocumentQuestionAnsweringInput, DocumentQuestionAnsweringInputData, DocumentQuestionAnsweringOutput, } from "@huggingface/tasks"; +import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import type { BaseArgs, Options, RequestArgs } from "../../types"; +import { base64FromBytes } from "../../utils/base64FromBytes"; +import { innerRequest } from "../../utils/request"; +import { toArray } from "../../utils/toArray"; /// Override the type to properly set inputs.image as Blob export type DocumentQuestionAnsweringArgs = BaseArgs & @@ -29,16 +28,17 @@ export async function documentQuestionAnswering( image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer())), }, } as RequestArgs; - const res = toArray( - await request(reqArgs, { + const { data: res } = await innerRequest( + reqArgs, + { ...options, task: "document-question-answering", - }) + } ); - + const output = toArray(res); const isValidOutput = - Array.isArray(res) && - res.every( + Array.isArray(output) && + output.every( (elem) => typeof elem === "object" && !!elem && @@ -51,5 +51,5 @@ export async function documentQuestionAnswering( throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>"); } - return res[0]; + return output[0]; } diff --git a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts index 536fab8340..d6d2a5379b 100644 --- a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts @@ -6,7 +6,7 @@ import type { import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; /// Override the type to properly set inputs.image as Blob export type VisualQuestionAnsweringArgs = BaseArgs & @@ -27,10 +27,12 @@ export async function visualQuestionAnswering( image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer())), }, } as RequestArgs; - const res = await request(reqArgs, { + + const { data: res } = await innerRequest(reqArgs, { ...options, task: "visual-question-answering", }); + const isValidOutput = Array.isArray(res) && res.every( diff --git a/packages/inference/src/tasks/nlp/chatCompletion.ts b/packages/inference/src/tasks/nlp/chatCompletion.ts index baa91c005d..e2eb99d4ee 100644 --- a/packages/inference/src/tasks/nlp/chatCompletion.ts +++ b/packages/inference/src/tasks/nlp/chatCompletion.ts @@ -1,7 +1,7 @@ +import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; -import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks"; +import { innerRequest } from "../../utils/request"; /** * Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream @@ -10,12 +10,11 @@ export async function chatCompletion( args: BaseArgs & ChatCompletionInput, options?: Options ): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "text-generation", chatCompletion: true, }); - const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && diff --git a/packages/inference/src/tasks/nlp/chatCompletionStream.ts b/packages/inference/src/tasks/nlp/chatCompletionStream.ts index 6984ce75ec..0008a7b915 100644 --- a/packages/inference/src/tasks/nlp/chatCompletionStream.ts +++ b/packages/inference/src/tasks/nlp/chatCompletionStream.ts @@ -1,6 +1,6 @@ -import type { BaseArgs, Options } from "../../types"; -import { streamingRequest } from "../custom/streamingRequest"; import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks"; +import type { BaseArgs, Options } from "../../types"; +import { innerStreamingRequest } from "../../utils/request"; /** * Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time @@ -9,7 +9,7 @@ export async function* chatCompletionStream( args: BaseArgs & ChatCompletionInput, options?: Options ): AsyncGenerator { - yield* streamingRequest(args, { + yield* innerStreamingRequest(args, { ...options, task: "text-generation", chatCompletion: true, diff --git a/packages/inference/src/tasks/nlp/featureExtraction.ts b/packages/inference/src/tasks/nlp/featureExtraction.ts index 25a6695a2c..451e8898f5 100644 --- a/packages/inference/src/tasks/nlp/featureExtraction.ts +++ b/packages/inference/src/tasks/nlp/featureExtraction.ts @@ -1,7 +1,7 @@ import type { FeatureExtractionInput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type FeatureExtractionArgs = BaseArgs & FeatureExtractionInput; @@ -17,7 +17,7 @@ export async function featureExtraction( args: FeatureExtractionArgs, options?: Options ): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "feature-extraction", }); diff --git a/packages/inference/src/tasks/nlp/fillMask.ts b/packages/inference/src/tasks/nlp/fillMask.ts index 9a30b056e3..59d1a59421 100644 --- a/packages/inference/src/tasks/nlp/fillMask.ts +++ b/packages/inference/src/tasks/nlp/fillMask.ts @@ -1,7 +1,7 @@ import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type FillMaskArgs = BaseArgs & FillMaskInput; @@ -9,7 +9,7 @@ export type FillMaskArgs = BaseArgs & FillMaskInput; * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models. */ export async function fillMask(args: FillMaskArgs, options?: Options): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "fill-mask", }); diff --git a/packages/inference/src/tasks/nlp/questionAnswering.ts b/packages/inference/src/tasks/nlp/questionAnswering.ts index 4141c193a2..d6d12f6bad 100644 --- a/packages/inference/src/tasks/nlp/questionAnswering.ts +++ b/packages/inference/src/tasks/nlp/questionAnswering.ts @@ -1,7 +1,7 @@ import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type QuestionAnsweringArgs = BaseArgs & QuestionAnsweringInput; @@ -12,10 +12,11 @@ export async function questionAnswering( args: QuestionAnsweringArgs, options?: Options ): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "question-answering", }); + const isValidOutput = Array.isArray(res) ? res.every( (elem) => diff --git a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts index a2d365b4fb..dea29aec67 100644 --- a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts +++ b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts @@ -1,7 +1,7 @@ import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type SentenceSimilarityArgs = BaseArgs & SentenceSimilarityInput; @@ -12,7 +12,7 @@ export async function sentenceSimilarity( args: SentenceSimilarityArgs, options?: Options ): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "sentence-similarity", }); diff --git a/packages/inference/src/tasks/nlp/summarization.ts b/packages/inference/src/tasks/nlp/summarization.ts index bf7439ba73..e4806de5d0 100644 --- a/packages/inference/src/tasks/nlp/summarization.ts +++ b/packages/inference/src/tasks/nlp/summarization.ts @@ -1,7 +1,7 @@ import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type SummarizationArgs = BaseArgs & SummarizationInput; @@ -9,7 +9,7 @@ export type SummarizationArgs = BaseArgs & SummarizationInput; * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model. */ export async function summarization(args: SummarizationArgs, options?: Options): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "summarization", }); diff --git a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts index 0ac08c5897..1b89a1bfd4 100644 --- a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts +++ b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts @@ -1,7 +1,7 @@ import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type TableQuestionAnsweringArgs = BaseArgs & TableQuestionAnsweringInput; @@ -12,7 +12,7 @@ export async function tableQuestionAnswering( args: TableQuestionAnsweringArgs, options?: Options ): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "table-question-answering", }); diff --git a/packages/inference/src/tasks/nlp/textClassification.ts b/packages/inference/src/tasks/nlp/textClassification.ts index 7c99ddeece..975386e6ca 100644 --- a/packages/inference/src/tasks/nlp/textClassification.ts +++ b/packages/inference/src/tasks/nlp/textClassification.ts @@ -1,7 +1,7 @@ import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type TextClassificationArgs = BaseArgs & TextClassificationInput; @@ -12,16 +12,15 @@ export async function textClassification( args: TextClassificationArgs, options?: Options ): Promise { - const res = ( - await request(args, { - ...options, - task: "text-classification", - }) - )?.[0]; + const { data: res } = await innerRequest(args, { + ...options, + task: "text-classification", + }); + const output = res?.[0]; const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number"); + Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number"); if (!isValidOutput) { throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); } - return res; + return output; } diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts index 1afd1f17cd..989222a2ac 100644 --- a/packages/inference/src/tasks/nlp/textGeneration.ts +++ b/packages/inference/src/tasks/nlp/textGeneration.ts @@ -6,9 +6,9 @@ import type { } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { toArray } from "../../utils/toArray"; -import { request } from "../custom/request"; import { omit } from "../../utils/omit"; +import { innerRequest } from "../../utils/request"; +import { toArray } from "../../utils/toArray"; export type { TextGenerationInput, TextGenerationOutput }; @@ -37,7 +37,7 @@ export async function textGeneration( ): Promise { if (args.provider === "together") { args.prompt = args.inputs; - const raw = await request(args, { + const { data: raw } = await innerRequest(args, { ...options, task: "text-generation", }); @@ -61,10 +61,12 @@ export async function textGeneration( : undefined), ...omit(args, ["inputs", "parameters"]), }; - const raw = await request(payload, { - ...options, - task: "text-generation", - }); + const raw = ( + await innerRequest(payload, { + ...options, + task: "text-generation", + }) + ).data; const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string"; if (!isValidOutput) { @@ -75,18 +77,16 @@ export async function textGeneration( generated_text: completion.message.content, }; } else { - const res = toArray( - await request(args, { - ...options, - task: "text-generation", - }) - ); - + const { data: res } = await innerRequest(args, { + ...options, + task: "text-generation", + }); + const output = toArray(res); const isValidOutput = - Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string"); + Array.isArray(output) && output.every((x) => "generated_text" in x && typeof x?.generated_text === "string"); if (!isValidOutput) { throw new InferenceOutputError("Expected Array<{generated_text: string}>"); } - return (res as TextGenerationOutput[])?.[0]; + return (output as TextGenerationOutput[])?.[0]; } } diff --git a/packages/inference/src/tasks/nlp/textGenerationStream.ts b/packages/inference/src/tasks/nlp/textGenerationStream.ts index 5c6c76add7..de6d84e72c 100644 --- a/packages/inference/src/tasks/nlp/textGenerationStream.ts +++ b/packages/inference/src/tasks/nlp/textGenerationStream.ts @@ -1,6 +1,6 @@ import type { TextGenerationInput } from "@huggingface/tasks"; import type { BaseArgs, Options } from "../../types"; -import { streamingRequest } from "../custom/streamingRequest"; +import { innerStreamingRequest } from "../../utils/request"; export interface TextGenerationStreamToken { /** Token ID from the model tokenizer */ @@ -89,7 +89,7 @@ export async function* textGenerationStream( args: BaseArgs & TextGenerationInput, options?: Options ): AsyncGenerator { - yield* streamingRequest(args, { + yield* innerStreamingRequest(args, { ...options, task: "text-generation", }); diff --git a/packages/inference/src/tasks/nlp/tokenClassification.ts b/packages/inference/src/tasks/nlp/tokenClassification.ts index 46d53ffcbd..061cead753 100644 --- a/packages/inference/src/tasks/nlp/tokenClassification.ts +++ b/packages/inference/src/tasks/nlp/tokenClassification.ts @@ -1,8 +1,8 @@ import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; +import { innerRequest } from "../../utils/request"; import { toArray } from "../../utils/toArray"; -import { request } from "../custom/request"; export type TokenClassificationArgs = BaseArgs & TokenClassificationInput; @@ -13,15 +13,14 @@ export async function tokenClassification( args: TokenClassificationArgs, options?: Options ): Promise { - const res = toArray( - await request(args, { - ...options, - task: "token-classification", - }) - ); + const { data: res } = await innerRequest(args, { + ...options, + task: "token-classification", + }); + const output = toArray(res); const isValidOutput = - Array.isArray(res) && - res.every( + Array.isArray(output) && + output.every( (x) => typeof x.end === "number" && typeof x.entity_group === "string" && @@ -34,5 +33,5 @@ export async function tokenClassification( "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>" ); } - return res; + return output; } diff --git a/packages/inference/src/tasks/nlp/translation.ts b/packages/inference/src/tasks/nlp/translation.ts index a05b228eaa..bbd0d1ea29 100644 --- a/packages/inference/src/tasks/nlp/translation.ts +++ b/packages/inference/src/tasks/nlp/translation.ts @@ -1,14 +1,14 @@ import type { TranslationInput, TranslationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type TranslationArgs = BaseArgs & TranslationInput; /** * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en. */ export async function translation(args: TranslationArgs, options?: Options): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "translation", }); diff --git a/packages/inference/src/tasks/nlp/zeroShotClassification.ts b/packages/inference/src/tasks/nlp/zeroShotClassification.ts index 769315ac22..6ef92efad5 100644 --- a/packages/inference/src/tasks/nlp/zeroShotClassification.ts +++ b/packages/inference/src/tasks/nlp/zeroShotClassification.ts @@ -1,8 +1,8 @@ import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; +import { innerRequest } from "../../utils/request"; import { toArray } from "../../utils/toArray"; -import { request } from "../custom/request"; export type ZeroShotClassificationArgs = BaseArgs & ZeroShotClassificationInput; @@ -13,15 +13,14 @@ export async function zeroShotClassification( args: ZeroShotClassificationArgs, options?: Options ): Promise { - const res = toArray( - await request(args, { - ...options, - task: "zero-shot-classification", - }) - ); + const { data: res } = await innerRequest(args, { + ...options, + task: "zero-shot-classification", + }); + const output = toArray(res); const isValidOutput = - Array.isArray(res) && - res.every( + Array.isArray(output) && + output.every( (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && @@ -32,5 +31,5 @@ export async function zeroShotClassification( if (!isValidOutput) { throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>"); } - return res; + return output; } diff --git a/packages/inference/src/tasks/tabular/tabularClassification.ts b/packages/inference/src/tasks/tabular/tabularClassification.ts index aa5f0c72d2..b253159439 100644 --- a/packages/inference/src/tasks/tabular/tabularClassification.ts +++ b/packages/inference/src/tasks/tabular/tabularClassification.ts @@ -1,6 +1,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type TabularClassificationArgs = BaseArgs & { inputs: { @@ -25,7 +25,7 @@ export async function tabularClassification( args: TabularClassificationArgs, options?: Options ): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "tabular-classification", }); diff --git a/packages/inference/src/tasks/tabular/tabularRegression.ts b/packages/inference/src/tasks/tabular/tabularRegression.ts index 102f1dc3c8..3b0e19b573 100644 --- a/packages/inference/src/tasks/tabular/tabularRegression.ts +++ b/packages/inference/src/tasks/tabular/tabularRegression.ts @@ -1,6 +1,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; +import { innerRequest } from "../../utils/request"; export type TabularRegressionArgs = BaseArgs & { inputs: { @@ -25,7 +25,7 @@ export async function tabularRegression( args: TabularRegressionArgs, options?: Options ): Promise { - const res = await request(args, { + const { data: res } = await innerRequest(args, { ...options, task: "tabular-regression", }); diff --git a/packages/inference/src/utils/request.ts b/packages/inference/src/utils/request.ts new file mode 100644 index 0000000000..a25130de74 --- /dev/null +++ b/packages/inference/src/utils/request.ts @@ -0,0 +1,157 @@ +import { makeRequestOptions } from "../lib/makeRequestOptions"; +import type { InferenceTask, Options, RequestArgs } from "../types"; +import type { EventSourceMessage } from "../vendor/fetch-event-source/parse"; +import { getLines, getMessages } from "../vendor/fetch-event-source/parse"; + +export interface ResponseWrapper { + data: T; + requestContext: { + url: string; + info: RequestInit; + }; +} + +/** + * Primitive to make custom calls to the inference provider + */ +export async function innerRequest( + args: RequestArgs, + options?: Options & { + /** In most cases (unless we pass a endpointUrl) we know the task */ + task?: InferenceTask; + /** Is chat completion compatible */ + chatCompletion?: boolean; + } +): Promise> { + const { url, info } = await makeRequestOptions(args, options); + const response = await (options?.fetch ?? fetch)(url, info); + + const requestContext: ResponseWrapper["requestContext"] = { url, info }; + + if (options?.retry_on_error !== false && response.status === 503) { + return innerRequest(args, options); + } + + if (!response.ok) { + const contentType = response.headers.get("Content-Type"); + if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) { + const output = await response.json(); + if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) { + throw new Error( + `Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}` + ); + } + if (output.error || output.detail) { + throw new Error(JSON.stringify(output.error ?? output.detail)); + } else { + throw new Error(output); + } + } + const message = contentType?.startsWith("text/plain;") ? await response.text() : undefined; + throw new Error(message ?? "An error occurred while fetching the blob"); + } + + if (response.headers.get("Content-Type")?.startsWith("application/json")) { + const data = (await response.json()) as T; + return { data, requestContext }; + } + + const blob = (await response.blob()) as T; + return { data: blob as unknown as T, requestContext }; +} + +/** + * Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator + */ +export async function* innerStreamingRequest( + args: RequestArgs, + options?: Options & { + /** In most cases (unless we pass a endpointUrl) we know the task */ + task?: InferenceTask; + /** Is chat completion compatible */ + chatCompletion?: boolean; + } +): AsyncGenerator { + const { url, info } = await makeRequestOptions({ ...args, stream: true }, options); + const response = await (options?.fetch ?? fetch)(url, info); + + if (options?.retry_on_error !== false && response.status === 503) { + return yield* innerStreamingRequest(args, options); + } + if (!response.ok) { + if (response.headers.get("Content-Type")?.startsWith("application/json")) { + const output = await response.json(); + if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) { + throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`); + } + if (typeof output.error === "string") { + throw new Error(output.error); + } + if (output.error && "message" in output.error && typeof output.error.message === "string") { + /// OpenAI errors + throw new Error(output.error.message); + } + } + + throw new Error(`Server response contains error: ${response.status}`); + } + if (!response.headers.get("content-type")?.startsWith("text/event-stream")) { + throw new Error( + `Server does not support event stream content type, it returned ` + response.headers.get("content-type") + ); + } + + if (!response.body) { + return; + } + + const reader = response.body.getReader(); + let events: EventSourceMessage[] = []; + + const onEvent = (event: EventSourceMessage) => { + // accumulate events in array + events.push(event); + }; + + const onChunk = getLines( + getMessages( + () => {}, + () => {}, + onEvent + ) + ); + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) { + return; + } + onChunk(value); + for (const event of events) { + if (event.data.length > 0) { + if (event.data === "[DONE]") { + return; + } + const data = JSON.parse(event.data); + if (typeof data === "object" && data !== null && "error" in data) { + const errorStr = + typeof data.error === "string" + ? data.error + : typeof data.error === "object" && + data.error && + "message" in data.error && + typeof data.error.message === "string" + ? data.error.message + : JSON.stringify(data.error); + throw new Error(`Error forwarded from backend: ` + errorStr); + } + yield data as T; + } + } + events = []; + } + } finally { + reader.releaseLock(); + } +}