Skip to content

Commit 7994035

Browse files
committed
fix
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent d22ea6a commit 7994035

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
import torch.distributed as dist
2525
from torch.nn.attention import SDPBackend, sdpa_kernel
2626

27+
import nemo_automodel.components.distributed.utils as dist_utils
2728
from nemo_automodel import __version__
2829
from nemo_automodel._transformers.registry import ModelRegistry
2930
from nemo_automodel.components.distributed.init_utils import (
3031
get_local_rank_preinit,
3132
get_local_world_size_preinit,
3233
get_world_size_safe,
3334
)
34-
from nemo_automodel.components.distributed.utils import FirstRankPerNode
3535
from nemo_automodel.components.utils.model_utils import resolve_trust_remote_code
3636
from nemo_automodel.shared.import_utils import safe_import
3737
from nemo_automodel.shared.utils import dtype_from_str
@@ -227,7 +227,9 @@ def _download_model_weights(hf_config, pretrained_model_name_or_path):
227227
f"""Downloading model weights on {num_nodes} nodes. This incurs high storage usage.
228228
It is recommended to download once with `hf download` and pass in the downloaded path to the `pretrained_model_name_or_path` argument."""
229229
)
230-
with FirstRankPerNode():
230+
# Import via module reference (vs bound name) so unit tests can patch
231+
# `nemo_automodel.components.distributed.utils.FirstRankPerNode`.
232+
with dist_utils.FirstRankPerNode():
231233
_get_resolved_checkpoint_files(
232234
pretrained_model_name_or_path=pretrained_model_name_or_path,
233235
subfolder="",
@@ -248,6 +250,14 @@ def _download_model_weights(hf_config, pretrained_model_name_or_path):
248250
transformers_explicit_filename=None,
249251
)
250252

253+
def get_architectures(hf_config):
254+
"""
255+
Get the architectures from the HF config.
256+
"""
257+
architectures = []
258+
if hasattr(hf_config, "architectures"):
259+
architectures = hf_config.architectures or []
260+
return architectures
251261

252262
class _BaseNeMoAutoModelClass(_BaseAutoModelClass):
253263
"""
@@ -277,6 +287,10 @@ def _from_pretrained_parent_class(cls, *args, **kwargs):
277287
if name.startswith("NeMo"):
278288
cls.__name__ = name[4:]
279289
model = super().from_pretrained(*args, **kwargs)
290+
# Some HF entrypoints (or tests/mocks) may return (model, unused_kwargs).
291+
# Our NeMo wrappers always expect a model instance.
292+
if isinstance(model, tuple) and len(model) == 2:
293+
model, _ = model
280294
cls.__name__ = name
281295
return model
282296

@@ -286,6 +300,10 @@ def _from_config_parent_class(cls, *args, **kwargs):
286300
if name.startswith("NeMo"):
287301
cls.__name__ = name[4:]
288302
model = super().from_config(*args, **kwargs)
303+
# Some HF entrypoints (or tests/mocks) may return (model, unused_kwargs).
304+
# Our NeMo wrappers always expect a model instance.
305+
if isinstance(model, tuple) and len(model) == 2:
306+
model, _ = model
289307
cls.__name__ = name
290308
return model
291309

@@ -377,30 +395,32 @@ def _retry(**override):
377395

378396
# 1. if force_hf is True, we will use the parent class to load and return the model as is
379397
if force_hf:
380-
return _BaseNeMoAutoModelClass._from_pretrained_parent_class(
398+
return cls._from_pretrained_parent_class(
381399
pretrained_model_name_or_path,
382400
*model_args,
401+
config=hf_config,
383402
torch_dtype=torch_dtype,
384403
attn_implementation=attn_implementation,
385404
**kwargs,
386405
)
387-
406+
architectures = get_architectures(hf_config)
388407
# 2. If we have a custom model implementation available, we prioritize that over HF
389-
if hf_config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
408+
if len(architectures) > 0 and architectures[0] in ModelRegistry.model_arch_name_to_cls:
390409
# if we are able to init the custom model, we will now download the model weights on local rank 0
391410
_download_model_weights(hf_config, pretrained_model_name_or_path)
392-
logger.info(f"Using custom model implementation for {hf_config.architectures[0]}")
411+
logger.info(f"Using custom model implementation for {architectures[0]}")
393412
kwargs.pop("trust_remote_code", None)
394-
return ModelRegistry.model_arch_name_to_cls[hf_config.architectures[0]](hf_config, *model_args, **kwargs)
413+
return ModelRegistry.model_arch_name_to_cls[architectures[0]](hf_config, *model_args, **kwargs)
395414

396415
# 3. fallback to parent class
397416
model = None
398417
try:
399418
if quantization_config is not None:
400419
kwargs["quantization_config"] = quantization_config
401-
model = _BaseNeMoAutoModelClass._from_pretrained_parent_class(
420+
model = cls._from_pretrained_parent_class(
402421
pretrained_model_name_or_path,
403422
*model_args,
423+
config=hf_config,
404424
torch_dtype=torch_dtype,
405425
attn_implementation=attn_implementation,
406426
**kwargs,
@@ -522,7 +542,7 @@ def _retry(**override):
522542

523543
# 1. if force_hf is True, we will use the parent class to load and return the model as is
524544
if force_hf:
525-
return _BaseNeMoAutoModelClass._from_config_parent_class(
545+
return cls._from_config_parent_class(
526546
config,
527547
*model_args,
528548
torch_dtype=torch_dtype,
@@ -531,15 +551,16 @@ def _retry(**override):
531551
)
532552

533553
# 2. If we have a custom model implementation available, we prioritize that over HF
534-
if config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
535-
raise NotImplementedError("Custom model implementation is not supported for from_config")
554+
architectures = get_architectures(config)
555+
if len(architectures) > 0 and architectures[0] in ModelRegistry.model_arch_name_to_cls:
556+
return ModelRegistry.model_arch_name_to_cls[architectures[0]](config, *model_args, **kwargs)
536557

537558
# 3. fallback to parent class
538559
model = None
539560
try:
540561
if quantization_config is not None:
541562
kwargs["quantization_config"] = quantization_config
542-
model = _BaseNeMoAutoModelClass._from_config_parent_class(
563+
model = cls._from_config_parent_class(
543564
config,
544565
*model_args,
545566
torch_dtype=torch_dtype,

0 commit comments

Comments
 (0)