2424import torch .distributed as dist
2525from torch .nn .attention import SDPBackend , sdpa_kernel
2626
27+ import nemo_automodel .components .distributed .utils as dist_utils
2728from nemo_automodel import __version__
2829from nemo_automodel ._transformers .registry import ModelRegistry
2930from nemo_automodel .components .distributed .init_utils import (
3031 get_local_rank_preinit ,
3132 get_local_world_size_preinit ,
3233 get_world_size_safe ,
3334)
34- from nemo_automodel .components .distributed .utils import FirstRankPerNode
3535from nemo_automodel .components .utils .model_utils import resolve_trust_remote_code
3636from nemo_automodel .shared .import_utils import safe_import
3737from nemo_automodel .shared .utils import dtype_from_str
@@ -227,7 +227,9 @@ def _download_model_weights(hf_config, pretrained_model_name_or_path):
227227 f"""Downloading model weights on { num_nodes } nodes. This incurs high storage usage.
228228 It is recommended to download once with `hf download` and pass in the downloaded path to the `pretrained_model_name_or_path` argument."""
229229 )
230- with FirstRankPerNode ():
230+ # Import via module reference (vs bound name) so unit tests can patch
231+ # `nemo_automodel.components.distributed.utils.FirstRankPerNode`.
232+ with dist_utils .FirstRankPerNode ():
231233 _get_resolved_checkpoint_files (
232234 pretrained_model_name_or_path = pretrained_model_name_or_path ,
233235 subfolder = "" ,
@@ -248,6 +250,14 @@ def _download_model_weights(hf_config, pretrained_model_name_or_path):
248250 transformers_explicit_filename = None ,
249251 )
250252
253+ def get_architectures (hf_config ):
254+ """
255+ Get the architectures from the HF config.
256+ """
257+ architectures = []
258+ if hasattr (hf_config , "architectures" ):
259+ architectures = hf_config .architectures or []
260+ return architectures
251261
252262class _BaseNeMoAutoModelClass (_BaseAutoModelClass ):
253263 """
@@ -277,6 +287,10 @@ def _from_pretrained_parent_class(cls, *args, **kwargs):
277287 if name .startswith ("NeMo" ):
278288 cls .__name__ = name [4 :]
279289 model = super ().from_pretrained (* args , ** kwargs )
290+ # Some HF entrypoints (or tests/mocks) may return (model, unused_kwargs).
291+ # Our NeMo wrappers always expect a model instance.
292+ if isinstance (model , tuple ) and len (model ) == 2 :
293+ model , _ = model
280294 cls .__name__ = name
281295 return model
282296
@@ -286,6 +300,10 @@ def _from_config_parent_class(cls, *args, **kwargs):
286300 if name .startswith ("NeMo" ):
287301 cls .__name__ = name [4 :]
288302 model = super ().from_config (* args , ** kwargs )
303+ # Some HF entrypoints (or tests/mocks) may return (model, unused_kwargs).
304+ # Our NeMo wrappers always expect a model instance.
305+ if isinstance (model , tuple ) and len (model ) == 2 :
306+ model , _ = model
289307 cls .__name__ = name
290308 return model
291309
@@ -377,30 +395,32 @@ def _retry(**override):
377395
378396 # 1. if force_hf is True, we will use the parent class to load and return the model as is
379397 if force_hf :
380- return _BaseNeMoAutoModelClass ._from_pretrained_parent_class (
398+ return cls ._from_pretrained_parent_class (
381399 pretrained_model_name_or_path ,
382400 * model_args ,
401+ config = hf_config ,
383402 torch_dtype = torch_dtype ,
384403 attn_implementation = attn_implementation ,
385404 ** kwargs ,
386405 )
387-
406+ architectures = get_architectures ( hf_config )
388407 # 2. If we have a custom model implementation available, we prioritize that over HF
389- if hf_config . architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
408+ if len ( architectures ) > 0 and architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
390409 # if we are able to init the custom model, we will now download the model weights on local rank 0
391410 _download_model_weights (hf_config , pretrained_model_name_or_path )
392- logger .info (f"Using custom model implementation for { hf_config . architectures [0 ]} " )
411+ logger .info (f"Using custom model implementation for { architectures [0 ]} " )
393412 kwargs .pop ("trust_remote_code" , None )
394- return ModelRegistry .model_arch_name_to_cls [hf_config . architectures [0 ]](hf_config , * model_args , ** kwargs )
413+ return ModelRegistry .model_arch_name_to_cls [architectures [0 ]](hf_config , * model_args , ** kwargs )
395414
396415 # 3. fallback to parent class
397416 model = None
398417 try :
399418 if quantization_config is not None :
400419 kwargs ["quantization_config" ] = quantization_config
401- model = _BaseNeMoAutoModelClass ._from_pretrained_parent_class (
420+ model = cls ._from_pretrained_parent_class (
402421 pretrained_model_name_or_path ,
403422 * model_args ,
423+ config = hf_config ,
404424 torch_dtype = torch_dtype ,
405425 attn_implementation = attn_implementation ,
406426 ** kwargs ,
@@ -522,7 +542,7 @@ def _retry(**override):
522542
523543 # 1. if force_hf is True, we will use the parent class to load and return the model as is
524544 if force_hf :
525- return _BaseNeMoAutoModelClass ._from_config_parent_class (
545+ return cls ._from_config_parent_class (
526546 config ,
527547 * model_args ,
528548 torch_dtype = torch_dtype ,
@@ -531,15 +551,16 @@ def _retry(**override):
531551 )
532552
533553 # 2. If we have a custom model implementation available, we prioritize that over HF
534- if config .architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
535- raise NotImplementedError ("Custom model implementation is not supported for from_config" )
554+ architectures = get_architectures (config )
555+ if len (architectures ) > 0 and architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
556+ return ModelRegistry .model_arch_name_to_cls [architectures [0 ]](config , * model_args , ** kwargs )
536557
537558 # 3. fallback to parent class
538559 model = None
539560 try :
540561 if quantization_config is not None :
541562 kwargs ["quantization_config" ] = quantization_config
542- model = _BaseNeMoAutoModelClass ._from_config_parent_class (
563+ model = cls ._from_config_parent_class (
543564 config ,
544565 * model_args ,
545566 torch_dtype = torch_dtype ,
0 commit comments