Skip to content

Commit bb3142b

Browse files
committed
Add warning if downloading weights and HF cache not set
1 parent a7a5deb commit bb3142b

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

vec_inf/client/_slurm_script_generator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SLURM_SCRIPT_TEMPLATE,
1616
)
1717
from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME
18+
from vec_inf.client._utils import check_and_warn_hf_cache
1819

1920

2021
class SlurmScriptGenerator:
@@ -47,6 +48,11 @@ def __init__(self, params: dict[str, Any]):
4748
if self.model_weights_exists
4849
else self.params["model_name"]
4950
)
51+
check_and_warn_hf_cache(
52+
self.model_weights_exists,
53+
self.model_weights_path,
54+
self.params.get("env", {}),
55+
)
5056
self.env_str = self._generate_env_str()
5157

5258
def _generate_env_str(self) -> str:
@@ -253,6 +259,13 @@ def __init__(self, params: dict[str, Any]):
253259
self.params["models"][model_name]["model_source"] = (
254260
model_weights_path_str if model_weights_exists else model_name
255261
)
262+
check_and_warn_hf_cache(
263+
model_weights_exists,
264+
model_weights_path_str,
265+
self.params["models"][model_name].get("env", {}),
266+
model_name,
267+
)
268+
256269

257270
def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path:
258271
"""Write the generated Slurm script to the log directory.

vec_inf/client/_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,43 @@ def check_required_fields(params: dict[str, Any]) -> dict[str, Any]:
456456
f"{arg} is required, please set it in the command arguments or environment variables"
457457
)
458458
return env_overrides
459+
460+
461+
def check_and_warn_hf_cache(
462+
model_weights_exists: bool,
463+
model_weights_path: str,
464+
env_dict: dict[str, str],
465+
model_name: str | None = None,
466+
) -> None:
467+
"""Warn if model weights don't exist and HuggingFace cache directory is not set.
468+
469+
Parameters
470+
----------
471+
model_weights_exists : bool
472+
Whether the model weights exist at the expected path.
473+
model_weights_path : str
474+
The expected path to the model weights.
475+
env_dict : dict[str, str]
476+
Dictionary of environment variables to check (from --env parameter).
477+
model_name : str | None, optional
478+
Optional model name to include in the warning message (for batch mode).
479+
"""
480+
if model_weights_exists:
481+
return
482+
483+
hf_cache_vars = ["HF_HOME", "HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE"]
484+
hf_cache_set = any(
485+
os.environ.get(var) or env_dict.get(var) for var in hf_cache_vars
486+
)
487+
488+
if not hf_cache_set:
489+
model_prefix = f"Model weights for '{model_name}' " if model_name else "Model weights "
490+
warnings.warn(
491+
f"{model_prefix}not found at '{model_weights_path}' and no "
492+
f"HuggingFace cache directory is set (HF_HOME, HF_HUB_CACHE, or "
493+
f"HUGGINGFACE_HUB_CACHE). The model may be downloaded to your home "
494+
f"directory, which could consume your storage quota. Consider setting "
495+
f"one of these environment variables to a shared cache location.",
496+
UserWarning,
497+
stacklevel=4,
498+
)

0 commit comments

Comments
 (0)