Skip to content

Commit 64645c3

Browse files
authored
Merge branch 'main' into main
2 parents a3e0e5b + 2e2eeed commit 64645c3

18 files changed

+236
-31
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import { it, describe, expect } from "vitest";
2+
3+
import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts";
4+
import { addCollectionItem } from "./add-collection-item";
5+
import { listCollections } from "./list-collections";
6+
import { collectionInfo } from "./collection-info";
7+
import { deleteCollectionItem } from "./delete-collection-item";
8+
9+
describe("addCollectionItem", () => {
10+
it("should add a item to a collection", async () => {
11+
let slug: string = "";
12+
let itemId: string = "";
13+
14+
try {
15+
for await (const entry of listCollections({
16+
search: { owner: [TEST_USER] },
17+
limit: 1,
18+
hubUrl: TEST_HUB_URL,
19+
})) {
20+
slug = entry.slug;
21+
break;
22+
}
23+
24+
await addCollectionItem({
25+
slug,
26+
item: {
27+
type: "model",
28+
id: "quanghuynt14/TestAddCollectionItem",
29+
},
30+
note: "This is a test item",
31+
accessToken: TEST_ACCESS_TOKEN,
32+
hubUrl: TEST_HUB_URL,
33+
});
34+
35+
const collection = await collectionInfo({
36+
slug,
37+
accessToken: TEST_ACCESS_TOKEN,
38+
hubUrl: TEST_HUB_URL,
39+
});
40+
41+
const item = collection.items.find((item) => item.id === "quanghuynt14/TestAddCollectionItem");
42+
43+
expect(item).toBeDefined();
44+
45+
itemId = item?._id || "";
46+
} finally {
47+
await deleteCollectionItem({
48+
slug,
49+
itemId,
50+
accessToken: TEST_ACCESS_TOKEN,
51+
hubUrl: TEST_HUB_URL,
52+
});
53+
}
54+
});
55+
});
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import { HUB_URL } from "../consts";
2+
import { createApiError } from "../error";
3+
import type { CredentialsParams } from "../types/public";
4+
import { checkCredentials } from "../utils/checkCredentials";
5+
6+
export async function addCollectionItem(
7+
params: {
8+
/**
9+
* The slug of the collection to add the item to.
10+
*/
11+
slug: string;
12+
/**
13+
* The item to add to the collection.
14+
*/
15+
item: {
16+
type: "paper" | "collection" | "space" | "model" | "dataset";
17+
id: string;
18+
};
19+
/**
20+
* A note to attach to the item in the collection. The maximum size for a note is 500 characters.
21+
*/
22+
note?: string;
23+
hubUrl?: string;
24+
/**
25+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
26+
*/
27+
fetch?: typeof fetch;
28+
} & Partial<CredentialsParams>
29+
): Promise<void> {
30+
const accessToken = checkCredentials(params);
31+
32+
const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/collections/${params.slug}/items`, {
33+
method: "POST",
34+
body: JSON.stringify({
35+
item: params.item,
36+
note: params.note,
37+
}),
38+
headers: {
39+
Authorization: `Bearer ${accessToken}`,
40+
"Content-Type": "application/json",
41+
},
42+
});
43+
44+
if (!res.ok) {
45+
throw await createApiError(res);
46+
}
47+
}

packages/hub/src/lib/create-collection.spec.ts

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,26 @@
11
import { it, describe, expect } from "vitest";
22

3-
import { TEST_HUB_URL, TEST_ACCESS_TOKEN } from "../test/consts";
3+
import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts";
44
import { createCollection } from "./create-collection";
5-
import { whoAmI } from "./who-am-i";
65
import { deleteCollection } from "./delete-collection";
76

87
describe("createCollection", () => {
98
it("should create a collection", async () => {
109
let slug: string = "";
1110

1211
try {
13-
const user = await whoAmI({
14-
hubUrl: TEST_HUB_URL,
15-
accessToken: TEST_ACCESS_TOKEN,
16-
});
17-
1812
const result = await createCollection({
1913
collection: {
2014
title: "Test Collection",
21-
namespace: user.name,
15+
namespace: TEST_USER,
2216
description: "This is a test collection",
2317
private: false,
2418
},
2519
accessToken: TEST_ACCESS_TOKEN,
2620
hubUrl: TEST_HUB_URL,
2721
});
2822

29-
expect(result.slug.startsWith(`${user.name}/test-collection`)).toBe(true);
23+
expect(result.slug.startsWith(`${TEST_USER}/test-collection`)).toBe(true);
3024

3125
slug = result.slug;
3226
} finally {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import { HUB_URL } from "../consts";
2+
import { createApiError } from "../error";
3+
import type { CredentialsParams } from "../types/public";
4+
import { checkCredentials } from "../utils/checkCredentials";
5+
6+
export async function deleteCollectionItem(
7+
params: {
8+
/**
9+
* The slug of the collection to delete the item from.
10+
*/
11+
slug: string;
12+
/**
13+
* The item object id which is different from the repo_id/paper_id provided when adding the item to the collection.
14+
* This should be the _id property of the item.
15+
*/
16+
itemId: string;
17+
hubUrl?: string;
18+
/**
19+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
20+
*/
21+
fetch?: typeof fetch;
22+
} & Partial<CredentialsParams>
23+
): Promise<void> {
24+
const accessToken = checkCredentials(params);
25+
26+
const res = await (params.fetch ?? fetch)(
27+
`${params.hubUrl ?? HUB_URL}/api/collections/${params.slug}/items/${params.itemId}`,
28+
{
29+
method: "DELETE",
30+
headers: {
31+
Authorization: `Bearer ${accessToken}`,
32+
"Content-Type": "application/json",
33+
},
34+
}
35+
);
36+
37+
if (!res.ok) {
38+
throw await createApiError(res);
39+
}
40+
}

packages/inference/src/providers/black-forest-labs.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implement
6666
response: BlackForestLabsResponse,
6767
url?: string,
6868
headers?: HeadersInit,
69-
outputType?: "url" | "blob"
70-
): Promise<string | Blob> {
69+
outputType?: "url" | "blob" | "json"
70+
): Promise<string | Blob | Record<string, unknown>> {
7171
const logger = getLogger();
7272
const urlObj = new URL(response.polling_url);
7373
for (let step = 0; step < 5; step++) {
@@ -95,6 +95,9 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implement
9595
"sample" in payload.result &&
9696
typeof payload.result.sample === "string"
9797
) {
98+
if (outputType === "json") {
99+
return payload.result;
100+
}
98101
if (outputType === "url") {
99102
return payload.result.sample;
100103
}

packages/inference/src/providers/fal-ai.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,12 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
182182
return payload;
183183
}
184184

185-
override async getResponse(response: FalAITextToImageOutput, outputType?: "url" | "blob"): Promise<string | Blob> {
185+
override async getResponse(
186+
response: FalAITextToImageOutput,
187+
url?: string,
188+
headers?: HeadersInit,
189+
outputType?: "url" | "blob" | "json"
190+
): Promise<string | Blob | Record<string, unknown>> {
186191
if (
187192
typeof response === "object" &&
188193
"images" in response &&
@@ -191,6 +196,9 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
191196
"url" in response.images[0] &&
192197
typeof response.images[0].url === "string"
193198
) {
199+
if (outputType === "json") {
200+
return { ...response };
201+
}
194202
if (outputType === "url") {
195203
return response.images[0].url;
196204
}

packages/inference/src/providers/hf-inference.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,17 @@ export class HFInferenceTextToImageTask extends HFInferenceTask implements TextT
127127
response: Base64ImageGeneration | OutputUrlImageGeneration,
128128
url?: string,
129129
headers?: HeadersInit,
130-
outputType?: "url" | "blob"
131-
): Promise<string | Blob> {
130+
outputType?: "url" | "blob" | "json"
131+
): Promise<string | Blob | Record<string, unknown>> {
132132
if (!response) {
133133
throw new InferenceClientProviderOutputError(
134134
"Received malformed response from HF-Inference text-to-image API: response is undefined"
135135
);
136136
}
137137
if (typeof response == "object") {
138+
if (outputType === "json") {
139+
return { ...response };
140+
}
138141
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
139142
const base64Data = response.data[0].b64_json;
140143
if (outputType === "url") {
@@ -153,9 +156,9 @@ export class HFInferenceTextToImageTask extends HFInferenceTask implements TextT
153156
}
154157
}
155158
if (response instanceof Blob) {
156-
if (outputType === "url") {
159+
if (outputType === "url" || outputType === "json") {
157160
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
158-
return `data:image/jpeg;base64,${b64}`;
161+
return outputType === "url" ? `data:image/jpeg;base64,${b64}` : { output: `data:image/jpeg;base64,${b64}` };
159162
}
160163
return response;
161164
}

packages/inference/src/providers/hyperbolic.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,18 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper implements Tex
105105
response: HyperbolicTextToImageOutput,
106106
url?: string,
107107
headers?: HeadersInit,
108-
outputType?: "url" | "blob"
109-
): Promise<string | Blob> {
108+
outputType?: "url" | "blob" | "json"
109+
): Promise<string | Blob | Record<string, unknown>> {
110110
if (
111111
typeof response === "object" &&
112112
"images" in response &&
113113
Array.isArray(response.images) &&
114114
response.images[0] &&
115115
typeof response.images[0].image === "string"
116116
) {
117+
if (outputType === "json") {
118+
return { ...response };
119+
}
117120
if (outputType === "url") {
118121
return `data:image/jpeg;base64,${response.images[0].image}`;
119122
}

packages/inference/src/providers/nebius.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
116116
response: NebiusBase64ImageGeneration,
117117
url?: string,
118118
headers?: HeadersInit,
119-
outputType?: "url" | "blob"
120-
): Promise<string | Blob> {
119+
outputType?: "url" | "blob" | "json"
120+
): Promise<string | Blob | Record<string, unknown>> {
121121
if (
122122
typeof response === "object" &&
123123
"data" in response &&
@@ -126,6 +126,9 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
126126
"b64_json" in response.data[0] &&
127127
typeof response.data[0].b64_json === "string"
128128
) {
129+
if (outputType === "json") {
130+
return { ...response };
131+
}
129132
const base64Data = response.data[0].b64_json;
130133
if (outputType === "url") {
131134
return `data:image/jpeg;base64,${base64Data}`;

packages/inference/src/providers/nscale.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
5757
response: NscaleCloudBase64ImageGeneration,
5858
url?: string,
5959
headers?: HeadersInit,
60-
outputType?: "url" | "blob"
61-
): Promise<string | Blob> {
60+
outputType?: "url" | "blob" | "json"
61+
): Promise<string | Blob | Record<string, unknown>> {
6262
if (
6363
typeof response === "object" &&
6464
"data" in response &&
@@ -67,6 +67,9 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
6767
"b64_json" in response.data[0] &&
6868
typeof response.data[0].b64_json === "string"
6969
) {
70+
if (outputType === "json") {
71+
return { ...response };
72+
}
7073
const base64Data = response.data[0].b64_json;
7174
if (outputType === "url") {
7275
return `data:image/jpeg;base64,${base64Data}`;

0 commit comments

Comments
 (0)