Skip to content

Commit ee6740d

Browse files
authored
✨ Query models / datasets / spaces by tags (#498)
Fix #497 ## Query repos by tags Eg for gguf model, `listModels({search: {tags: ["gguf"]}})` For indonesian models, `listModels({search: {tags: ["id"]}})` For models stored in the EU, `listModels({search: {tags: ["region:eu"]}})` And so on. ## Request extra fields The `additionalFields` param, previously only available for `listSpaces`, is also available for `listModels` and `listDatasets`. It's also now strongly typed. So if you want to get the tags of the models you're listing: ```ts for await (const model of listModels({additionalFields: ["tags"]})) { console.log(model.tags); } ``` ## Limit the number of models requested You can specify the number of models to fetch, eg: ```ts for await (const model of listModels({limit: 2})) { console.log(model.tags); } ``` This will only fetch two models
1 parent bea807a commit ee6740d

File tree

8 files changed

+414
-17
lines changed

8 files changed

+414
-17
lines changed

packages/hub/src/lib/list-datasets.ts

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,35 @@ import type { ApiDatasetInfo } from "../types/api/api-dataset";
44
import type { Credentials } from "../types/public";
55
import { checkCredentials } from "../utils/checkCredentials";
66
import { parseLinkHeader } from "../utils/parseLinkHeader";
7+
import { pick } from "../utils/pick";
78

8-
const EXPAND_KEYS = ["private", "downloads", "gated", "likes", "lastModified"] satisfies (keyof ApiDatasetInfo)[];
9+
const EXPAND_KEYS = [
10+
"private",
11+
"downloads",
12+
"gated",
13+
"likes",
14+
"lastModified",
15+
] as const satisfies readonly (keyof ApiDatasetInfo)[];
16+
17+
const EXPANDABLE_KEYS = [
18+
"author",
19+
"cardData",
20+
"citation",
21+
"createdAt",
22+
"disabled",
23+
"description",
24+
"downloads",
25+
"downloadsAllTime",
26+
"gated",
27+
"gitalyUid",
28+
"lastModified",
29+
"likes",
30+
"paperswithcode_id",
31+
"private",
32+
// "siblings",
33+
"sha",
34+
"tags",
35+
] as const satisfies readonly (keyof ApiDatasetInfo)[];
936

1037
export interface DatasetEntry {
1138
id: string;
@@ -17,24 +44,35 @@ export interface DatasetEntry {
1744
updatedAt: Date;
1845
}
1946

20-
export async function* listDatasets(params?: {
47+
export async function* listDatasets<
48+
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
49+
>(params?: {
2150
search?: {
2251
owner?: string;
52+
tags?: string[];
2353
};
2454
credentials?: Credentials;
2555
hubUrl?: string;
56+
additionalFields?: T[];
57+
/**
58+
* Set to limit the number of models returned.
59+
*/
60+
limit?: number;
2661
/**
2762
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
2863
*/
2964
fetch?: typeof fetch;
3065
}): AsyncGenerator<DatasetEntry> {
3166
checkCredentials(params?.credentials);
67+
let totalToFetch = params?.limit ?? Infinity;
3268
const search = new URLSearchParams([
3369
...Object.entries({
34-
limit: "500",
70+
limit: String(Math.min(totalToFetch, 500)),
3571
...(params?.search?.owner ? { author: params.search.owner } : undefined),
3672
}),
73+
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
3774
...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
75+
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
3876
]).toString();
3977
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/datasets` + (search ? "?" + search : "");
4078

@@ -54,6 +92,7 @@ export async function* listDatasets(params?: {
5492

5593
for (const item of items) {
5694
yield {
95+
...(params?.additionalFields && pick(item, params.additionalFields)),
5796
id: item._id,
5897
name: item.id,
5998
private: item.private,
@@ -62,10 +101,15 @@ export async function* listDatasets(params?: {
62101
gated: item.gated,
63102
updatedAt: new Date(item.lastModified),
64103
};
104+
totalToFetch--;
105+
if (totalToFetch <= 0) {
106+
return;
107+
}
65108
}
66109

67110
const linkHeader = res.headers.get("Link");
68111

69112
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
113+
// Could update limit in url to fetch less items if not all items of next page are needed.
70114
}
71115
}

packages/hub/src/lib/list-models.spec.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,19 @@ describe("listModels", () => {
5252
},
5353
]);
5454
});
55+
56+
it("should list indonesian models with gguf format", async () => {
57+
let count = 0;
58+
for await (const entry of listModels({
59+
search: { tags: ["gguf", "id"] },
60+
additionalFields: ["tags"],
61+
limit: 2,
62+
})) {
63+
count++;
64+
expect(entry.tags).to.include("gguf");
65+
expect(entry.tags).to.include("id");
66+
}
67+
68+
expect(count).to.equal(2);
69+
});
5570
});

packages/hub/src/lib/list-models.ts

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import type { ApiModelInfo } from "../types/api/api-model";
44
import type { Credentials, PipelineType } from "../types/public";
55
import { checkCredentials } from "../utils/checkCredentials";
66
import { parseLinkHeader } from "../utils/parseLinkHeader";
7+
import { pick } from "../utils/pick";
78

89
const EXPAND_KEYS = [
910
"pipeline_tag",
@@ -12,7 +13,31 @@ const EXPAND_KEYS = [
1213
"downloads",
1314
"likes",
1415
"lastModified",
15-
] satisfies (keyof ApiModelInfo)[];
16+
] as const satisfies readonly (keyof ApiModelInfo)[];
17+
18+
const EXPANDABLE_KEYS = [
19+
"author",
20+
"cardData",
21+
"config",
22+
"createdAt",
23+
"disabled",
24+
"downloads",
25+
"downloadsAllTime",
26+
"gated",
27+
"gitalyUid",
28+
"lastModified",
29+
"library_name",
30+
"likes",
31+
"model-index",
32+
"pipeline_tag",
33+
"private",
34+
"safetensors",
35+
"sha",
36+
// "siblings",
37+
"spaces",
38+
"tags",
39+
"transformersInfo",
40+
] as const satisfies readonly (keyof ApiModelInfo)[];
1641

1742
export interface ModelEntry {
1843
id: string;
@@ -25,26 +50,37 @@ export interface ModelEntry {
2550
updatedAt: Date;
2651
}
2752

28-
export async function* listModels(params?: {
53+
export async function* listModels<
54+
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
55+
>(params?: {
2956
search?: {
3057
owner?: string;
3158
task?: PipelineType;
59+
tags?: string[];
3260
};
3361
credentials?: Credentials;
3462
hubUrl?: string;
63+
additionalFields?: T[];
64+
/**
65+
* Set to limit the number of models returned.
66+
*/
67+
limit?: number;
3568
/**
3669
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
3770
*/
3871
fetch?: typeof fetch;
39-
}): AsyncGenerator<ModelEntry> {
72+
}): AsyncGenerator<ModelEntry & Pick<ApiModelInfo, T>> {
4073
checkCredentials(params?.credentials);
74+
let totalToFetch = params?.limit ?? Infinity;
4175
const search = new URLSearchParams([
4276
...Object.entries({
43-
limit: "500",
77+
limit: String(Math.min(totalToFetch, 500)),
4478
...(params?.search?.owner ? { author: params.search.owner } : undefined),
4579
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
4680
}),
81+
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
4782
...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
83+
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
4884
]).toString();
4985
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;
5086

@@ -64,6 +100,7 @@ export async function* listModels(params?: {
64100

65101
for (const item of items) {
66102
yield {
103+
...(params?.additionalFields && pick(item, params.additionalFields)),
67104
id: item._id,
68105
name: item.id,
69106
private: item.private,
@@ -72,11 +109,17 @@ export async function* listModels(params?: {
72109
gated: item.gated,
73110
likes: item.likes,
74111
updatedAt: new Date(item.lastModified),
75-
};
112+
} as ModelEntry & Pick<ApiModelInfo, T>;
113+
totalToFetch--;
114+
115+
if (totalToFetch <= 0) {
116+
return;
117+
}
76118
}
77119

78120
const linkHeader = res.headers.get("Link");
79121

80122
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
123+
// Could update url to reduce the limit if we don't need the whole 500 of the next batch.
81124
}
82125
}

packages/hub/src/lib/list-spaces.ts

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,42 @@ import { checkCredentials } from "../utils/checkCredentials";
66
import { parseLinkHeader } from "../utils/parseLinkHeader";
77
import { pick } from "../utils/pick";
88

9-
const EXPAND_KEYS = ["sdk", "likes", "private", "lastModified"];
9+
const EXPAND_KEYS = ["sdk", "likes", "private", "lastModified"] as const satisfies readonly (keyof ApiSpaceInfo)[];
10+
const EXPANDABLE_KEYS = [
11+
"author",
12+
"cardData",
13+
"datasets",
14+
"disabled",
15+
"gitalyUid",
16+
"lastModified",
17+
"createdAt",
18+
"likes",
19+
"private",
20+
"runtime",
21+
"sdk",
22+
// "siblings",
23+
"sha",
24+
"subdomain",
25+
"tags",
26+
"models",
27+
] as const satisfies readonly (keyof ApiSpaceInfo)[];
1028

11-
export type SpaceEntry = {
29+
export interface SpaceEntry {
1230
id: string;
1331
name: string;
1432
sdk?: SpaceSdk;
1533
likes: number;
1634
private: boolean;
1735
updatedAt: Date;
1836
// Use additionalFields to fetch the fields from ApiSpaceInfo
19-
} & Partial<Omit<ApiSpaceInfo, "updatedAt">>;
37+
}
2038

21-
export async function* listSpaces(params?: {
39+
export async function* listSpaces<
40+
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
41+
>(params?: {
2242
search?: {
2343
owner?: string;
44+
tags?: string[];
2445
};
2546
credentials?: Credentials;
2647
hubUrl?: string;
@@ -31,11 +52,12 @@ export async function* listSpaces(params?: {
3152
/**
3253
* Additional fields to fetch from huggingface.co.
3354
*/
34-
additionalFields?: Array<keyof ApiSpaceInfo>;
55+
additionalFields?: T[];
3556
}): AsyncGenerator<SpaceEntry> {
3657
checkCredentials(params?.credentials);
3758
const search = new URLSearchParams([
3859
...Object.entries({ limit: "500", ...(params?.search?.owner ? { author: params.search.owner } : undefined) }),
60+
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
3961
...[...EXPAND_KEYS, ...(params?.additionalFields ?? [])].map((val) => ["expand", val] satisfies [string, string]),
4062
]).toString();
4163
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/spaces?${search}`;

packages/hub/src/types/api/api-dataset.d.ts

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import type { License } from "../public";
2+
13
export interface ApiDatasetInfo {
24
_id: string;
35
id: string;
46
arxivIds?: string[];
57
author?: string;
68
cardExists?: true;
79
cardError?: unknown;
8-
cardData?: unknown;
10+
cardData?: ApiDatasetMetadata;
911
contributors?: Array<{ user: string; _id: string }>;
1012
disabled: boolean;
1113
discussionsDisabled: boolean;
@@ -17,6 +19,9 @@ export interface ApiDatasetInfo {
1719
likesRecent: number;
1820
private: boolean;
1921
updatedAt: string; // date
22+
createdAt: string; // date
23+
tags: string[];
24+
paperswithcode_id?: string;
2025
sha: string;
2126
files?: string[];
2227
citation?: string;
@@ -26,3 +31,59 @@ export interface ApiDatasetInfo {
2631
previewable?: boolean;
2732
doi?: { id: string; commit: string };
2833
}
34+
35+
export interface ApiDatasetMetadata {
36+
licenses?: undefined;
37+
license?: License | License[];
38+
license_name?: string;
39+
license_link?: "LICENSE" | "LICENSE.md" | string;
40+
license_details?: string;
41+
languages?: undefined;
42+
language?: string | string[];
43+
language_bcp47?: string[];
44+
language_details?: string;
45+
tags?: string[];
46+
task_categories?: string[];
47+
task_ids?: string[];
48+
config_names?: string[];
49+
configs?: {
50+
config_name: string;
51+
data_files?:
52+
| string
53+
| string[]
54+
| {
55+
split: string;
56+
path: string | string[];
57+
}[];
58+
data_dir?: string;
59+
}[];
60+
benchmark?: string;
61+
paperswithcode_id?: string | null;
62+
pretty_name?: string;
63+
viewer?: boolean;
64+
viewer_display_urls?: boolean;
65+
thumbnail?: string | null;
66+
description?: string | null;
67+
annotations_creators?: string[];
68+
language_creators?: string[];
69+
multilinguality?: string[];
70+
size_categories?: string[];
71+
source_datasets?: string[];
72+
extra_gated_prompt?: string;
73+
extra_gated_fields?: {
74+
/**
75+
* "text" | "checkbox" | "date_picker" | "country" | "ip_location" | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } | { type: "select", options: Array<string | { label: string; value: string; }> } Property
76+
*/
77+
[x: string]:
78+
| "text"
79+
| "checkbox"
80+
| "date_picker"
81+
| "country"
82+
| "ip_location"
83+
| { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" }
84+
| { type: "select"; options: Array<string | { label: string; value: string }> };
85+
};
86+
extra_gated_heading?: string;
87+
extra_gated_description?: string;
88+
extra_gated_button_content?: string;
89+
}

0 commit comments

Comments
 (0)