Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/hub/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"./src/utils/FileBlob.ts": false,
"./src/lib/cache-management.ts": false,
"./src/lib/download-file-to-cache-dir.ts": false,
"./src/lib/snapshot-download.ts": false,
"./dist/index.js": "./dist/browser/index.js",
"./dist/index.mjs": "./dist/browser/index.mjs"
},
Expand Down
1 change: 1 addition & 0 deletions packages/hub/src/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export * from "./oauth-handle-redirect";
export * from "./oauth-login-url";
export * from "./parse-safetensors-metadata";
export * from "./paths-info";
export * from "./snapshot-download";
export * from "./space-info";
export * from "./upload-file";
export * from "./upload-files";
Expand Down
263 changes: 263 additions & 0 deletions packages/hub/src/lib/snapshot-download.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import { expect, test, describe, vi, beforeEach } from "vitest";
import { dirname, join } from "node:path";
import { mkdir, writeFile } from "node:fs/promises";
import { getHFHubCachePath } from "./cache-management";
import { downloadFileToCacheDir } from "./download-file-to-cache-dir";
import { snapshotDownload } from "./snapshot-download";
import type { ListFileEntry } from "./list-files";
import { listFiles } from "./list-files";
import { modelInfo } from "./model-info";
import type { ModelEntry } from "./list-models";
import type { ApiModelInfo } from "../types/api/api-model";
import { datasetInfo } from "./dataset-info";
import type { DatasetEntry } from "./list-datasets";
import type { ApiDatasetInfo } from "../types/api/api-dataset";
import { spaceInfo } from "./space-info";
import type { SpaceEntry } from "./list-spaces";
import type { ApiSpaceInfo } from "../types/api/api-space";

vi.mock("node:fs/promises", () => ({
writeFile: vi.fn(),
mkdir: vi.fn(),
}));

vi.mock("./space-info", () => ({
spaceInfo: vi.fn(),
}));

vi.mock("./dataset-info", () => ({
datasetInfo: vi.fn(),
}));

vi.mock("./model-info", () => ({
modelInfo: vi.fn(),
}));

vi.mock("./list-files", () => ({
listFiles: vi.fn(),
}));

vi.mock("./download-file-to-cache-dir", () => ({
downloadFileToCacheDir: vi.fn(),
}));

const DUMMY_SHA = "dummy-sha";

// utility method to transform an array of ListFileEntry to an AsyncGenerator<ListFileEntry>
async function* toAsyncGenerator(content: ListFileEntry[]): AsyncGenerator<ListFileEntry> {
for (const entry of content) {
yield Promise.resolve(entry);
}
}

beforeEach(() => {
vi.resetAllMocks();
vi.mocked(listFiles).mockReturnValue(toAsyncGenerator([]));

// mock repo info
vi.mocked(modelInfo).mockResolvedValue({
sha: DUMMY_SHA,
} as ModelEntry & ApiModelInfo);
vi.mocked(datasetInfo).mockResolvedValue({
sha: DUMMY_SHA,
} as DatasetEntry & ApiDatasetInfo);
vi.mocked(spaceInfo).mockResolvedValue({
sha: DUMMY_SHA,
} as SpaceEntry & ApiSpaceInfo);
});

describe("snapshotDownload", () => {
test("empty AsyncGenerator should not call downloadFileToCacheDir", async () => {
await snapshotDownload({
repo: {
name: "foo/bar",
type: "space",
},
});

expect(downloadFileToCacheDir).not.toHaveBeenCalled();
});

test("repo type model should use modelInfo", async () => {
await snapshotDownload({
repo: {
name: "foo/bar",
type: "model",
},
});
expect(modelInfo).toHaveBeenCalledOnce();
expect(modelInfo).toHaveBeenCalledWith({
name: "foo/bar",
additionalFields: ["sha"],
revision: "main",
});
});

test("repo type dataset should use datasetInfo", async () => {
await snapshotDownload({
repo: {
name: "foo/bar",
type: "dataset",
},
});
expect(datasetInfo).toHaveBeenCalledOnce();
expect(datasetInfo).toHaveBeenCalledWith({
name: "foo/bar",
additionalFields: ["sha"],
revision: "main",
});
});

test("repo type space should use spaceInfo", async () => {
await snapshotDownload({
repo: {
name: "foo/bar",
type: "space",
},
});
expect(spaceInfo).toHaveBeenCalledOnce();
expect(spaceInfo).toHaveBeenCalledWith({
name: "foo/bar",
additionalFields: ["sha"],
revision: "main",
});
});

test("commitHash should be saved to ref folder", async () => {
await snapshotDownload({
repo: {
name: "foo/bar",
type: "space",
},
revision: "dummy-revision",
});

// cross-platform testing
const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "refs", "dummy-revision");
expect(mkdir).toHaveBeenCalledWith(dirname(expectedPath), { recursive: true });
expect(writeFile).toHaveBeenCalledWith(expectedPath, DUMMY_SHA);
});

test("directory ListFileEntry should mkdir it", async () => {
vi.mocked(listFiles).mockReturnValue(
toAsyncGenerator([
{
oid: "dummy-etag",
type: "directory",
path: "potatoes",
size: 0,
lastCommit: {
date: new Date().toISOString(),
id: DUMMY_SHA,
title: "feat: best commit",
},
},
])
);

await snapshotDownload({
repo: {
name: "foo/bar",
type: "space",
},
});

// cross-platform testing
const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "snapshots", DUMMY_SHA, "potatoes");
expect(mkdir).toHaveBeenCalledWith(expectedPath, { recursive: true });
});

test("files in ListFileEntry should download them", async () => {
const entries: ListFileEntry[] = Array.from({ length: 10 }, (_, i) => ({
oid: `dummy-etag-${i}`,
type: "file",
path: `file-${i}.txt`,
size: i,
lastCommit: {
date: new Date().toISOString(),
id: DUMMY_SHA,
title: "feat: best commit",
},
}));
vi.mocked(listFiles).mockReturnValue(toAsyncGenerator(entries));

await snapshotDownload({
repo: {
name: "foo/bar",
type: "space",
},
});

for (const entry of entries) {
expect(downloadFileToCacheDir).toHaveBeenCalledWith(
expect.objectContaining({
repo: {
name: "foo/bar",
type: "space",
},
path: entry.path,
revision: DUMMY_SHA,
})
);
}
});

test("custom params should be propagated", async () => {
// fetch mock
const fetchMock: typeof fetch = vi.fn();
const hubMock = "https://foor.bar";
const accessTokenMock = 'dummy-access-token';

vi.mocked(listFiles).mockReturnValue(
toAsyncGenerator([
{
oid: `dummy-etag`,
type: "file",
path: `file.txt`,
size: 10,
lastCommit: {
date: new Date().toISOString(),
id: DUMMY_SHA,
title: "feat: best commit",
},
},
])
);

await snapshotDownload({
repo: {
name: "foo/bar",
type: "space",
},
hubUrl: hubMock,
fetch: fetchMock,
accessToken: accessTokenMock,
});

expect(spaceInfo).toHaveBeenCalledWith(
expect.objectContaining({
fetch: fetchMock,
hubUrl: hubMock,
accessToken: accessTokenMock,
})
);

// list files should receive custom fetch
expect(listFiles).toHaveBeenCalledWith(
expect.objectContaining({
fetch: fetchMock,
hubUrl: hubMock,
accessToken: accessTokenMock,
})
);

// download file to cache should receive custom fetch
expect(downloadFileToCacheDir).toHaveBeenCalledWith(
expect.objectContaining({
fetch: fetchMock,
hubUrl: hubMock,
accessToken: accessTokenMock,
})
);
});
});
119 changes: 119 additions & 0 deletions packages/hub/src/lib/snapshot-download.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { listFiles } from "./list-files";
import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
import { spaceInfo } from "./space-info";
import { datasetInfo } from "./dataset-info";
import { modelInfo } from "./model-info";
import { toRepoId } from "../utils/toRepoId";
import { join, dirname } from "node:path";
import { mkdir, writeFile } from "node:fs/promises";
import { downloadFileToCacheDir } from "./download-file-to-cache-dir";

export const DEFAULT_REVISION = "main";

export async function snapshotDownload(
params: {
repo: RepoDesignation;
cacheDir?: string;
/**
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
*
* @default "main"
*/
revision?: string;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<string> {
let cacheDir: string;
if (params.cacheDir) {
cacheDir = params.cacheDir;
} else {
cacheDir = getHFHubCachePath();
}

let revision: string;
if (params.revision) {
revision = params.revision;
} else {
revision = DEFAULT_REVISION;
}

const repoId = toRepoId(params.repo);

// get repository revision value (sha)
let repoInfo: { sha: string };
switch (repoId.type) {
case "space":
repoInfo = await spaceInfo({
...params,
name: repoId.name,
additionalFields: ["sha"],
revision: revision,
});
break;
case "dataset":
repoInfo = await datasetInfo({
...params,
name: repoId.name,
additionalFields: ["sha"],
revision: revision,
});
break;
case "model":
repoInfo = await modelInfo({
...params,
name: repoId.name,
additionalFields: ["sha"],
revision: revision,
});
break;
default:
throw new Error(`invalid repository type ${repoId.type}`);
}

const commitHash: string = repoInfo.sha;

// get storage folder
const storageFolder = join(cacheDir, getRepoFolderName(repoId));
const snapshotFolder = join(storageFolder, "snapshots", commitHash);

// if passed revision is not identical to commit_hash
// then revision has to be a branch name or tag name.
// In that case store a ref.
if (revision !== commitHash) {
const refPath = join(storageFolder, "refs", revision);
await mkdir(dirname(refPath), { recursive: true });
await writeFile(refPath, commitHash);
}

const cursor = listFiles({
...params,
repo: params.repo,
recursive: true,
revision: repoInfo.sha,
});

for await (const entry of cursor) {
switch (entry.type) {
case "file":
await downloadFileToCacheDir({
...params,
path: entry.path,
revision: commitHash,
cacheDir: cacheDir,
});
break;
case "directory":
await mkdir(join(snapshotFolder, entry.path), { recursive: true });
break;
default:
throw new Error(`unknown entry type: ${entry.type}`);
}
}

return snapshotFolder;
}
Loading
Loading