diff --git a/src/utils/hub.js b/src/utils/hub.js index fe44bd614..84232fa04 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -190,6 +190,22 @@ function isValidUrl(string, protocols = null, validHosts = null) { return true; } +const REPO_ID_REGEX = /^(\b[\w\-.]+\b\/)?\b[\w\-.]{1,96}\b$/; + +/** + * Tests whether a string is a valid Hugging Face model ID or not. + * Adapted from https://github.com/huggingface/huggingface_hub/blob/6378820ebb03f071988a96c7f3268f5bdf8f9449/src/huggingface_hub/utils/_validators.py#L119-L170 + * + * @param {string} string The string to test + * @returns {boolean} True if the string is a valid model ID, false otherwise. + */ +function isValidHfModelId(string) { + if (!REPO_ID_REGEX.test(string)) return false; + if (string.includes("..") || string.includes("--")) return false; + if (string.endsWith(".git") || string.endsWith(".ipynb")) return false; + return true; +} + /** * Helper function to get a file, using either the Fetch API or FileSystem API. * @@ -442,12 +458,13 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti } const revision = options.revision ?? 'main'; + const requestURL = pathJoin(path_or_repo_id, filename); - let requestURL = pathJoin(path_or_repo_id, filename); - let cachePath = pathJoin(env.localModelPath, requestURL); - - let localPath = requestURL; - let remoteURL = pathJoin( + const validModelId = isValidHfModelId(path_or_repo_id); + const localPath = validModelId + ? pathJoin(env.localModelPath, requestURL) + : requestURL; + const remoteURL = pathJoin( env.remoteHost, env.remotePathTemplate .replaceAll('{model}', path_or_repo_id) @@ -455,14 +472,14 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti filename ); - // Choose cache key for filesystem cache - // When using the main revision (default), we use the request URL as the cache key. - // If a specific revision is requested, we account for this in the cache key. - let fsCacheKey = revision === 'main' ? requestURL : pathJoin(path_or_repo_id, revision, filename); - /** @type {string} */ let cacheKey; - let proposedCacheKey = cache instanceof FileCache ? fsCacheKey : remoteURL; + const proposedCacheKey = cache instanceof FileCache + // Choose cache key for filesystem cache + // When using the main revision (default), we use the request URL as the cache key. + // If a specific revision is requested, we account for this in the cache key. + ? revision === 'main' ? requestURL : pathJoin(path_or_repo_id, revision, filename) + : remoteURL; // Whether to cache the final response in the end. let toCacheResponse = false; @@ -475,11 +492,10 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // 1. We first try to get from cache using the local path. In some environments (like deno), // non-URL cache keys are not allowed. In these cases, `response` will be undefined. // 2. If no response is found, we try to get from cache using the remote URL or file system cache. - response = await tryCache(cache, cachePath, proposedCacheKey); + response = await tryCache(cache, localPath, proposedCacheKey); } const cacheHit = response !== undefined; - if (response === undefined) { // Caching not available, or file is not cached, so we perform the request @@ -497,9 +513,9 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti console.warn(`Unable to load from local path "${localPath}": "${e}"`); } } else if (options.local_files_only) { - throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${localPath}.`); + throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${requestURL}.`); } else if (!env.allowRemoteModels) { - throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${localPath}.`); + throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${requestURL}.`); } } @@ -519,6 +535,11 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti return null; } } + if (!validModelId) { + // Before making any requests to the remote server, we check if the model ID is valid. + // This prevents unnecessary network requests for invalid model IDs. + throw Error(`Local file missing at "${localPath}" and download aborted due to invalid model ID "${path_or_repo_id}".`); + } // File not found locally, so we try to download it from the remote server response = await getFile(remoteURL); diff --git a/tests/utils/hub.test.js b/tests/utils/hub.test.js index 3ef3f41f7..f819efc35 100644 --- a/tests/utils/hub.test.js +++ b/tests/utils/hub.test.js @@ -1,6 +1,7 @@ import { AutoModel, PreTrainedModel } from "../../src/models.js"; import { MAX_TEST_EXECUTION_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js"; +import fs from "fs"; // TODO: Set cache folder to a temp directory @@ -36,5 +37,16 @@ describe("Hub", () => { }, MAX_TEST_EXECUTION_TIME, ); + + const localPath = "./models/hf-internal-testing/tiny-random-T5ForConditionalGeneration"; + (fs.existsSync(localPath) ? it : it.skip)( + "should load a model from a local path", + async () => { + // 4. Ensure we can load a model from a local path + const model = await AutoModel.from_pretrained(localPath, DEFAULT_MODEL_OPTIONS); + expect(model).toBeInstanceOf(PreTrainedModel); + }, + MAX_TEST_EXECUTION_TIME, + ); }); });