Skip to content

Commit 1841c3c

Browse files
authored
fix: prevent hang in in download_model_weights (#991)
Signed-off-by: Hemil Desai <[email protected]>
1 parent ed04d5d commit 1841c3c

File tree

2 files changed

+2
-14
lines changed

2 files changed

+2
-14
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from typing import List, Optional, Union
2222

2323
import torch
24-
import torch.distributed as dist
2524
from torch.nn.attention import SDPBackend, sdpa_kernel
2625
from transformers import (
2726
AutoConfig,
@@ -38,7 +37,6 @@
3837
from nemo_automodel import __version__
3938
from nemo_automodel._transformers.registry import ModelRegistry
4039
from nemo_automodel.components.distributed.init_utils import (
41-
get_local_rank_preinit,
4240
get_local_world_size_preinit,
4341
get_world_size_safe,
4442
)
@@ -219,13 +217,11 @@ def _verify_sdpa_support(model, is_hf_model, cp_size):
219217

220218

221219
def _download_model_weights(hf_config, pretrained_model_name_or_path):
222-
if (not dist.is_initialized() or get_local_rank_preinit() == 0) and not os.path.isdir(
223-
pretrained_model_name_or_path
224-
):
220+
if not os.path.isdir(pretrained_model_name_or_path):
225221
num_nodes = (get_world_size_safe() % get_local_world_size_preinit()) + 1 # 1-indexed
226222
if num_nodes > 1:
227223
logging.info(
228-
f"""Downloading model weights on {num_nodes} nodes. This incurs high storage usage.
224+
f"""Downloading model weights on {num_nodes} nodes. This incurs high storage usage.
229225
It is recommended to download once with `hf download` and pass in the downloaded path to the `pretrained_model_name_or_path` argument."""
230226
)
231227
# Import via module reference (vs bound name) so unit tests can patch

tests/unit_tests/_transformers/test_auto_model.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,6 @@ def test_from_pretrained_registry_downloads_checkpoint_files_rank0(self):
149149
patch.object(transformers.AutoModelForCausalLM, "from_pretrained") as mock_hf_loader,
150150
patch("nemo_automodel._transformers.auto_model._get_resolved_checkpoint_files") as mock_get_files,
151151
patch("nemo_automodel._transformers.auto_model.os.path.isdir", return_value=False),
152-
patch("nemo_automodel._transformers.auto_model.dist.is_initialized", return_value=True),
153-
patch("nemo_automodel._transformers.auto_model.dist.get_world_size", return_value=1),
154-
patch("nemo_automodel._transformers.auto_model.dist.get_rank", return_value=0),
155152
patch("nemo_automodel.components.distributed.utils.FirstRankPerNode") as mock_barrier,
156153
):
157154
# Prepare a fake config with architectures and commit hash
@@ -187,9 +184,6 @@ def test_from_pretrained_registry_downloads_when_dist_uninitialized(self):
187184
patch.object(transformers.AutoModelForCausalLM, "from_pretrained") as mock_hf_loader,
188185
patch("nemo_automodel._transformers.auto_model._get_resolved_checkpoint_files") as mock_get_files,
189186
patch("nemo_automodel._transformers.auto_model.os.path.isdir", return_value=False),
190-
patch("nemo_automodel._transformers.auto_model.dist.is_initialized", return_value=False),
191-
patch("nemo_automodel._transformers.auto_model.dist.get_world_size", return_value=1),
192-
patch("nemo_automodel._transformers.auto_model.dist.barrier") as mock_barrier,
193187
):
194188
# Prepare a fake config with architectures and commit hash
195189
cfg = Mock()
@@ -213,8 +207,6 @@ def test_from_pretrained_registry_downloads_when_dist_uninitialized(self):
213207
_, kwargs = mock_get_files.call_args
214208
assert kwargs["pretrained_model_name_or_path"] == "dummy/repo-id"
215209
assert kwargs["commit_hash"] == "commit456"
216-
# No barrier when dist not initialized
217-
mock_barrier.assert_not_called()
218210

219211
def test_from_config_happy_path(self):
220212
"""Test the basic from_config functionality works."""

0 commit comments

Comments
 (0)