2323import torch
2424import torch .distributed as dist
2525from torch .nn .attention import SDPBackend , sdpa_kernel
26- from transformers import (
27- AutoConfig ,
28- AutoModelForCausalLM ,
29- AutoModelForImageTextToText ,
30- AutoModelForSequenceClassification ,
31- AutoModelForTextToWaveform ,
32- PreTrainedModel ,
33- )
34- from transformers .modeling_utils import _get_resolved_checkpoint_files
35- from transformers .models .auto .auto_factory import _BaseAutoModelClass
3626
3727from nemo_automodel import __version__
3828from nemo_automodel ._transformers .registry import ModelRegistry
4131 get_local_world_size_preinit ,
4232 get_world_size_safe ,
4333)
34+ from nemo_automodel .components .distributed .utils import FirstRankPerNode
4435from nemo_automodel .components .utils .model_utils import resolve_trust_remote_code
4536from nemo_automodel .shared .import_utils import safe_import
4637from nemo_automodel .shared .utils import dtype_from_str
38+ from transformers import (
39+ AutoConfig ,
40+ AutoModelForCausalLM ,
41+ AutoModelForImageTextToText ,
42+ AutoModelForSequenceClassification ,
43+ AutoModelForTextToWaveform ,
44+ PreTrainedModel ,
45+ )
46+ from transformers .modeling_utils import _get_resolved_checkpoint_files
47+ from transformers .models .auto .auto_factory import _BaseAutoModelClass
4748
4849HAS_LIGER_KERNEL , liger_kernel_trf = safe_import ("liger_kernel.transformers" )
4950logger = logging .getLogger (__name__ )
@@ -216,6 +217,38 @@ def _verify_sdpa_support(model, is_hf_model, cp_size):
216217 raise ValueError ("Model does not support SDPA required for context parallelism" )
217218
218219
220+ def _download_model_weights (hf_config , pretrained_model_name_or_path ):
221+ if (not dist .is_initialized () or get_local_rank_preinit () == 0 ) and not os .path .isdir (
222+ pretrained_model_name_or_path
223+ ):
224+ num_nodes = (get_world_size_safe () % get_local_world_size_preinit ()) + 1 # 1-indexed
225+ if num_nodes > 1 :
226+ logging .info (
227+ f"""Downloading model weights on { num_nodes } nodes. This incurs high storage usage.
228+ It is recommended to download once with `hf download` and pass in the downloaded path to the `pretrained_model_name_or_path` argument."""
229+ )
230+ with FirstRankPerNode ():
231+ _get_resolved_checkpoint_files (
232+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
233+ subfolder = "" ,
234+ variant = None ,
235+ gguf_file = None ,
236+ from_tf = False ,
237+ from_flax = False ,
238+ use_safetensors = None ,
239+ cache_dir = None ,
240+ force_download = False ,
241+ proxies = None ,
242+ local_files_only = False ,
243+ token = None ,
244+ user_agent = {"file_type" : "model" , "framework" : "pytorch" , "from_auto_class" : False },
245+ revision = "main" ,
246+ commit_hash = getattr (hf_config , "_commit_hash" , None ),
247+ is_remote_code = False ,
248+ transformers_explicit_filename = None ,
249+ )
250+
251+
219252class _BaseNeMoAutoModelClass (_BaseAutoModelClass ):
220253 """
221254 Drop-in replacement for ``_BaseAutoModelClass`` that includes custom-kernels.
@@ -238,6 +271,24 @@ class _BaseNeMoAutoModelClass(_BaseAutoModelClass):
238271 Liger patch. Unsupported models will silently fall back.
239272 """
240273
274+ @classmethod
275+ def _from_pretrained_parent_class (cls , * args , ** kwargs ):
276+ name = cls .__name__
277+ if name .startswith ("NeMo" ):
278+ cls .__name__ = name [4 :]
279+ model = super ().from_pretrained (* args , ** kwargs )
280+ cls .__name__ = name
281+ return model
282+
283+ @classmethod
284+ def _from_config_parent_class (cls , * args , ** kwargs ):
285+ name = cls .__name__
286+ if name .startswith ("NeMo" ):
287+ cls .__name__ = name [4 :]
288+ model = super ().from_config (* args , ** kwargs )
289+ cls .__name__ = name
290+ return model
291+
241292 @classmethod
242293 def from_pretrained (
243294 cls ,
@@ -324,66 +375,36 @@ def _retry(**override):
324375 ** kwargs ,
325376 )
326377
327- # load model
378+ # 1. if force_hf is True, we will use the parent class to load and return the model as is
379+ if force_hf :
380+ return _BaseNeMoAutoModelClass ._from_pretrained_parent_class (
381+ pretrained_model_name_or_path ,
382+ * model_args ,
383+ torch_dtype = torch_dtype ,
384+ attn_implementation = attn_implementation ,
385+ ** kwargs ,
386+ )
387+
388+ # 2. If we have a custom model implementation available, we prioritize that over HF
389+ if hf_config .architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
390+ # if we are able to init the custom model, we will now download the model weights on local rank 0
391+ _download_model_weights (hf_config , pretrained_model_name_or_path )
392+ logger .info (f"Using custom model implementation for { hf_config .architectures [0 ]} " )
393+ kwargs .pop ("trust_remote_code" , None )
394+ return ModelRegistry .model_arch_name_to_cls [hf_config .architectures [0 ]](hf_config , * model_args , ** kwargs )
395+
396+ # 3. fallback to parent class
328397 model = None
329398 try :
330- name = cls .__name__
331- if name .startswith ("NeMo" ):
332- cls .__name__ = name [4 :]
333- if not force_hf :
334- try :
335- # if we have a custom model implementation available, we prioritize that over HF
336- if hf_config .architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
337- kwargs .pop ("trust_remote_code" , None )
338- model = ModelRegistry .model_arch_name_to_cls [hf_config .architectures [0 ]](
339- hf_config , * model_args , ** kwargs
340- )
341- # if we are able to init the custom model, we will now download the model weights on local rank 0
342- if (not dist .is_initialized () or get_local_rank_preinit () == 0 ) and not os .path .isdir (
343- pretrained_model_name_or_path
344- ):
345- num_nodes = (get_world_size_safe () % get_local_world_size_preinit ()) + 1 # 1-indexed
346- if num_nodes > 1 :
347- logging .info (
348- f"""Downloading model weights on { num_nodes } nodes. This incurs high storage usage.
349- It is recommended to download once with `hf download` and pass in the downloaded path to the `pretrained_model_name_or_path` argument."""
350- )
351- _get_resolved_checkpoint_files (
352- pretrained_model_name_or_path = pretrained_model_name_or_path ,
353- subfolder = "" ,
354- variant = None ,
355- gguf_file = None ,
356- from_tf = False ,
357- from_flax = False ,
358- use_safetensors = None ,
359- cache_dir = None ,
360- force_download = False ,
361- proxies = None ,
362- local_files_only = False ,
363- token = None ,
364- user_agent = {"file_type" : "model" , "framework" : "pytorch" , "from_auto_class" : False },
365- revision = "main" ,
366- commit_hash = getattr (hf_config , "_commit_hash" , None ),
367- is_remote_code = False ,
368- transformers_explicit_filename = None ,
369- )
370- if dist .is_initialized ():
371- dist .barrier ()
372- logger .info (f"Using custom model implementation for { hf_config .architectures [0 ]} " )
373- return model
374- except Exception as e :
375- logger .error (f"Failed to use custom model implementation with error: { e } " )
376-
377399 if quantization_config is not None :
378400 kwargs ["quantization_config" ] = quantization_config
379- model = super (). from_pretrained (
401+ model = _BaseNeMoAutoModelClass . _from_pretrained_parent_class (
380402 pretrained_model_name_or_path ,
381403 * model_args ,
382404 torch_dtype = torch_dtype ,
383405 attn_implementation = attn_implementation ,
384406 ** kwargs ,
385407 )
386- cls .__name__ = name
387408 except ValueError as e :
388409 if "does not support" in str (e ):
389410 if model is not None :
@@ -499,35 +520,32 @@ def _retry(**override):
499520 ** kwargs ,
500521 )
501522
502- # load model
523+ # 1. if force_hf is True, we will use the parent class to load and return the model as is
524+ if force_hf :
525+ return _BaseNeMoAutoModelClass ._from_config_parent_class (
526+ config ,
527+ * model_args ,
528+ torch_dtype = torch_dtype ,
529+ attn_implementation = attn_implementation ,
530+ ** kwargs ,
531+ )
532+
533+ # 2. If we have a custom model implementation available, we prioritize that over HF
534+ if config .architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
535+ raise NotImplementedError ("Custom model implementation is not supported for from_config" )
536+
537+ # 3. fallback to parent class
503538 model = None
504539 try :
505- name = cls .__name__
506- if name .startswith ("NeMo" ):
507- cls .__name__ = name [4 :]
508- if not force_hf :
509- try :
510- # if we have a custom model implementation available, we prioritize that over HF
511- if config .architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
512- kwargs .pop ("trust_remote_code" , None )
513- model = ModelRegistry .model_arch_name_to_cls [config .architectures [0 ]](
514- config , * model_args , ** kwargs
515- )
516- logger .info (f"Using custom model implementation for { config .architectures [0 ]} " )
517- return model
518- except Exception as e :
519- logger .error (f"Failed to use custom model implementation with error: { e } " )
520-
521540 if quantization_config is not None :
522541 kwargs ["quantization_config" ] = quantization_config
523- model = super (). from_config (
542+ model = _BaseNeMoAutoModelClass . _from_config_parent_class (
524543 config ,
525544 * model_args ,
526- attn_implementation = attn_implementation ,
527545 torch_dtype = torch_dtype ,
546+ attn_implementation = attn_implementation ,
528547 ** kwargs ,
529548 )
530- cls .__name__ = name
531549 except ValueError as e :
532550 if "does not support" in str (e ):
533551 logging .warning ("Falling back to eager attention." )
0 commit comments