diff --git a/packages/hub/README.md b/packages/hub/README.md index c9a650a5c6..3d87c15dc0 100644 --- a/packages/hub/README.md +++ b/packages/hub/README.md @@ -117,6 +117,10 @@ Checkout the demo: https://huggingface.co/spaces/huggingfacejs/client-side-oauth The `@huggingface/hub` package provide basic capabilities to scan the cache directory. Learn more about [Manage huggingface_hub cache-system](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache). +### `scanCacheDir` + +You can get the list of cached repositories using the `scanCacheDir` function. + ```ts import { scanCacheDir } from "@huggingface/hub"; @@ -124,7 +128,40 @@ const result = await scanCacheDir(); console.log(result); ``` -Note that the cache directory is created and used only by the Python and Rust libraries. Downloading files using the `@huggingface/hub` package won't use the cache directory. +Note: this does not work in the browser + +### `downloadFileToCacheDir` + +You can cache a file of a repository using the `downloadFileToCacheDir` function. + +```ts +import { downloadFileToCacheDir } from "@huggingface/hub"; + +const file = await downloadFileToCacheDir({ + repo: 'foo/bar', + path: 'README.md' +}); + +console.log(file); +``` +Note: this does not work in the browser + +### `snapshotDownload` + +You can download an entire repository at a given revision in the cache directory using the `snapshotDownload` function. + +```ts +import { snapshotDownload } from "@huggingface/hub"; + +const directory = await snapshotDownload({ + repo: 'foo/bar', +}); + +console.log(directory); +``` +The code use internally the `downloadFileToCacheDir` function. + +Note: this does not work in the browser ## Performance considerations diff --git a/packages/hub/package.json b/packages/hub/package.json index 04486821d5..9e84de2c1a 100644 --- a/packages/hub/package.json +++ b/packages/hub/package.json @@ -22,6 +22,7 @@ "./src/utils/FileBlob.ts": false, "./src/lib/cache-management.ts": false, "./src/lib/download-file-to-cache-dir.ts": false, + "./src/lib/snapshot-download.ts": false, "./dist/index.js": "./dist/browser/index.js", "./dist/index.mjs": "./dist/browser/index.mjs" }, diff --git a/packages/hub/src/lib/index.ts b/packages/hub/src/lib/index.ts index c2a2fbe06c..24e239bdcb 100644 --- a/packages/hub/src/lib/index.ts +++ b/packages/hub/src/lib/index.ts @@ -21,6 +21,7 @@ export * from "./oauth-handle-redirect"; export * from "./oauth-login-url"; export * from "./parse-safetensors-metadata"; export * from "./paths-info"; +export * from "./snapshot-download"; export * from "./space-info"; export * from "./upload-file"; export * from "./upload-files"; diff --git a/packages/hub/src/lib/snapshot-download.spec.ts b/packages/hub/src/lib/snapshot-download.spec.ts new file mode 100644 index 0000000000..0c44cc8424 --- /dev/null +++ b/packages/hub/src/lib/snapshot-download.spec.ts @@ -0,0 +1,275 @@ +import { expect, test, describe, vi, beforeEach } from "vitest"; +import { dirname, join } from "node:path"; +import { mkdir, writeFile } from "node:fs/promises"; +import { getHFHubCachePath } from "./cache-management"; +import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; +import { snapshotDownload } from "./snapshot-download"; +import type { ListFileEntry } from "./list-files"; +import { listFiles } from "./list-files"; +import { modelInfo } from "./model-info"; +import type { ModelEntry } from "./list-models"; +import type { ApiModelInfo } from "../types/api/api-model"; +import { datasetInfo } from "./dataset-info"; +import type { DatasetEntry } from "./list-datasets"; +import type { ApiDatasetInfo } from "../types/api/api-dataset"; +import { spaceInfo } from "./space-info"; +import type { SpaceEntry } from "./list-spaces"; +import type { ApiSpaceInfo } from "../types/api/api-space"; + +vi.mock("node:fs/promises", () => ({ + writeFile: vi.fn(), + mkdir: vi.fn(), +})); + +vi.mock("./space-info", () => ({ + spaceInfo: vi.fn(), +})); + +vi.mock("./dataset-info", () => ({ + datasetInfo: vi.fn(), +})); + +vi.mock("./model-info", () => ({ + modelInfo: vi.fn(), +})); + +vi.mock("./list-files", () => ({ + listFiles: vi.fn(), +})); + +vi.mock("./download-file-to-cache-dir", () => ({ + downloadFileToCacheDir: vi.fn(), +})); + +const DUMMY_SHA = "dummy-sha"; + +// utility method to transform an array of ListFileEntry to an AsyncGenerator +async function* toAsyncGenerator(content: ListFileEntry[]): AsyncGenerator { + for (const entry of content) { + yield Promise.resolve(entry); + } +} + +beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(listFiles).mockReturnValue(toAsyncGenerator([])); + + // mock repo info + vi.mocked(modelInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as ModelEntry & ApiModelInfo); + vi.mocked(datasetInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as DatasetEntry & ApiDatasetInfo); + vi.mocked(spaceInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as SpaceEntry & ApiSpaceInfo); +}); + +describe("snapshotDownload", () => { + test("empty AsyncGenerator should not call downloadFileToCacheDir", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + expect(downloadFileToCacheDir).not.toHaveBeenCalled(); + }); + + test("repo type model should use modelInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "model", + }, + }); + expect(modelInfo).toHaveBeenCalledOnce(); + expect(modelInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + repo: { + name: "foo/bar", + type: "model", + }, + }); + }); + + test("repo type dataset should use datasetInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "dataset", + }, + }); + expect(datasetInfo).toHaveBeenCalledOnce(); + expect(datasetInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + repo: { + name: "foo/bar", + type: "dataset", + }, + }); + }); + + test("repo type space should use spaceInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + expect(spaceInfo).toHaveBeenCalledOnce(); + expect(spaceInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + repo: { + name: "foo/bar", + type: "space", + }, + }); + }); + + test("commitHash should be saved to ref folder", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + revision: "dummy-revision", + }); + + // cross-platform testing + const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "refs", "dummy-revision"); + expect(mkdir).toHaveBeenCalledWith(dirname(expectedPath), { recursive: true }); + expect(writeFile).toHaveBeenCalledWith(expectedPath, DUMMY_SHA); + }); + + test("directory ListFileEntry should mkdir it", async () => { + vi.mocked(listFiles).mockReturnValue( + toAsyncGenerator([ + { + oid: "dummy-etag", + type: "directory", + path: "potatoes", + size: 0, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + }, + ]) + ); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + // cross-platform testing + const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "snapshots", DUMMY_SHA, "potatoes"); + expect(mkdir).toHaveBeenCalledWith(expectedPath, { recursive: true }); + }); + + test("files in ListFileEntry should download them", async () => { + const entries: ListFileEntry[] = Array.from({ length: 10 }, (_, i) => ({ + oid: `dummy-etag-${i}`, + type: "file", + path: `file-${i}.txt`, + size: i, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + })); + vi.mocked(listFiles).mockReturnValue(toAsyncGenerator(entries)); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + for (const entry of entries) { + expect(downloadFileToCacheDir).toHaveBeenCalledWith( + expect.objectContaining({ + repo: { + name: "foo/bar", + type: "space", + }, + path: entry.path, + revision: DUMMY_SHA, + }) + ); + } + }); + + test("custom params should be propagated", async () => { + // fetch mock + const fetchMock: typeof fetch = vi.fn(); + const hubMock = "https://foor.bar"; + const accessTokenMock = "dummy-access-token"; + + vi.mocked(listFiles).mockReturnValue( + toAsyncGenerator([ + { + oid: `dummy-etag`, + type: "file", + path: `file.txt`, + size: 10, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + }, + ]) + ); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + hubUrl: hubMock, + fetch: fetchMock, + accessToken: accessTokenMock, + }); + + expect(spaceInfo).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + + // list files should receive custom fetch + expect(listFiles).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + + // download file to cache should receive custom fetch + expect(downloadFileToCacheDir).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + }); +}); diff --git a/packages/hub/src/lib/snapshot-download.ts b/packages/hub/src/lib/snapshot-download.ts new file mode 100644 index 0000000000..b3e30c13f1 --- /dev/null +++ b/packages/hub/src/lib/snapshot-download.ts @@ -0,0 +1,124 @@ +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { listFiles } from "./list-files"; +import { getHFHubCachePath, getRepoFolderName } from "./cache-management"; +import { spaceInfo } from "./space-info"; +import { datasetInfo } from "./dataset-info"; +import { modelInfo } from "./model-info"; +import { toRepoId } from "../utils/toRepoId"; +import { join, dirname } from "node:path"; +import { mkdir, writeFile } from "node:fs/promises"; +import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; + +export const DEFAULT_REVISION = "main"; + +/** + * Downloads an entire repository at a given revision in the cache directory {@link getHFHubCachePath}. + * You can list all cached repositories using {@link scanCachedRepo} + * @remarks It uses internally {@link downloadFileToCacheDir}. + */ +export async function snapshotDownload( + params: { + repo: RepoDesignation; + cacheDir?: string; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + * + * @default "main" + */ + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { + let cacheDir: string; + if (params.cacheDir) { + cacheDir = params.cacheDir; + } else { + cacheDir = getHFHubCachePath(); + } + + let revision: string; + if (params.revision) { + revision = params.revision; + } else { + revision = DEFAULT_REVISION; + } + + const repoId = toRepoId(params.repo); + + // get repository revision value (sha) + let repoInfo: { sha: string }; + switch (repoId.type) { + case "space": + repoInfo = await spaceInfo({ + ...params, + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + case "dataset": + repoInfo = await datasetInfo({ + ...params, + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + case "model": + repoInfo = await modelInfo({ + ...params, + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + default: + throw new Error(`invalid repository type ${repoId.type}`); + } + + const commitHash: string = repoInfo.sha; + + // get storage folder + const storageFolder = join(cacheDir, getRepoFolderName(repoId)); + const snapshotFolder = join(storageFolder, "snapshots", commitHash); + + // if passed revision is not identical to commit_hash + // then revision has to be a branch name or tag name. + // In that case store a ref. + if (revision !== commitHash) { + const refPath = join(storageFolder, "refs", revision); + await mkdir(dirname(refPath), { recursive: true }); + await writeFile(refPath, commitHash); + } + + const cursor = listFiles({ + ...params, + repo: params.repo, + recursive: true, + revision: repoInfo.sha, + }); + + for await (const entry of cursor) { + switch (entry.type) { + case "file": + await downloadFileToCacheDir({ + ...params, + path: entry.path, + revision: commitHash, + cacheDir: cacheDir, + }); + break; + case "directory": + await mkdir(join(snapshotFolder, entry.path), { recursive: true }); + break; + default: + throw new Error(`unknown entry type: ${entry.type}`); + } + } + + return snapshotFolder; +} diff --git a/packages/hub/vitest-browser.config.mts b/packages/hub/vitest-browser.config.mts index 60fcbfbfcf..db22fb67cf 100644 --- a/packages/hub/vitest-browser.config.mts +++ b/packages/hub/vitest-browser.config.mts @@ -2,6 +2,12 @@ import { configDefaults, defineConfig } from "vitest/config"; export default defineConfig({ test: { - exclude: [...configDefaults.exclude, "src/utils/FileBlob.spec.ts", "src/lib/cache-management.spec.ts", "src/lib/download-file-to-cache-dir.spec.ts"], + exclude: [ + ...configDefaults.exclude, + "src/utils/FileBlob.spec.ts", + "src/lib/cache-management.spec.ts", + "src/lib/download-file-to-cache-dir.spec.ts", + "src/lib/snapshot-download.spec.ts", + ], }, });