diff --git a/packages/hub/README.md b/packages/hub/README.md index 2510ef7814..0633f06271 100644 --- a/packages/hub/README.md +++ b/packages/hub/README.md @@ -57,6 +57,8 @@ await hub.uploadFiles({ }, // Local file URL pathToFileURL("./pytorch-model.bin"), + // Local folder URL + pathToFileURL("./models"), // Web URL new URL("https://huggingface.co/xlm-roberta-base/resolve/main/tokenizer.json"), // Path + Web URL diff --git a/packages/hub/package.json b/packages/hub/package.json index 6ec1d13b20..15d2e31c8d 100644 --- a/packages/hub/package.json +++ b/packages/hub/package.json @@ -19,6 +19,7 @@ }, "browser": { "./src/utils/sha256-node.ts": false, + "./src/utils/sub-paths.ts": false, "./src/utils/FileBlob.ts": false, "./src/lib/cache-management.ts": false, "./src/lib/download-file-to-cache-dir.ts": false, diff --git a/packages/hub/src/lib/commit.ts b/packages/hub/src/lib/commit.ts index bd623a96cd..a7acb3bcbe 100644 --- a/packages/hub/src/lib/commit.ts +++ b/packages/hub/src/lib/commit.ts @@ -18,10 +18,10 @@ import { promisesQueueStreaming } from "../utils/promisesQueueStreaming"; import { sha256 } from "../utils/sha256"; import { toRepoId } from "../utils/toRepoId"; import { WebBlob } from "../utils/WebBlob"; -import { createBlob } from "../utils/createBlob"; import { eventToGenerator } from "../utils/eventToGenerator"; import { base64FromBytes } from "../utils/base64FromBytes"; import { isFrontend } from "../utils/isFrontend"; +import { createBlobs } from "../utils/createBlobs"; const CONCURRENT_SHAS = 5; const CONCURRENT_LFS_UPLOADS = 5; @@ -73,9 +73,15 @@ export type CommitParams = { /** * Whether to use web workers to compute SHA256 hashes. * - * We load hash-wasm from a CDN inside the web worker. Not sure how to do otherwise and still have a "clean" bundle. + * @default false */ useWebWorkers?: boolean | { minSize?: number; poolSize?: number }; + /** + * Maximum depth of folders to upload. Files deeper than this will be ignored + * + * @default 5 + */ + maxFolderDepth?: number; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ @@ -144,27 +150,33 @@ export async function* commitIter(params: CommitParams): AsyncGenerator { - if (operation.operation !== "addOrUpdate") { - return operation; - } - - if (!(operation.content instanceof URL)) { - /** TS trick to enforce `content` to be a `Blob` */ - return { ...operation, content: operation.content }; - } - - const lazyBlob = await createBlob(operation.content, { fetch: params.fetch }); + const allOperations = ( + await Promise.all( + params.operations.map(async (operation) => { + if (operation.operation !== "addOrUpdate") { + return operation; + } - abortSignal?.throwIfAborted(); + if (!(operation.content instanceof URL)) { + /** TS trick to enforce `content` to be a `Blob` */ + return { ...operation, content: operation.content }; + } - return { - ...operation, - content: lazyBlob, - }; - }) - ); + const lazyBlobs = await createBlobs(operation.content, operation.path, { + fetch: params.fetch, + maxFolderDepth: params.maxFolderDepth, + }); + + abortSignal?.throwIfAborted(); + + return lazyBlobs.map((blob) => ({ + ...operation, + content: blob.blob, + path: blob.path, + })); + }) + ) + ).flat(1); const gitAttributes = allOperations.filter(isFileOperation).find((op) => op.path === ".gitattributes")?.content; diff --git a/packages/hub/src/lib/create-repo.spec.ts b/packages/hub/src/lib/create-repo.spec.ts index c1a39b9f81..92d4d6b51a 100644 --- a/packages/hub/src/lib/create-repo.spec.ts +++ b/packages/hub/src/lib/create-repo.spec.ts @@ -100,4 +100,4 @@ describe("createRepo", () => { credentials: { accessToken: TEST_ACCESS_TOKEN }, }); }); -}, 10_000); +}); diff --git a/packages/hub/src/lib/delete-files.spec.ts b/packages/hub/src/lib/delete-files.spec.ts index 558da6a6ba..8124d9afa0 100644 --- a/packages/hub/src/lib/delete-files.spec.ts +++ b/packages/hub/src/lib/delete-files.spec.ts @@ -78,4 +78,4 @@ describe("deleteFiles", () => { }); } }); -}, 10_000); +}); diff --git a/packages/hub/src/lib/upload-files-with-progress.ts b/packages/hub/src/lib/upload-files-with-progress.ts index e0e4c9d7f9..f0a0af2525 100644 --- a/packages/hub/src/lib/upload-files-with-progress.ts +++ b/packages/hub/src/lib/upload-files-with-progress.ts @@ -28,6 +28,7 @@ export async function* uploadFilesWithProgress( isPullRequest?: CommitParams["isPullRequest"]; parentCommit?: CommitParams["parentCommit"]; abortSignal?: CommitParams["abortSignal"]; + maxFolderDepth?: CommitParams["maxFolderDepth"]; /** * Set this to true in order to have progress events for hashing */ diff --git a/packages/hub/src/lib/upload-files.fs.spec.ts b/packages/hub/src/lib/upload-files.fs.spec.ts new file mode 100644 index 0000000000..415d71fd98 --- /dev/null +++ b/packages/hub/src/lib/upload-files.fs.spec.ts @@ -0,0 +1,71 @@ +import { assert, it, describe } from "vitest"; + +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { downloadFile } from "./download-file"; +import { uploadFiles } from "./upload-files"; +import { mkdir } from "fs/promises"; +import { writeFile } from "fs/promises"; +import { pathToFileURL } from "url"; +import { tmpdir } from "os"; + +describe("uploadFiles", () => { + it("should upload local folder", async () => { + const tmpDir = tmpdir(); + + await mkdir(`${tmpDir}/test-folder/sub`, { recursive: true }); + + await writeFile(`${tmpDir}/test-folder/sub/file1.txt`, "file1"); + await writeFile(`${tmpDir}/test-folder/sub/file2.txt`, "file2"); + + await writeFile(`${tmpDir}/test-folder/file3.txt`, "file3"); + await writeFile(`${tmpDir}/test-folder/file4.txt`, "file4"); + + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + repo, + hubUrl: TEST_HUB_URL, + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + await uploadFiles({ + accessToken: TEST_ACCESS_TOKEN, + repo, + files: [pathToFileURL(`${tmpDir}/test-folder`)], + hubUrl: TEST_HUB_URL, + }); + + let content = await downloadFile({ + repo, + path: "test-folder/sub/file1.txt", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file1"); + + content = await downloadFile({ + repo, + path: "test-folder/file3.txt", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), `file3`); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/packages/hub/src/lib/upload-files.spec.ts b/packages/hub/src/lib/upload-files.spec.ts index 89206c99c8..94258ad1b2 100644 --- a/packages/hub/src/lib/upload-files.spec.ts +++ b/packages/hub/src/lib/upload-files.spec.ts @@ -92,4 +92,4 @@ describe("uploadFiles", () => { }); } }); -}, 10_000); +}); diff --git a/packages/hub/src/lib/upload-files.ts b/packages/hub/src/lib/upload-files.ts index d205dba975..4eda11015e 100644 --- a/packages/hub/src/lib/upload-files.ts +++ b/packages/hub/src/lib/upload-files.ts @@ -14,6 +14,7 @@ export function uploadFiles( parentCommit?: CommitParams["parentCommit"]; fetch?: CommitParams["fetch"]; useWebWorkers?: CommitParams["useWebWorkers"]; + maxFolderDepth?: CommitParams["maxFolderDepth"]; abortSignal?: CommitParams["abortSignal"]; } & Partial ): Promise { diff --git a/packages/hub/src/utils/createBlobs.ts b/packages/hub/src/utils/createBlobs.ts new file mode 100644 index 0000000000..ebe63ed5d1 --- /dev/null +++ b/packages/hub/src/utils/createBlobs.ts @@ -0,0 +1,48 @@ +import { WebBlob } from "./WebBlob"; +import { isFrontend } from "./isFrontend"; + +/** + * This function allow to retrieve either a FileBlob or a WebBlob from a URL. + * + * From the backend: + * - support local files + * - support local folders + * - support http resources with absolute URLs + * + * From the frontend: + * - support http resources with absolute or relative URLs + */ +export async function createBlobs( + url: URL, + destPath: string, + opts?: { fetch?: typeof fetch; maxFolderDepth?: number } +): Promise> { + if (url.protocol === "http:" || url.protocol === "https:") { + const blob = await WebBlob.create(url, { fetch: opts?.fetch }); + return [{ path: destPath, blob }]; + } + + if (isFrontend) { + throw new TypeError(`Unsupported URL protocol "${url.protocol}"`); + } + + if (url.protocol === "file:") { + const { FileBlob } = await import("./FileBlob"); + const { subPaths } = await import("./sub-paths"); + const paths = await subPaths(url, opts?.maxFolderDepth); + + if (paths.length === 1 && paths[0].relativePath === ".") { + const blob = await FileBlob.create(url); + return [{ path: destPath, blob }]; + } + + return Promise.all( + paths.map(async (path) => ({ + path: `${destPath}/${path.relativePath}`.replace(/\/[.]$/, "").replaceAll("//", "/"), + blob: await FileBlob.create(new URL(path.path)), + })) + ); + } + + throw new TypeError(`Unsupported URL protocol "${url.protocol}"`); +} diff --git a/packages/hub/src/utils/sub-paths.spec.ts b/packages/hub/src/utils/sub-paths.spec.ts new file mode 100644 index 0000000000..6dcb773ab7 --- /dev/null +++ b/packages/hub/src/utils/sub-paths.spec.ts @@ -0,0 +1,39 @@ +import { mkdir, writeFile } from "fs/promises"; +import { tmpdir } from "os"; +import { describe, expect, it } from "vitest"; +import { subPaths } from "./sub-paths"; +import { pathToFileURL } from "url"; + +describe("sub-paths", () => { + it("should retrieve all sub-paths of a directory", async () => { + const tmpDir = tmpdir(); + + await mkdir(`${tmpDir}/test-dir/sub`, { recursive: true }); + + await writeFile(`${tmpDir}/test-dir/sub/file1.txt`, "file1"); + await writeFile(`${tmpDir}/test-dir/sub/file2.txt`, "file2"); + await writeFile(`${tmpDir}/test-dir/file3.txt`, "file3"); + await writeFile(`${tmpDir}/test-dir/file4.txt`, "file4"); + const result = await subPaths(pathToFileURL(`${tmpDir}/test-dir`)); + + expect(result).toEqual([ + { + path: pathToFileURL(`${tmpDir}/test-dir/file3.txt`), + relativePath: "file3.txt", + }, + { + path: pathToFileURL(`${tmpDir}/test-dir/file4.txt`), + relativePath: "file4.txt", + }, + + { + path: pathToFileURL(`${tmpDir}/test-dir/sub/file1.txt`), + relativePath: "sub/file1.txt", + }, + { + path: pathToFileURL(`${tmpDir}/test-dir/sub/file2.txt`), + relativePath: "sub/file2.txt", + }, + ]); + }); +}); diff --git a/packages/hub/src/utils/sub-paths.ts b/packages/hub/src/utils/sub-paths.ts new file mode 100644 index 0000000000..15682c14f8 --- /dev/null +++ b/packages/hub/src/utils/sub-paths.ts @@ -0,0 +1,38 @@ +import { readdir, stat } from "node:fs/promises"; +import { fileURLToPath, pathToFileURL } from "node:url"; + +/** + * Recursively retrieves all sub-paths of a given directory up to a specified depth. + */ +export async function subPaths( + path: URL, + maxDepth = 10 +): Promise< + Array<{ + path: URL; + relativePath: string; + }> +> { + const state = await stat(path); + if (!state.isDirectory()) { + return [{ path, relativePath: "." }]; + } + + const files = await readdir(path, { withFileTypes: true }); + const ret: Array<{ path: URL; relativePath: string }> = []; + for (const file of files) { + const filePath = pathToFileURL(fileURLToPath(path) + "/" + file.name); + if (file.isDirectory()) { + ret.push( + ...(await subPaths(filePath, maxDepth - 1)).map((subPath) => ({ + ...subPath, + relativePath: `${file.name}/${subPath.relativePath}`, + })) + ); + } else { + ret.push({ path: filePath, relativePath: file.name }); + } + } + + return ret; +} diff --git a/packages/hub/vitest-browser.config.mts b/packages/hub/vitest-browser.config.mts index e2e1e87f98..ad51e248d3 100644 --- a/packages/hub/vitest-browser.config.mts +++ b/packages/hub/vitest-browser.config.mts @@ -7,9 +7,11 @@ export default defineConfig({ ...configDefaults.exclude, "src/utils/FileBlob.spec.ts", "src/utils/symlink.spec.ts", + "src/utils/sub-paths.spec.ts", "src/lib/cache-management.spec.ts", "src/lib/download-file-to-cache-dir.spec.ts", "src/lib/snapshot-download.spec.ts", + "src/lib/upload-files.fs.spec.ts", // Because we use redirect: "manual" in the test "src/lib/oauth-handle-redirect.spec.ts", ],