Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions src/utils/hub.js
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,22 @@ function isValidUrl(string, protocols = null, validHosts = null) {
return true;
}

const REPO_ID_REGEX = /^(\b[\w\-.]+\b\/)?\b[\w\-.]{1,96}\b$/;

/**
* Tests whether a string is a valid Hugging Face model ID or not.
* Adapted from https://github.com/huggingface/huggingface_hub/blob/6378820ebb03f071988a96c7f3268f5bdf8f9449/src/huggingface_hub/utils/_validators.py#L119-L170
*
* @param {string} string The string to test
* @returns {boolean} True if the string is a valid model ID, false otherwise.
*/
function isValidHfModelId(string) {
if (!REPO_ID_REGEX.test(string)) return false;
if (string.includes("..") || string.includes("--")) return false;
if (string.endsWith(".git") || string.endsWith(".ipynb")) return false;
return true;
}

/**
* Helper function to get a file, using either the Fetch API or FileSystem API.
*
Expand Down Expand Up @@ -442,27 +458,28 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
}

const revision = options.revision ?? 'main';
const requestURL = pathJoin(path_or_repo_id, filename);

let requestURL = pathJoin(path_or_repo_id, filename);
let cachePath = pathJoin(env.localModelPath, requestURL);

let localPath = requestURL;
let remoteURL = pathJoin(
const validModelId = isValidHfModelId(path_or_repo_id);
const localPath = validModelId
? pathJoin(env.localModelPath, requestURL)
: requestURL;
const remoteURL = pathJoin(
env.remoteHost,
env.remotePathTemplate
.replaceAll('{model}', path_or_repo_id)
.replaceAll('{revision}', encodeURIComponent(revision)),
filename
);

// Choose cache key for filesystem cache
// When using the main revision (default), we use the request URL as the cache key.
// If a specific revision is requested, we account for this in the cache key.
let fsCacheKey = revision === 'main' ? requestURL : pathJoin(path_or_repo_id, revision, filename);

/** @type {string} */
let cacheKey;
let proposedCacheKey = cache instanceof FileCache ? fsCacheKey : remoteURL;
const proposedCacheKey = cache instanceof FileCache
// Choose cache key for filesystem cache
// When using the main revision (default), we use the request URL as the cache key.
// If a specific revision is requested, we account for this in the cache key.
? revision === 'main' ? requestURL : pathJoin(path_or_repo_id, revision, filename)
: remoteURL;

// Whether to cache the final response in the end.
let toCacheResponse = false;
Expand All @@ -475,11 +492,10 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
// 1. We first try to get from cache using the local path. In some environments (like deno),
// non-URL cache keys are not allowed. In these cases, `response` will be undefined.
// 2. If no response is found, we try to get from cache using the remote URL or file system cache.
response = await tryCache(cache, cachePath, proposedCacheKey);
response = await tryCache(cache, localPath, proposedCacheKey);
}

const cacheHit = response !== undefined;

if (response === undefined) {
// Caching not available, or file is not cached, so we perform the request

Expand All @@ -497,9 +513,9 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
console.warn(`Unable to load from local path "${localPath}": "${e}"`);
}
} else if (options.local_files_only) {
throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${localPath}.`);
throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${requestURL}.`);
} else if (!env.allowRemoteModels) {
throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${localPath}.`);
throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${requestURL}.`);
}
}

Expand All @@ -519,6 +535,11 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
return null;
}
}
if (!validModelId) {
// Before making any requests to the remote server, we check if the model ID is valid.
// This prevents unnecessary network requests for invalid model IDs.
throw Error(`Local file missing at "${localPath}" and download aborted due to invalid model ID "${path_or_repo_id}".`);
}

// File not found locally, so we try to download it from the remote server
response = await getFile(remoteURL);
Expand Down
12 changes: 12 additions & 0 deletions tests/utils/hub.test.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { AutoModel, PreTrainedModel } from "../../src/models.js";

import { MAX_TEST_EXECUTION_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js";
import fs from "fs";

// TODO: Set cache folder to a temp directory

Expand Down Expand Up @@ -36,5 +37,16 @@ describe("Hub", () => {
},
MAX_TEST_EXECUTION_TIME,
);

const localPath = "./models/hf-internal-testing/tiny-random-T5ForConditionalGeneration";
(fs.existsSync(localPath) ? it : it.skip)(
"should load a model from a local path",
async () => {
// 4. Ensure we can load a model from a local path
const model = await AutoModel.from_pretrained(localPath, DEFAULT_MODEL_OPTIONS);
expect(model).toBeInstanceOf(PreTrainedModel);
},
MAX_TEST_EXECUTION_TIME,
);
});
});