Skip to content

Commit 57df98e

Browse files
committed
refactor: moving to download-file-cache.ts file
1 parent f2f2972 commit 57df98e

File tree

5 files changed

+367
-355
lines changed

5 files changed

+367
-355
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import { expect, test, describe, vi, beforeEach } from "vitest";
2+
import type { RepoDesignation, RepoId } from "../types/public";
3+
import { dirname, join } from "node:path";
4+
import { lstat, mkdir, stat, symlink, writeFile, rename } from "node:fs/promises";
5+
import { pathsInfo } from "./paths-info";
6+
import type { Stats } from "node:fs";
7+
import { getHFHubCache, getRepoFolderName } from "./cache-management";
8+
import { toRepoId } from "../utils/toRepoId";
9+
import { downloadFileToCacheDir } from "./download-file-cache";
10+
11+
vi.mock('node:fs/promises', () => ({
12+
writeFile: vi.fn(),
13+
rename: vi.fn(),
14+
symlink: vi.fn(),
15+
lstat: vi.fn(),
16+
mkdir: vi.fn(),
17+
stat: vi.fn()
18+
}));
19+
20+
vi.mock('./paths-info', () => ({
21+
pathsInfo: vi.fn(),
22+
}));
23+
24+
const DUMMY_REPO: RepoId = {
25+
name: 'hello-world',
26+
type: 'model',
27+
};
28+
29+
const DUMMY_ETAG = "dummy-etag";
30+
31+
// utility test method to get blob file path
32+
function _getBlobFile(params: {
33+
repo: RepoDesignation;
34+
etag: string;
35+
cacheDir?: string, // default to {@link getHFHubCache}
36+
}) {
37+
return join(params.cacheDir ?? getHFHubCache(), getRepoFolderName(toRepoId(params.repo)), "blobs", params.etag);
38+
}
39+
40+
// utility test method to get snapshot file path
41+
function _getSnapshotFile(params: {
42+
repo: RepoDesignation;
43+
path: string;
44+
revision : string;
45+
cacheDir?: string, // default to {@link getHFHubCache}
46+
}) {
47+
return join(params.cacheDir ?? getHFHubCache(), getRepoFolderName(toRepoId(params.repo)), "snapshots", params.revision, params.path);
48+
}
49+
50+
describe('downloadFileToCacheDir', () => {
51+
const fetchMock: typeof fetch = vi.fn();
52+
beforeEach(() => {
53+
vi.resetAllMocks();
54+
// mock 200 request
55+
vi.mocked(fetchMock).mockResolvedValue({
56+
status: 200,
57+
ok: true,
58+
body: 'dummy-body'
59+
} as unknown as Response);
60+
61+
// prevent to use caching
62+
vi.mocked(stat).mockRejectedValue(new Error('Do not exists'));
63+
vi.mocked(lstat).mockRejectedValue(new Error('Do not exists'));
64+
});
65+
66+
test('should throw an error if fileDownloadInfo return nothing', async () => {
67+
await expect(async () => {
68+
await downloadFileToCacheDir({
69+
repo: DUMMY_REPO,
70+
path: '/README.md',
71+
fetch: fetchMock,
72+
});
73+
}).rejects.toThrowError('cannot get path info for /README.md');
74+
75+
expect(pathsInfo).toHaveBeenCalledWith(expect.objectContaining({
76+
repo: DUMMY_REPO,
77+
paths: ['/README.md'],
78+
fetch: fetchMock,
79+
}));
80+
});
81+
82+
test('existing symlinked and blob should not re-download it', async () => {
83+
// <cache>/<repo>/<revision>/snapshots/README.md
84+
const expectPointer = _getSnapshotFile({
85+
repo: DUMMY_REPO,
86+
path: '/README.md',
87+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
88+
});
89+
// stat ensure a symlink and the pointed file exists
90+
vi.mocked(stat).mockResolvedValue({} as Stats) // prevent default mocked reject
91+
92+
const output = await downloadFileToCacheDir({
93+
repo: DUMMY_REPO,
94+
path: '/README.md',
95+
fetch: fetchMock,
96+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
97+
});
98+
99+
expect(stat).toHaveBeenCalledOnce();
100+
// Get call argument for stat
101+
const starArg = vi.mocked(stat).mock.calls[0][0];
102+
103+
expect(starArg).toBe(expectPointer)
104+
expect(fetchMock).not.toHaveBeenCalledWith();
105+
106+
expect(output).toBe(expectPointer);
107+
});
108+
109+
test('existing blob should only create the symlink', async () => {
110+
// <cache>/<repo>/<revision>/snapshots/README.md
111+
const expectPointer = _getSnapshotFile({
112+
repo: DUMMY_REPO,
113+
path: '/README.md',
114+
revision: "dummy-commit-hash",
115+
});
116+
// <cache>/<repo>/blobs/<etag>
117+
const expectedBlob = _getBlobFile({
118+
repo: DUMMY_REPO,
119+
etag: DUMMY_ETAG,
120+
});
121+
122+
// mock existing blob only no symlink
123+
vi.mocked(lstat).mockResolvedValue({} as Stats);
124+
// mock pathsInfo resolve content
125+
vi.mocked(pathsInfo).mockResolvedValue([{
126+
oid: DUMMY_ETAG,
127+
size: 55,
128+
path: 'README.md',
129+
type: 'file',
130+
lastCommit: {
131+
date: new Date(),
132+
id: 'dummy-commit-hash',
133+
title: 'Commit msg',
134+
},
135+
}]);
136+
137+
const output = await downloadFileToCacheDir({
138+
repo: DUMMY_REPO,
139+
path: '/README.md',
140+
fetch: fetchMock,
141+
});
142+
143+
expect(stat).not.toHaveBeenCalled();
144+
// should have check for the blob
145+
expect(lstat).toHaveBeenCalled();
146+
expect(vi.mocked(lstat).mock.calls[0][0]).toBe(expectedBlob);
147+
148+
// symlink should have been created
149+
expect(symlink).toHaveBeenCalledOnce();
150+
// no download done
151+
expect(fetchMock).not.toHaveBeenCalled();
152+
153+
expect(output).toBe(expectPointer);
154+
});
155+
156+
test('expect resolve value to be the pointer path of downloaded file', async () => {
157+
// <cache>/<repo>/<revision>/snapshots/README.md
158+
const expectPointer = _getSnapshotFile({
159+
repo: DUMMY_REPO,
160+
path: '/README.md',
161+
revision: "dummy-commit-hash",
162+
});
163+
// <cache>/<repo>/blobs/<etag>
164+
const expectedBlob = _getBlobFile({
165+
repo: DUMMY_REPO,
166+
etag: DUMMY_ETAG,
167+
});
168+
169+
vi.mocked(pathsInfo).mockResolvedValue([{
170+
oid: DUMMY_ETAG,
171+
size: 55,
172+
path: 'README.md',
173+
type: 'file',
174+
lastCommit: {
175+
date: new Date(),
176+
id: 'dummy-commit-hash',
177+
title: 'Commit msg',
178+
},
179+
}]);
180+
181+
const output = await downloadFileToCacheDir({
182+
repo: DUMMY_REPO,
183+
path: '/README.md',
184+
fetch: fetchMock,
185+
});
186+
187+
// expect blobs and snapshots folder to have been mkdir
188+
expect(vi.mocked(mkdir).mock.calls[0][0]).toBe(dirname(expectedBlob));
189+
expect(vi.mocked(mkdir).mock.calls[1][0]).toBe(dirname(expectPointer));
190+
191+
expect(output).toBe(expectPointer);
192+
});
193+
194+
test('should write fetch response to blob', async () => {
195+
// <cache>/<repo>/<revision>/snapshots/README.md
196+
const expectPointer = _getSnapshotFile({
197+
repo: DUMMY_REPO,
198+
path: '/README.md',
199+
revision: "dummy-commit-hash",
200+
});
201+
// <cache>/<repo>/blobs/<etag>
202+
const expectedBlob = _getBlobFile({
203+
repo: DUMMY_REPO,
204+
etag: DUMMY_ETAG,
205+
});
206+
207+
// mock pathsInfo resolve content
208+
vi.mocked(pathsInfo).mockResolvedValue([{
209+
oid: DUMMY_ETAG,
210+
size: 55,
211+
path: 'README.md',
212+
type: 'file',
213+
lastCommit: {
214+
date: new Date(),
215+
id: 'dummy-commit-hash',
216+
title: 'Commit msg',
217+
},
218+
}]);
219+
220+
await downloadFileToCacheDir({
221+
repo: DUMMY_REPO,
222+
path: '/README.md',
223+
fetch: fetchMock,
224+
});
225+
226+
const incomplete = `${expectedBlob}.incomplete`;
227+
// 1. should write fetch#response#body to incomplete file
228+
expect(writeFile).toHaveBeenCalledWith(incomplete, 'dummy-body');
229+
// 2. should rename the incomplete to the blob expected name
230+
expect(rename).toHaveBeenCalledWith(incomplete, expectedBlob);
231+
// 3. should create symlink pointing to blob
232+
expect(symlink).toHaveBeenCalledWith(expectedBlob, expectPointer);
233+
});
234+
});
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import { getHFHubCache, getRepoFolderName } from "./cache-management";
2+
import { dirname, join } from "node:path";
3+
import { writeFile, rename, symlink, lstat, mkdir, stat } from "node:fs/promises";
4+
import type { CommitInfo, PathInfo } from "./paths-info";
5+
import { pathsInfo } from "./paths-info";
6+
import type { CredentialsParams, RepoDesignation } from "../types/public";
7+
import { toRepoId } from "../utils/toRepoId";
8+
import { downloadFile } from "./download-file";
9+
10+
export const REGEX_COMMIT_HASH: RegExp = new RegExp("^[0-9a-f]{40}$");
11+
12+
function getFilePointer(storageFolder: string, revision: string, relativeFilename: string): string {
13+
const snapshotPath = join(storageFolder, "snapshots");
14+
return join(snapshotPath, revision, relativeFilename);
15+
}
16+
17+
/**
18+
* handy method to check if a file exists, or the pointer of a symlinks exists
19+
* @param path
20+
* @param followSymlinks
21+
*/
22+
async function exists(path: string, followSymlinks?: boolean): Promise<boolean> {
23+
try {
24+
if(followSymlinks) {
25+
await stat(path);
26+
} else {
27+
await lstat(path);
28+
}
29+
return true;
30+
} catch (err: unknown) {
31+
return false;
32+
}
33+
}
34+
35+
/**
36+
* Download a given file if it's not already present in the local cache.
37+
* @param params
38+
* @return the symlink to the blob object
39+
*/
40+
export async function downloadFileToCacheDir(
41+
params: {
42+
repo: RepoDesignation;
43+
path: string;
44+
/**
45+
* If true, will download the raw git file.
46+
*
47+
* For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead.
48+
*/
49+
raw?: boolean;
50+
/**
51+
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
52+
*
53+
* @default "main"
54+
*/
55+
revision?: string;
56+
hubUrl?: string;
57+
cacheDir?: string,
58+
/**
59+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
60+
*/
61+
fetch?: typeof fetch;
62+
} & Partial<CredentialsParams>
63+
): Promise<string> {
64+
// get revision provided or default to main
65+
const revision = params.revision ?? "main";
66+
const cacheDir = params.cacheDir ?? getHFHubCache();
67+
// get repo id
68+
const repoId = toRepoId(params.repo);
69+
// get storage folder
70+
const storageFolder = join(cacheDir, getRepoFolderName(repoId));
71+
72+
let commitHash: string | undefined;
73+
74+
// if user provides a commitHash as revision, and they already have the file on disk, shortcut everything.
75+
if (REGEX_COMMIT_HASH.test(revision)) {
76+
commitHash = revision;
77+
const pointerPath = getFilePointer(storageFolder, revision, params.path);
78+
if (await exists(pointerPath, true)) return pointerPath;
79+
}
80+
81+
const pathsInformation: (PathInfo & { lastCommit: CommitInfo })[] = await pathsInfo({
82+
...params,
83+
paths: [params.path],
84+
revision: revision,
85+
expand: true,
86+
});
87+
if (!pathsInformation || pathsInformation.length !== 1) throw new Error(`cannot get path info for ${params.path}`);
88+
89+
let etag: string;
90+
if (pathsInformation[0].lfs) {
91+
etag = pathsInformation[0].lfs.oid; // get the LFS pointed file oid
92+
} else {
93+
etag = pathsInformation[0].oid; // get the repo file if not a LFS pointer
94+
}
95+
96+
const pointerPath = getFilePointer(storageFolder, commitHash ?? pathsInformation[0].lastCommit.id, params.path);
97+
const blobPath = join(storageFolder, "blobs", etag);
98+
99+
// mkdir blob and pointer path parent directory
100+
await mkdir(dirname(blobPath), { recursive: true });
101+
await mkdir(dirname(pointerPath), { recursive: true });
102+
103+
// We might already have the blob but not the pointer
104+
// shortcut the download if needed
105+
if (await exists(blobPath)) {
106+
// create symlinks in snapshot folder to blob object
107+
await symlink(blobPath, pointerPath);
108+
return pointerPath;
109+
}
110+
111+
const incomplete = `${blobPath}.incomplete`;
112+
console.debug(`Downloading ${params.path} to ${incomplete}`);
113+
114+
const response: Response | null = await downloadFile({
115+
...params,
116+
revision: commitHash,
117+
});
118+
119+
if (!response || !response.ok || !response.body) throw new Error(`invalid response for file ${params.path}`);
120+
121+
// @ts-expect-error resp.body is a Stream, but Stream in internal to node
122+
await writeFile(incomplete, response.body);
123+
124+
// rename .incomplete file to expect blob
125+
await rename(incomplete, blobPath);
126+
// create symlinks in snapshot folder to blob object
127+
await symlink(blobPath, pointerPath);
128+
return pointerPath;
129+
}

0 commit comments

Comments
 (0)