Skip to content

Commit 8a5f5f4

Browse files
Move download to shared helper
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
1 parent ab62f87 commit 8a5f5f4

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,23 @@ def _submit_to_all_workers(
638638
else:
639639
return [task(*args, **kwargs)]
640640

641+
def _download_hf_model_if_needed(self,
642+
model_obj: _ModelWrapper,
643+
revision: Optional[str] = None) -> Path:
644+
"""Download a model from HF hub if needed.
645+
646+
Also updates the model_obj.model_dir with the local model dir on rank 0.
647+
"""
648+
if model_obj.is_hub_model:
649+
model_dirs = self._submit_to_all_workers(
650+
CachedModelLoader._node_download_hf_model,
651+
model=model_obj.model_name,
652+
revision=revision)
653+
model_dir = model_dirs[0]
654+
model_obj.model_dir = model_dir
655+
return model_dir
656+
return model_obj.model_dir
657+
641658
def __call__(self) -> Tuple[Path, Union[Path, None]]:
642659

643660
if self.llm_args.model_format is _ModelFormatKind.TLLM_ENGINE:
@@ -648,14 +665,9 @@ def __call__(self) -> Tuple[Path, Union[Path, None]]:
648665
self.model_loader = ModelLoader(self.llm_args)
649666

650667
# Download speculative model from HuggingFace if needed
651-
if (self.model_loader.speculative_model_obj is not None
652-
and self.model_loader.speculative_model_obj.is_hub_model):
653-
spec_model_dirs = self._submit_to_all_workers(
654-
CachedModelLoader._node_download_hf_model,
655-
model=self.model_loader.speculative_model_obj.model_name,
656-
revision=None)
657-
spec_model_dir = spec_model_dirs[0]
658-
self.model_loader.speculative_model_obj.model_dir = spec_model_dir
668+
if self.model_loader.speculative_model_obj is not None:
669+
spec_model_dir = self._download_hf_model_if_needed(
670+
self.model_loader.speculative_model_obj)
659671
# Update llm_args so PyTorch/AutoDeploy executor gets the local path
660672
if self.llm_args.speculative_config is not None:
661673
self.llm_args.speculative_config.speculative_model = spec_model_dir
@@ -668,14 +680,8 @@ def __call__(self) -> Tuple[Path, Union[Path, None]]:
668680
raise ValueError(
669681
f'backend {self.llm_args.backend} is not supported.')
670682

671-
if self.model_loader.model_obj.is_hub_model:
672-
hf_model_dirs = self._submit_to_all_workers(
673-
CachedModelLoader._node_download_hf_model,
674-
model=self.model_loader.model_obj.model_name,
675-
revision=self.llm_args.revision)
676-
self._hf_model_dir = hf_model_dirs[0]
677-
else:
678-
self._hf_model_dir = self.model_loader.model_obj.model_dir
683+
self._hf_model_dir = self._download_hf_model_if_needed(
684+
self.model_loader.model_obj, revision=self.llm_args.revision)
679685

680686
if self.llm_args.quant_config.quant_algo is not None:
681687
logger.warning(

0 commit comments

Comments
 (0)