Skip to content

Commit 5826183

Browse files
authored
Merge branch 'main' into add-docker-model-runner-support
2 parents 83566ea + 4e3da78 commit 5826183

File tree

34 files changed

+649
-118
lines changed

34 files changed

+649
-118
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ You can run our packages with vanilla JS, without any bundler, by using a CDN or
9898

9999
```html
100100
<script type="module">
101-
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@4.0.6/+esm';
101+
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@4.2.0/+esm';
102102
import { createRepo, commit, deleteRepo, listFiles } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]/+esm";
103103
</script>
104104
```

packages/hub/src/lib/list-datasets.spec.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ describe("listDatasets", () => {
77
const results: DatasetEntry[] = [];
88

99
for await (const entry of listDatasets({ search: { owner: "hf-doc-build" } })) {
10-
if (entry.name === "hf-doc-build/doc-build-dev-test") {
10+
if (entry.name !== "hf-doc-build/doc-build" && entry.name !== "hf-doc-build/doc-build-dev") {
1111
continue;
1212
}
1313
if (typeof entry.downloads === "number") {
@@ -23,7 +23,7 @@ describe("listDatasets", () => {
2323
results.push(entry);
2424
}
2525

26-
expect(results).deep.equal([
26+
expect(results.sort((a, b) => a.id.localeCompare(b.id))).to.deep.equal([
2727
{
2828
id: "6356b19985da6f13863228bd",
2929
name: "hf-doc-build/doc-build",

packages/inference/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@huggingface/inference",
3-
"version": "4.0.6",
3+
"version": "4.2.0",
44
"packageManager": "[email protected]",
55
"license": "MIT",
66
"author": "Hugging Face and Tim Mikeladze <[email protected]>",

packages/inference/src/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@ export * from "./errors.js";
33
export * from "./types.js";
44
export * from "./tasks/index.js";
55
import * as snippets from "./snippets/index.js";
6+
export * from "./lib/getProviderHelper.js";
7+
export * from "./lib/makeRequestOptions.js";
68

79
export { snippets };

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
6464
"text-to-image": new FalAI.FalAITextToImageTask(),
6565
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
6666
"text-to-video": new FalAI.FalAITextToVideoTask(),
67+
"image-to-image": new FalAI.FalAIImageToImageTask(),
6768
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
6869
},
6970
"featherless-ai": {
@@ -138,6 +139,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
138139
"text-to-image": new Replicate.ReplicateTextToImageTask(),
139140
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
140141
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
142+
"image-to-image": new Replicate.ReplicateImageToImageTask(),
141143
},
142144
sambanova: {
143145
conversational: new Sambanova.SambanovaConversationalTask(),

packages/inference/src/package.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
// Generated file from package.json. Issues importing JSON directly when publishing on commonjs/ESM - see https://github.com/microsoft/TypeScript/issues/51783
2-
export const PACKAGE_VERSION = "4.0.6";
2+
export const PACKAGE_VERSION = "4.2.0";
33
export const PACKAGE_NAME = "@huggingface/inference";

packages/inference/src/providers/fal-ai.ts

Lines changed: 144 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";
1818

1919
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2020
import { isUrl } from "../lib/isUrl.js";
21-
import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types.js";
21+
import type { BodyParams, HeaderParams, InferenceTask, ModelId, RequestArgs, UrlParams } from "../types.js";
2222
import { delay } from "../utils/delay.js";
2323
import { omit } from "../utils/omit.js";
24+
import type { ImageToImageTaskHelper } from "./providerHelper.js";
2425
import {
2526
type AutomaticSpeechRecognitionTaskHelper,
2627
TaskProviderHelper,
@@ -34,6 +35,7 @@ import {
3435
InferenceClientProviderApiError,
3536
InferenceClientProviderOutputError,
3637
} from "../errors.js";
38+
import type { ImageToImageArgs } from "../tasks/index.js";
3739

3840
export interface FalAiQueueOutput {
3941
request_id: string;
@@ -82,6 +84,75 @@ abstract class FalAITask extends TaskProviderHelper {
8284
}
8385
}
8486

87+
abstract class FalAiQueueTask extends FalAITask {
88+
abstract task: InferenceTask;
89+
90+
async getResponseFromQueueApi(
91+
response: FalAiQueueOutput,
92+
url?: string,
93+
headers?: Record<string, string>
94+
): Promise<unknown> {
95+
if (!url || !headers) {
96+
throw new InferenceClientInputError(`URL and headers are required for ${this.task} task`);
97+
}
98+
const requestId = response.request_id;
99+
if (!requestId) {
100+
throw new InferenceClientProviderOutputError(
101+
`Received malformed response from Fal.ai ${this.task} API: no request ID found in the response`
102+
);
103+
}
104+
let status = response.status;
105+
106+
const parsedUrl = new URL(url);
107+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
108+
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
109+
}`;
110+
111+
// extracting the provider model id for status and result urls
112+
// from the response as it might be different from the mapped model in `url`
113+
const modelId = new URL(response.response_url).pathname;
114+
const queryParams = parsedUrl.search;
115+
116+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
117+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
118+
119+
while (status !== "COMPLETED") {
120+
await delay(500);
121+
const statusResponse = await fetch(statusUrl, { headers });
122+
123+
if (!statusResponse.ok) {
124+
throw new InferenceClientProviderApiError(
125+
"Failed to fetch response status from fal-ai API",
126+
{ url: statusUrl, method: "GET" },
127+
{
128+
requestId: statusResponse.headers.get("x-request-id") ?? "",
129+
status: statusResponse.status,
130+
body: await statusResponse.text(),
131+
}
132+
);
133+
}
134+
try {
135+
status = (await statusResponse.json()).status;
136+
} catch (error) {
137+
throw new InferenceClientProviderOutputError(
138+
"Failed to parse status response from fal-ai API: received malformed response"
139+
);
140+
}
141+
}
142+
143+
const resultResponse = await fetch(resultUrl, { headers });
144+
let result: unknown;
145+
try {
146+
result = await resultResponse.json();
147+
} catch (error) {
148+
throw new InferenceClientProviderOutputError(
149+
"Failed to parse result response from fal-ai API: received malformed response"
150+
);
151+
}
152+
return result;
153+
}
154+
}
155+
85156
function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string {
86157
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
87158
}
@@ -130,21 +201,42 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
130201
}
131202
}
132203

133-
export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
204+
export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper {
205+
task: InferenceTask;
134206
constructor() {
135207
super("https://queue.fal.run");
208+
this.task = "image-to-image";
136209
}
210+
137211
override makeRoute(params: UrlParams): string {
138212
if (params.authMethod !== "provider-key") {
139213
return `/${params.model}?_subdomain=queue`;
140214
}
141215
return `/${params.model}`;
142216
}
217+
143218
override preparePayload(params: BodyParams): Record<string, unknown> {
219+
const payload = params.args;
220+
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
221+
payload.loras = [
222+
{
223+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
224+
scale: 1,
225+
},
226+
];
227+
}
228+
return payload;
229+
}
230+
231+
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
232+
const mimeType = args.inputs instanceof Blob ? args.inputs.type : "image/png";
144233
return {
145-
...omit(params.args, ["inputs", "parameters"]),
146-
...(params.args.parameters as Record<string, unknown>),
147-
prompt: params.args.inputs,
234+
...omit(args, ["inputs", "parameters"]),
235+
image_url: `data:${mimeType};base64,${base64FromBytes(
236+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
237+
)}`,
238+
...args.parameters,
239+
...args,
148240
};
149241
}
150242

@@ -153,63 +245,59 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
153245
url?: string,
154246
headers?: Record<string, string>
155247
): Promise<Blob> {
156-
if (!url || !headers) {
157-
throw new InferenceClientInputError("URL and headers are required for text-to-video task");
158-
}
159-
const requestId = response.request_id;
160-
if (!requestId) {
248+
const result = await this.getResponseFromQueueApi(response, url, headers);
249+
250+
if (
251+
typeof result === "object" &&
252+
!!result &&
253+
"images" in result &&
254+
Array.isArray(result.images) &&
255+
result.images.length > 0 &&
256+
typeof result.images[0] === "object" &&
257+
!!result.images[0] &&
258+
"url" in result.images[0] &&
259+
typeof result.images[0].url === "string" &&
260+
isUrl(result.images[0].url)
261+
) {
262+
const urlResponse = await fetch(result.images[0].url);
263+
return await urlResponse.blob();
264+
} else {
161265
throw new InferenceClientProviderOutputError(
162-
"Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
266+
`Received malformed response from Fal.ai image-to-image API: expected { images: Array<{ url: string }> } result format, got instead: ${JSON.stringify(
267+
result
268+
)}`
163269
);
164270
}
165-
let status = response.status;
166-
167-
const parsedUrl = new URL(url);
168-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
169-
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
170-
}`;
171-
172-
// extracting the provider model id for status and result urls
173-
// from the response as it might be different from the mapped model in `url`
174-
const modelId = new URL(response.response_url).pathname;
175-
const queryParams = parsedUrl.search;
176-
177-
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
178-
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
179-
180-
while (status !== "COMPLETED") {
181-
await delay(500);
182-
const statusResponse = await fetch(statusUrl, { headers });
271+
}
272+
}
183273

184-
if (!statusResponse.ok) {
185-
throw new InferenceClientProviderApiError(
186-
"Failed to fetch response status from fal-ai API",
187-
{ url: statusUrl, method: "GET" },
188-
{
189-
requestId: statusResponse.headers.get("x-request-id") ?? "",
190-
status: statusResponse.status,
191-
body: await statusResponse.text(),
192-
}
193-
);
194-
}
195-
try {
196-
status = (await statusResponse.json()).status;
197-
} catch (error) {
198-
throw new InferenceClientProviderOutputError(
199-
"Failed to parse status response from fal-ai API: received malformed response"
200-
);
201-
}
274+
export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper {
275+
task: InferenceTask;
276+
constructor() {
277+
super("https://queue.fal.run");
278+
this.task = "text-to-video";
279+
}
280+
override makeRoute(params: UrlParams): string {
281+
if (params.authMethod !== "provider-key") {
282+
return `/${params.model}?_subdomain=queue`;
202283
}
284+
return `/${params.model}`;
285+
}
286+
override preparePayload(params: BodyParams): Record<string, unknown> {
287+
return {
288+
...omit(params.args, ["inputs", "parameters"]),
289+
...(params.args.parameters as Record<string, unknown>),
290+
prompt: params.args.inputs,
291+
};
292+
}
293+
294+
override async getResponse(
295+
response: FalAiQueueOutput,
296+
url?: string,
297+
headers?: Record<string, string>
298+
): Promise<Blob> {
299+
const result = await this.getResponseFromQueueApi(response, url, headers);
203300

204-
const resultResponse = await fetch(resultUrl, { headers });
205-
let result: unknown;
206-
try {
207-
result = await resultResponse.json();
208-
} catch (error) {
209-
throw new InferenceClientProviderOutputError(
210-
"Failed to parse result response from fal-ai API: received malformed response"
211-
);
212-
}
213301
if (
214302
typeof result === "object" &&
215303
!!result &&

packages/inference/src/providers/providerHelper.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ export abstract class TaskProviderHelper {
115115
* Prepare the headers for the request
116116
*/
117117
prepareHeaders(params: HeaderParams, isBinary: boolean): Record<string, string> {
118-
const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}` };
118+
const headers: Record<string, string> = {};
119+
if (params.authMethod !== "none") {
120+
headers["Authorization"] = `Bearer ${params.accessToken}`;
121+
}
119122
if (!isBinary) {
120123
headers["Content-Type"] = "application/json";
121124
}

0 commit comments

Comments
 (0)