Skip to content
25 changes: 0 additions & 25 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,31 +572,6 @@ await hf.tabularClassification({
})
```

## Custom Calls

For models with custom parameters / outputs.

```typescript
await hf.request({
model: 'my-custom-model',
inputs: 'hello world',
parameters: {
custom_param: 'some magic',
}
})

// Custom streaming call, for models with custom parameters / outputs
for await (const output of hf.streamingRequest({
model: 'my-custom-model',
inputs: 'hello world',
parameters: {
custom_param: 'some magic',
}
})) {
...
}
```

You can use any Chat Completion API-compatible provider with the `chatCompletion` method.

```typescript
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/audio/audioClassification.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";

Expand All @@ -16,10 +16,12 @@ export async function audioClassification(
options?: Options
): Promise<AudioClassificationOutput> {
const payload = preparePayload(args);
const res = await request<AudioClassificationOutput>(payload, {
...options,
task: "audio-classification",
});
const res = (
await innerRequest<AudioClassificationOutput>(payload, {
...options,
task: "audio-classification",
})
).data;
const isValidOutput =
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
if (!isValidOutput) {
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/audio/audioToAudio.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";

Expand Down Expand Up @@ -37,10 +37,12 @@ export interface AudioToAudioOutput {
*/
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
const payload = preparePayload(args);
const res = await request<AudioToAudioOutput>(payload, {
...options,
task: "audio-to-audio",
});
const res = (
await innerRequest<AudioToAudioOutput>(payload, {
...options,
task: "audio-to-audio",
})
).data;

return validateOutput(res);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
import { request } from "../custom/request";
import { omit } from "../../utils/omit";
import { innerRequest } from "../../utils/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";
import { omit } from "../../utils/omit";

export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput);
/**
Expand All @@ -17,10 +17,12 @@ export async function automaticSpeechRecognition(
options?: Options
): Promise<AutomaticSpeechRecognitionOutput> {
const payload = await buildPayload(args);
const res = await request<AutomaticSpeechRecognitionOutput>(payload, {
...options,
task: "automatic-speech-recognition",
});
const res = (
await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
...options,
task: "automatic-speech-recognition",
})
).data;
const isValidOutput = typeof res?.text === "string";
if (!isValidOutput) {
throw new InferenceOutputError("Expected {text: string}");
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { TextToSpeechInput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { omit } from "../../utils/omit";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
type TextToSpeechArgs = BaseArgs & TextToSpeechInput;

interface OutputUrlTextToSpeechGeneration {
Expand All @@ -22,10 +22,12 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P
text: args.inputs,
}
: args;
const res = await request<Blob | OutputUrlTextToSpeechGeneration>(payload, {
...options,
task: "text-to-speech",
});
const res = (
await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(payload, {
...options,
task: "text-to-speech",
})
).data;
if (res instanceof Blob) {
return res;
}
Expand Down
36 changes: 4 additions & 32 deletions packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import { innerRequest } from "../../utils/request";

/**
* Primitive to make custom calls to the inference provider
* @deprecated Use specific task functions instead. This function will be removed in a future version.
*/
export async function request<T>(
args: RequestArgs,
Expand All @@ -13,35 +14,6 @@ export async function request<T>(
chatCompletion?: boolean;
}
): Promise<T> {
const { url, info } = await makeRequestOptions(args, options);
const response = await (options?.fetch ?? fetch)(url, info);

if (options?.retry_on_error !== false && response.status === 503) {
return request(args, options);
}

if (!response.ok) {
const contentType = response.headers.get("Content-Type");
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
);
}
if (output.error || output.detail) {
throw new Error(JSON.stringify(output.error ?? output.detail));
} else {
throw new Error(output);
}
}
const message = contentType?.startsWith("text/plain;") ? await response.text() : undefined;
throw new Error(message ?? "An error occurred while fetching the blob");
}

if (response.headers.get("Content-Type")?.startsWith("application/json")) {
return await response.json();
}

return (await response.blob()) as T;
const result = await innerRequest<T>(args, options);
return result.data;
}
89 changes: 3 additions & 86 deletions packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import type { EventSourceMessage } from "../../vendor/fetch-event-source/parse";
import { getLines, getMessages } from "../../vendor/fetch-event-source/parse";

import { innerStreamingRequest } from "../../utils/request";
/**
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
* @deprecated Use specific task functions instead. This function will be removed in a future version.
*/
export async function* streamingRequest<T>(
args: RequestArgs,
Expand All @@ -15,86 +13,5 @@ export async function* streamingRequest<T>(
chatCompletion?: boolean;
}
): AsyncGenerator<T> {
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
const response = await (options?.fetch ?? fetch)(url, info);

if (options?.retry_on_error !== false && response.status === 503) {
return yield* streamingRequest(args, options);
}
if (!response.ok) {
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (typeof output.error === "string") {
throw new Error(output.error);
}
if (output.error && "message" in output.error && typeof output.error.message === "string") {
/// OpenAI errors
throw new Error(output.error.message);
}
}

throw new Error(`Server response contains error: ${response.status}`);
}
if (!response.headers.get("content-type")?.startsWith("text/event-stream")) {
throw new Error(
`Server does not support event stream content type, it returned ` + response.headers.get("content-type")
);
}

if (!response.body) {
return;
}

const reader = response.body.getReader();
let events: EventSourceMessage[] = [];

const onEvent = (event: EventSourceMessage) => {
// accumulate events in array
events.push(event);
};

const onChunk = getLines(
getMessages(
() => {},
() => {},
onEvent
)
);

try {
while (true) {
const { done, value } = await reader.read();
if (done) {
return;
}
onChunk(value);
for (const event of events) {
if (event.data.length > 0) {
if (event.data === "[DONE]") {
return;
}
const data = JSON.parse(event.data);
if (typeof data === "object" && data !== null && "error" in data) {
const errorStr =
typeof data.error === "string"
? data.error
: typeof data.error === "object" &&
data.error &&
"message" in data.error &&
typeof data.error.message === "string"
? data.error.message
: JSON.stringify(data.error);
throw new Error(`Error forwarded from backend: ` + errorStr);
}
yield data as T;
}
}
events = [];
}
} finally {
reader.releaseLock();
}
yield* innerStreamingRequest(args, options);
}
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import { preparePayload, type LegacyImageInput } from "./utils";

export type ImageClassificationArgs = BaseArgs & (ImageClassificationInput | LegacyImageInput);
Expand All @@ -15,10 +15,12 @@ export async function imageClassification(
options?: Options
): Promise<ImageClassificationOutput> {
const payload = preparePayload(args);
const res = await request<ImageClassificationOutput>(payload, {
...options,
task: "image-classification",
});
const res = (
await innerRequest<ImageClassificationOutput>(payload, {
...options,
task: "image-classification",
})
).data;
const isValidOutput =
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
if (!isValidOutput) {
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import { preparePayload, type LegacyImageInput } from "./utils";

export type ImageSegmentationArgs = BaseArgs & (ImageSegmentationInput | LegacyImageInput);
Expand All @@ -15,10 +15,12 @@ export async function imageSegmentation(
options?: Options
): Promise<ImageSegmentationOutput> {
const payload = preparePayload(args);
const res = await request<ImageSegmentationOutput>(payload, {
...options,
task: "image-segmentation",
});
const res = (
await innerRequest<ImageSegmentationOutput>(payload, {
...options,
task: "image-segmentation",
})
).data;
const isValidOutput =
Array.isArray(res) &&
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { ImageToImageInput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";

export type ImageToImageArgs = BaseArgs & ImageToImageInput;

Expand All @@ -26,10 +26,12 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
),
};
}
const res = await request<Blob>(reqArgs, {
...options,
task: "image-to-image",
});
const res = (
await innerRequest<Blob>(reqArgs, {
...options,
task: "image-to-image",
})
).data;
const isValidOutput = res && res instanceof Blob;
if (!isValidOutput) {
throw new InferenceOutputError("Expected Blob");
Expand Down
6 changes: 3 additions & 3 deletions packages/inference/src/tasks/cv/imageToText.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import type { LegacyImageInput } from "./utils";
import { preparePayload } from "./utils";

Expand All @@ -12,11 +12,11 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
const payload = preparePayload(args);
const res = (
await request<[ImageToTextOutput]>(payload, {
await innerRequest<[ImageToTextOutput]>(payload, {
...options,
task: "image-to-text",
})
)?.[0];
).data?.[0];

if (typeof res?.generated_text !== "string") {
throw new InferenceOutputError("Expected {generated_text: string}");
Expand Down
Loading