@@ -638,6 +638,23 @@ def _submit_to_all_workers(
638638 else :
639639 return [task (* args , ** kwargs )]
640640
641+ def _download_hf_model_if_needed (self ,
642+ model_obj : _ModelWrapper ,
643+ revision : Optional [str ] = None ) -> Path :
644+ """Download a model from HF hub if needed.
645+
646+ Also updates the model_obj.model_dir with the local model dir on rank 0.
647+ """
648+ if model_obj .is_hub_model :
649+ model_dirs = self ._submit_to_all_workers (
650+ CachedModelLoader ._node_download_hf_model ,
651+ model = model_obj .model_name ,
652+ revision = revision )
653+ model_dir = model_dirs [0 ]
654+ model_obj .model_dir = model_dir
655+ return model_dir
656+ return model_obj .model_dir
657+
641658 def __call__ (self ) -> Tuple [Path , Union [Path , None ]]:
642659
643660 if self .llm_args .model_format is _ModelFormatKind .TLLM_ENGINE :
@@ -648,14 +665,9 @@ def __call__(self) -> Tuple[Path, Union[Path, None]]:
648665 self .model_loader = ModelLoader (self .llm_args )
649666
650667 # Download speculative model from HuggingFace if needed
651- if (self .model_loader .speculative_model_obj is not None
652- and self .model_loader .speculative_model_obj .is_hub_model ):
653- spec_model_dirs = self ._submit_to_all_workers (
654- CachedModelLoader ._node_download_hf_model ,
655- model = self .model_loader .speculative_model_obj .model_name ,
656- revision = None )
657- spec_model_dir = spec_model_dirs [0 ]
658- self .model_loader .speculative_model_obj .model_dir = spec_model_dir
668+ if self .model_loader .speculative_model_obj is not None :
669+ spec_model_dir = self ._download_hf_model_if_needed (
670+ self .model_loader .speculative_model_obj )
659671 # Update llm_args so PyTorch/AutoDeploy executor gets the local path
660672 if self .llm_args .speculative_config is not None :
661673 self .llm_args .speculative_config .speculative_model = spec_model_dir
@@ -668,14 +680,8 @@ def __call__(self) -> Tuple[Path, Union[Path, None]]:
668680 raise ValueError (
669681 f'backend { self .llm_args .backend } is not supported.' )
670682
671- if self .model_loader .model_obj .is_hub_model :
672- hf_model_dirs = self ._submit_to_all_workers (
673- CachedModelLoader ._node_download_hf_model ,
674- model = self .model_loader .model_obj .model_name ,
675- revision = self .llm_args .revision )
676- self ._hf_model_dir = hf_model_dirs [0 ]
677- else :
678- self ._hf_model_dir = self .model_loader .model_obj .model_dir
683+ self ._hf_model_dir = self ._download_hf_model_if_needed (
684+ self .model_loader .model_obj , revision = self .llm_args .revision )
679685
680686 if self .llm_args .quant_config .quant_algo is not None :
681687 logger .warning (
0 commit comments