From 6613966214d84b6de4230cd90530297b66105bbb Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Mon, 18 Nov 2024 15:50:10 +0100 Subject: [PATCH 1/8] feat(hub): adding snapshot download method --- packages/hub/package.json | 1 + packages/hub/src/lib/index.ts | 1 + .../hub/src/lib/snapshot-download.spec.ts | 204 ++++++++++++++++++ packages/hub/src/lib/snapshot-download.ts | 117 ++++++++++ packages/hub/vitest-browser.config.mts | 8 +- 5 files changed, 330 insertions(+), 1 deletion(-) create mode 100644 packages/hub/src/lib/snapshot-download.spec.ts create mode 100644 packages/hub/src/lib/snapshot-download.ts diff --git a/packages/hub/package.json b/packages/hub/package.json index 04486821d5..9e84de2c1a 100644 --- a/packages/hub/package.json +++ b/packages/hub/package.json @@ -22,6 +22,7 @@ "./src/utils/FileBlob.ts": false, "./src/lib/cache-management.ts": false, "./src/lib/download-file-to-cache-dir.ts": false, + "./src/lib/snapshot-download.ts": false, "./dist/index.js": "./dist/browser/index.js", "./dist/index.mjs": "./dist/browser/index.mjs" }, diff --git a/packages/hub/src/lib/index.ts b/packages/hub/src/lib/index.ts index c2a2fbe06c..24e239bdcb 100644 --- a/packages/hub/src/lib/index.ts +++ b/packages/hub/src/lib/index.ts @@ -21,6 +21,7 @@ export * from "./oauth-handle-redirect"; export * from "./oauth-login-url"; export * from "./parse-safetensors-metadata"; export * from "./paths-info"; +export * from "./snapshot-download"; export * from "./space-info"; export * from "./upload-file"; export * from "./upload-files"; diff --git a/packages/hub/src/lib/snapshot-download.spec.ts b/packages/hub/src/lib/snapshot-download.spec.ts new file mode 100644 index 0000000000..ab6d73d748 --- /dev/null +++ b/packages/hub/src/lib/snapshot-download.spec.ts @@ -0,0 +1,204 @@ +import { expect, test, describe, vi, beforeEach } from "vitest"; +import { dirname, join } from "node:path"; +import { mkdir, writeFile } from "node:fs/promises"; +import { getHFHubCachePath } from "./cache-management"; +import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; +import { snapshotDownload } from "./snapshot-download"; +import type { ListFileEntry } from "./list-files"; +import { listFiles } from "./list-files"; +import { modelInfo } from "./model-info"; +import type { ModelEntry } from "./list-models"; +import type { ApiModelInfo } from "../types/api/api-model"; +import { datasetInfo } from "./dataset-info"; +import type { DatasetEntry } from "./list-datasets"; +import type { ApiDatasetInfo } from "../types/api/api-dataset"; +import { spaceInfo } from "./space-info"; +import type { SpaceEntry } from "./list-spaces"; +import type { ApiSpaceInfo } from "../types/api/api-space"; + +vi.mock("node:fs/promises", () => ({ + writeFile: vi.fn(), + mkdir: vi.fn(), +})); + +vi.mock("./space-info", () => ({ + spaceInfo: vi.fn(), +})); + +vi.mock("./dataset-info", () => ({ + datasetInfo: vi.fn(), +})); + +vi.mock("./model-info", () => ({ + modelInfo: vi.fn(), +})); + +vi.mock("./list-files", () => ({ + listFiles: vi.fn(), +})); + +vi.mock("./download-file-to-cache-dir", () => ({ + downloadFileToCacheDir: vi.fn(), +})); + +const DUMMY_SHA = "dummy-sha"; + +// utility method to transform an array of ListFileEntry to an AsyncGenerator +async function* toAsyncGenerator(content: ListFileEntry[]): AsyncGenerator { + for (const entry of content) { + yield Promise.resolve(entry); + } +} + +beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(listFiles).mockReturnValue(toAsyncGenerator([])); + + // mock repo info + vi.mocked(modelInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as ModelEntry & ApiModelInfo); + vi.mocked(datasetInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as DatasetEntry & ApiDatasetInfo); + vi.mocked(spaceInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as SpaceEntry & ApiSpaceInfo); +}); + +describe("snapshotDownload", () => { + test("empty AsyncGenerator should not call downloadFileToCacheDir", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + expect(downloadFileToCacheDir).not.toHaveBeenCalled(); + }); + + test("repo type model should use modelInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "model", + }, + }); + expect(modelInfo).toHaveBeenCalledOnce(); + expect(modelInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + }); + }); + + test("repo type dataset should use datasetInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "dataset", + }, + }); + expect(datasetInfo).toHaveBeenCalledOnce(); + expect(datasetInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + }); + }); + + test("repo type space should use spaceInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + expect(spaceInfo).toHaveBeenCalledOnce(); + expect(spaceInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + }); + }); + + test("commitHash should be saved to ref folder", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + revision: "dummy-revision", + }); + + // cross-platform testing + const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "refs", "dummy-revision"); + expect(mkdir).toHaveBeenCalledWith(dirname(expectedPath), { recursive: true }); + expect(writeFile).toHaveBeenCalledWith(expectedPath, DUMMY_SHA); + }); + + test("directory ListFileEntry should mkdir it", async () => { + vi.mocked(listFiles).mockReturnValue( + toAsyncGenerator([ + { + oid: "dummy-etag", + type: "directory", + path: "potatoes", + size: 0, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + }, + ]) + ); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + // cross-platform testing + const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "snapshots", DUMMY_SHA, "potatoes"); + expect(mkdir).toHaveBeenCalledWith(expectedPath, { recursive: true }); + }); + + test("files in ListFileEntry should download them", async () => { + const entries: ListFileEntry[] = Array.from({ length: 10 }, (_, i) => ({ + oid: `dummy-etag-${i}`, + type: "file", + path: `file-${i}.txt`, + size: i, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + })); + vi.mocked(listFiles).mockReturnValue(toAsyncGenerator(entries)); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + for (const entry of entries) { + expect(downloadFileToCacheDir).toHaveBeenCalledWith( + expect.objectContaining({ + repo: { + name: "foo/bar", + type: "space", + }, + path: entry.path, + revision: DUMMY_SHA, + }) + ); + } + }); +}); diff --git a/packages/hub/src/lib/snapshot-download.ts b/packages/hub/src/lib/snapshot-download.ts new file mode 100644 index 0000000000..7f675b5cd7 --- /dev/null +++ b/packages/hub/src/lib/snapshot-download.ts @@ -0,0 +1,117 @@ +import type { RepoDesignation } from "../types/public"; +import { listFiles } from "./list-files"; +import { getHFHubCachePath, getRepoFolderName } from "./cache-management"; +import { spaceInfo } from "./space-info"; +import { datasetInfo } from "./dataset-info"; +import { modelInfo } from "./model-info"; +import { toRepoId } from "../utils/toRepoId"; +import { join, dirname } from "node:path"; +import { mkdir, writeFile } from "node:fs/promises"; +import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; + +export const DEFAULT_REVISION = "main"; + +export async function snapshotDownload(params: { + repo: RepoDesignation; + cacheDir?: string; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + * + * @default "main" + */ + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; +}): Promise { + let cacheDir: string; + if (params.cacheDir) { + cacheDir = params.cacheDir; + } else { + cacheDir = getHFHubCachePath(); + } + + let revision: string; + if (params.revision) { + revision = params.revision; + } else { + revision = DEFAULT_REVISION; + } + + const repoId = toRepoId(params.repo); + + // get repository revision value (sha) + let repoInfo: { sha: string }; + switch (repoId.type) { + case "space": + repoInfo = await spaceInfo({ + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + case "dataset": + repoInfo = await datasetInfo({ + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + case "model": + repoInfo = await modelInfo({ + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + default: + throw new Error(`invalid repository type ${repoId.type}`); + } + + const commitHash: string = repoInfo.sha; + + // get storage folder + const storageFolder = join(cacheDir, getRepoFolderName(repoId)); + const snapshotFolder = join(storageFolder, "snapshots", commitHash); + + // if passed revision is not identical to commit_hash + // then revision has to be a branch name or tag name. + // In that case store a ref. + if (revision !== commitHash) { + const refPath = join(storageFolder, "refs", revision); + await mkdir(dirname(refPath), { recursive: true }); + await writeFile(refPath, commitHash); + } + + const cursor = listFiles({ + repo: params.repo, + recursive: true, + revision: repoInfo.sha, + hubUrl: params.hubUrl, + fetch: params.fetch, + }); + + for await (const entry of cursor) { + switch (entry.type) { + case "file": + await downloadFileToCacheDir({ + repo: params.repo, + path: entry.path, + revision: commitHash, + hubUrl: params.hubUrl, + fetch: params.fetch, + cacheDir: cacheDir, + }); + break; + case "directory": + await mkdir(join(snapshotFolder, entry.path), { recursive: true }); + break; + default: + throw new Error(`unknown entry type: ${entry.type}`); + } + } + + return snapshotFolder; +} diff --git a/packages/hub/vitest-browser.config.mts b/packages/hub/vitest-browser.config.mts index 60fcbfbfcf..db22fb67cf 100644 --- a/packages/hub/vitest-browser.config.mts +++ b/packages/hub/vitest-browser.config.mts @@ -2,6 +2,12 @@ import { configDefaults, defineConfig } from "vitest/config"; export default defineConfig({ test: { - exclude: [...configDefaults.exclude, "src/utils/FileBlob.spec.ts", "src/lib/cache-management.spec.ts", "src/lib/download-file-to-cache-dir.spec.ts"], + exclude: [ + ...configDefaults.exclude, + "src/utils/FileBlob.spec.ts", + "src/lib/cache-management.spec.ts", + "src/lib/download-file-to-cache-dir.spec.ts", + "src/lib/snapshot-download.spec.ts", + ], }, }); From b63f16f9583b7b93e6f37518f4bc70fd0947eff5 Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:07:02 +0100 Subject: [PATCH 2/8] fix: params propagation --- .../hub/src/lib/snapshot-download.spec.ts | 59 +++++++++++++++++++ packages/hub/src/lib/snapshot-download.ts | 44 +++++++------- 2 files changed, 82 insertions(+), 21 deletions(-) diff --git a/packages/hub/src/lib/snapshot-download.spec.ts b/packages/hub/src/lib/snapshot-download.spec.ts index ab6d73d748..23f0927646 100644 --- a/packages/hub/src/lib/snapshot-download.spec.ts +++ b/packages/hub/src/lib/snapshot-download.spec.ts @@ -201,4 +201,63 @@ describe("snapshotDownload", () => { ); } }); + + test("custom params should be propagated", async () => { + // fetch mock + const fetchMock: typeof fetch = vi.fn(); + const hubMock = "https://foor.bar"; + const accessTokenMock = 'dummy-access-token'; + + vi.mocked(listFiles).mockReturnValue( + toAsyncGenerator([ + { + oid: `dummy-etag`, + type: "file", + path: `file.txt`, + size: 10, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + }, + ]) + ); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + hubUrl: hubMock, + fetch: fetchMock, + accessToken: accessTokenMock, + }); + + expect(spaceInfo).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + + // list files should receive custom fetch + expect(listFiles).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + + // download file to cache should receive custom fetch + expect(downloadFileToCacheDir).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + }); }); diff --git a/packages/hub/src/lib/snapshot-download.ts b/packages/hub/src/lib/snapshot-download.ts index 7f675b5cd7..f09f37e93f 100644 --- a/packages/hub/src/lib/snapshot-download.ts +++ b/packages/hub/src/lib/snapshot-download.ts @@ -1,4 +1,4 @@ -import type { RepoDesignation } from "../types/public"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; import { listFiles } from "./list-files"; import { getHFHubCachePath, getRepoFolderName } from "./cache-management"; import { spaceInfo } from "./space-info"; @@ -11,21 +11,23 @@ import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; export const DEFAULT_REVISION = "main"; -export async function snapshotDownload(params: { - repo: RepoDesignation; - cacheDir?: string; - /** - * An optional Git revision id which can be a branch name, a tag, or a commit hash. - * - * @default "main" - */ - revision?: string; - hubUrl?: string; - /** - * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. - */ - fetch?: typeof fetch; -}): Promise { +export async function snapshotDownload( + params: { + repo: RepoDesignation; + cacheDir?: string; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + * + * @default "main" + */ + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { let cacheDir: string; if (params.cacheDir) { cacheDir = params.cacheDir; @@ -47,6 +49,7 @@ export async function snapshotDownload(params: { switch (repoId.type) { case "space": repoInfo = await spaceInfo({ + ...params, name: repoId.name, additionalFields: ["sha"], revision: revision, @@ -54,6 +57,7 @@ export async function snapshotDownload(params: { break; case "dataset": repoInfo = await datasetInfo({ + ...params, name: repoId.name, additionalFields: ["sha"], revision: revision, @@ -61,6 +65,7 @@ export async function snapshotDownload(params: { break; case "model": repoInfo = await modelInfo({ + ...params, name: repoId.name, additionalFields: ["sha"], revision: revision, @@ -86,22 +91,19 @@ export async function snapshotDownload(params: { } const cursor = listFiles({ + ...params, repo: params.repo, recursive: true, revision: repoInfo.sha, - hubUrl: params.hubUrl, - fetch: params.fetch, }); for await (const entry of cursor) { switch (entry.type) { case "file": await downloadFileToCacheDir({ - repo: params.repo, + ...params, path: entry.path, revision: commitHash, - hubUrl: params.hubUrl, - fetch: params.fetch, cacheDir: cacheDir, }); break; From b452eb29cb6bf5aeffb56bb86b1c26819699df82 Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:06:09 +0100 Subject: [PATCH 3/8] docs: adding code documentation to snapshotDownload --- packages/hub/src/lib/snapshot-download.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/hub/src/lib/snapshot-download.ts b/packages/hub/src/lib/snapshot-download.ts index f09f37e93f..9c0c2a6941 100644 --- a/packages/hub/src/lib/snapshot-download.ts +++ b/packages/hub/src/lib/snapshot-download.ts @@ -11,6 +11,12 @@ import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; export const DEFAULT_REVISION = "main"; +/** + * Downloads an entire repository at a given revision in the cache directory {@link getHFHubCachePath}. + * You can list all cached repositories using {@link scanCachedRepo} + * @remarks It uses internally {@link downloadFileToCacheDir}. + * @param params + */ export async function snapshotDownload( params: { repo: RepoDesignation; From df29083a711a32bb80c2bd8ca05a3522e188f809 Mon Sep 17 00:00:00 2001 From: "Eliott C." Date: Mon, 18 Nov 2024 17:13:01 +0100 Subject: [PATCH 4/8] Update packages/hub/src/lib/snapshot-download.ts --- packages/hub/src/lib/snapshot-download.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/hub/src/lib/snapshot-download.ts b/packages/hub/src/lib/snapshot-download.ts index 9c0c2a6941..b3e30c13f1 100644 --- a/packages/hub/src/lib/snapshot-download.ts +++ b/packages/hub/src/lib/snapshot-download.ts @@ -15,7 +15,6 @@ export const DEFAULT_REVISION = "main"; * Downloads an entire repository at a given revision in the cache directory {@link getHFHubCachePath}. * You can list all cached repositories using {@link scanCachedRepo} * @remarks It uses internally {@link downloadFileToCacheDir}. - * @param params */ export async function snapshotDownload( params: { From 5aec3d5c1e1ad09bb0297729f2a139c0e293a4aa Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Mon, 18 Nov 2024 20:51:13 +0100 Subject: [PATCH 5/8] fix: tests --- packages/hub/src/lib/snapshot-download.spec.ts | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/packages/hub/src/lib/snapshot-download.spec.ts b/packages/hub/src/lib/snapshot-download.spec.ts index 23f0927646..0c44cc8424 100644 --- a/packages/hub/src/lib/snapshot-download.spec.ts +++ b/packages/hub/src/lib/snapshot-download.spec.ts @@ -90,6 +90,10 @@ describe("snapshotDownload", () => { name: "foo/bar", additionalFields: ["sha"], revision: "main", + repo: { + name: "foo/bar", + type: "model", + }, }); }); @@ -105,6 +109,10 @@ describe("snapshotDownload", () => { name: "foo/bar", additionalFields: ["sha"], revision: "main", + repo: { + name: "foo/bar", + type: "dataset", + }, }); }); @@ -120,6 +128,10 @@ describe("snapshotDownload", () => { name: "foo/bar", additionalFields: ["sha"], revision: "main", + repo: { + name: "foo/bar", + type: "space", + }, }); }); @@ -206,7 +218,7 @@ describe("snapshotDownload", () => { // fetch mock const fetchMock: typeof fetch = vi.fn(); const hubMock = "https://foor.bar"; - const accessTokenMock = 'dummy-access-token'; + const accessTokenMock = "dummy-access-token"; vi.mocked(listFiles).mockReturnValue( toAsyncGenerator([ From 92f453bd21804eb9591291c5989a4bf1d9ec6ff1 Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Mon, 18 Nov 2024 21:01:19 +0100 Subject: [PATCH 6/8] docs: adding cache related function to readme --- packages/hub/README.md | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/packages/hub/README.md b/packages/hub/README.md index c9a650a5c6..0173741308 100644 --- a/packages/hub/README.md +++ b/packages/hub/README.md @@ -117,6 +117,10 @@ Checkout the demo: https://huggingface.co/spaces/huggingfacejs/client-side-oauth The `@huggingface/hub` package provide basic capabilities to scan the cache directory. Learn more about [Manage huggingface_hub cache-system](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache). +### `scanCacheDir` + +You can get the list of cached repository using the `scanCacheDir` function. + ```ts import { scanCacheDir } from "@huggingface/hub"; @@ -124,7 +128,40 @@ const result = await scanCacheDir(); console.log(result); ``` -Note that the cache directory is created and used only by the Python and Rust libraries. Downloading files using the `@huggingface/hub` package won't use the cache directory. +Note: this does not work in the browser + +### `downloadFileToCacheDir` + +You can cache a file of a repository using the `downloadFileToCacheDir` function. + +```ts +import { snapshotDownload } from "@huggingface/hub"; + +const file = await downloadFileToCacheDir({ + repo: 'foo/bar', + path: 'README.md' +}); + +console.log(file); +``` +Note: this does not work in the browser + +### `snapshotDownload` + +You can download an entire repository at a given revision in the cache directory using the `snapshotDownload` function. + +```ts +import { snapshotDownload } from "@huggingface/hub"; + +const directory = await snapshotDownload({ + repo: 'foo/bar', +}); + +console.log(directory); +``` +The code use internally the `downloadFileToCacheDir` function. + +Note: this does not work in the browser ## Performance considerations From 58389eb83c8f8656bf1da5f6370a2d9a8de5559a Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Mon, 18 Nov 2024 21:03:07 +0100 Subject: [PATCH 7/8] docs: typo --- packages/hub/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/hub/README.md b/packages/hub/README.md index 0173741308..d1b353c3dd 100644 --- a/packages/hub/README.md +++ b/packages/hub/README.md @@ -135,7 +135,7 @@ Note: this does not work in the browser You can cache a file of a repository using the `downloadFileToCacheDir` function. ```ts -import { snapshotDownload } from "@huggingface/hub"; +import { downloadFileToCacheDir } from "@huggingface/hub"; const file = await downloadFileToCacheDir({ repo: 'foo/bar', From ab7f73269f064ac1ff47c173d30a377f5975894f Mon Sep 17 00:00:00 2001 From: "Eliott C." Date: Tue, 19 Nov 2024 08:52:40 +0100 Subject: [PATCH 8/8] Update packages/hub/README.md --- packages/hub/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/hub/README.md b/packages/hub/README.md index d1b353c3dd..3d87c15dc0 100644 --- a/packages/hub/README.md +++ b/packages/hub/README.md @@ -119,7 +119,7 @@ The `@huggingface/hub` package provide basic capabilities to scan the cache dire ### `scanCacheDir` -You can get the list of cached repository using the `scanCacheDir` function. +You can get the list of cached repositories using the `scanCacheDir` function. ```ts import { scanCacheDir } from "@huggingface/hub";