4646from nemo_automodel .components ._peft .lora import apply_lora_to_linear_modules
4747from nemo_automodel .components .checkpoint .checkpointing import (
4848 Checkpointer ,
49+ CheckpointingConfig ,
4950 _maybe_adapt_state_dict_to_hf ,
5051)
5152from nemo_automodel .components .distributed .ddp import DDPManager
@@ -524,7 +525,6 @@ def apply_model_infrastructure(
524525 is_hf_model ,
525526 is_meta_device ,
526527 device ,
527- checkpointer ,
528528 model_wrapper = None ,
529529 tp_size = 1 ,
530530 cp_size = 1 ,
@@ -536,9 +536,9 @@ def apply_model_infrastructure(
536536 autopipeline = None ,
537537 parallelize_fn = None ,
538538 compile_config = None ,
539- model_name_or_path = None ,
540539 load_base_model = False ,
541540 cache_dir = None ,
541+ pretrained_model_name_or_path = "" ,
542542 ** _kwargs ,
543543):
544544 """Apply sharding, PEFT, quantization, and checkpoint loading to a model.
@@ -558,7 +558,6 @@ def apply_model_infrastructure(
558558 is_hf_model: Whether this is an HF model (vs custom implementation)
559559 is_meta_device: Whether model was initialized on meta device
560560 device: Target device for model
561- checkpointer: Checkpointer instance for weight loading
562561 model_wrapper: Model wrapper (FSDP2Manager, DDPManager, etc.). Default: None
563562 tp_size: Tensor parallelism size. Default: 1
564563 cp_size: Context parallelism size. Default: 1
@@ -570,7 +569,7 @@ def apply_model_infrastructure(
570569 autopipeline: AutoPipeline instance for pipeline parallelism. Default: None
571570 parallelize_fn: Function to apply parallelization (EP + FSDP2). Default: None
572571 compile_config: Compilation configuration. Default: None
573- model_name_or_path : Model name or path for checkpoint loading. Default: None
572+ pretrained_model_name_or_path : Model name or path for checkpoint loading. Default: None
574573 load_base_model: Whether to load base model weights (True for from_pretrained). Default: False
575574 cache_dir: Cache directory for model weights. Default: None
576575 **_kwargs: Additional keyword arguments (ignored, allows passing extra kwargs)
@@ -580,6 +579,24 @@ def apply_model_infrastructure(
580579 """
581580 _verify_sdpa_support (model , is_hf_model , cp_size )
582581
582+ # Create a dummy checkpointer. We can pass in dummy values here since we are only loading the base weights.
583+ ckpt_config = CheckpointingConfig (
584+ enabled = True ,
585+ checkpoint_dir = "" ,
586+ model_save_format = "safetensors" ,
587+ model_cache_dir = cache_dir ,
588+ model_repo_id = pretrained_model_name_or_path ,
589+ save_consolidated = True ,
590+ is_peft = peft_config is not None ,
591+ )
592+ checkpointer = Checkpointer (
593+ ckpt_config ,
594+ 0 ,
595+ 0 ,
596+ 0 ,
597+ getattr (model_wrapper , "moe_mesh" , None ) if model_wrapper else None ,
598+ )
599+
583600 # Handle checkpointer config updates if checkpointer is provided
584601 dequantize_base_checkpoint = False
585602 if checkpointer is not None :
@@ -599,11 +616,10 @@ def apply_model_infrastructure(
599616 model , tp_size , autopipeline , peft_config , quantization_config , fp8_config , qat_quantizer
600617 )
601618
602- # hold a list copy of the model state dict keys before any parallelization
603- if checkpointer is not None :
604- checkpointer .config .model_state_dict_keys = list (
605- _maybe_adapt_state_dict_to_hf (model , model .state_dict (), quantization = dequantize_base_checkpoint ).keys ()
606- )
619+ # hold a list copy of the model state dict keys before any parallelization. To be used during checkpoint saving in safetensors format.
620+ pre_shard_hf_state_dict_keys = list (
621+ _maybe_adapt_state_dict_to_hf (model , model .state_dict (), quantization = dequantize_base_checkpoint ).keys ()
622+ )
607623
608624 # Loss function check
609625 if not _supports_logits_to_keep (model ) and not isinstance (loss_fn , MaskedCrossEntropy ):
@@ -613,24 +629,26 @@ def apply_model_infrastructure(
613629 # Note: AutoPipeline takes care of applying PP + EP + FSDP. _shard_ep_fsdp will take care of applying EP + FSDP if no PP.
614630 if autopipeline is not None :
615631 model = _shard_pp (autopipeline , model , loss_fn , parallelize_fn )
632+ for part in model .parts :
633+ setattr (part , "_pre_shard_hf_state_dict_keys" , pre_shard_hf_state_dict_keys )
616634 else :
617635 model = _shard_ep_fsdp (model , model_wrapper , parallelize_fn )
618636 if compile_config is not None :
619637 model = compile_model (model , compile_config )
638+ if isinstance (model_wrapper , DDPManager ):
639+ setattr (model .module , "_pre_shard_hf_state_dict_keys" , pre_shard_hf_state_dict_keys )
640+ else :
641+ setattr (model , "_pre_shard_hf_state_dict_keys" , pre_shard_hf_state_dict_keys )
620642
621643 # Load the checkpoint if needed and return
622644 # Weights need to be loaded for meta device models that were parallelized:
623645 # 1. When parallelize_fn was used (which will internally apply FSDP2/EP sharding)
624646 # 2. When FSDP2Manager.parallelize was used (but not MegatronFSDP which handles weights internally)
625- should_load_checkpoint = (
626- is_meta_device
627- and checkpointer is not None
628- and any (
629- [
630- parallelize_fn is not None and get_world_size_safe () > 1 ,
631- callable (getattr (model_wrapper , "parallelize" , None )),
632- ]
633- )
647+ should_load_checkpoint = is_meta_device and any (
648+ [
649+ parallelize_fn is not None and get_world_size_safe () > 1 ,
650+ callable (getattr (model_wrapper , "parallelize" , None )),
651+ ]
634652 )
635653 if should_load_checkpoint :
636654 models_to_load = model .parts if hasattr (model , "parts" ) else [model ]
@@ -640,7 +658,7 @@ def apply_model_infrastructure(
640658 mp ,
641659 device ,
642660 cache_dir ,
643- model_name_or_path ,
661+ pretrained_model_name_or_path ,
644662 lora_a_init ,
645663 load_base_model = load_base_model ,
646664 )
@@ -778,7 +796,6 @@ def from_pretrained(
778796 model_wrapper = None ,
779797 autopipeline : AutoPipeline | None = None ,
780798 parallelize_fn : Callable | None = None ,
781- checkpointer : Optional [Checkpointer ] = None ,
782799 peft_config : Optional [dict ] = None ,
783800 fp8_config : Optional ["FP8Config" ] = None ,
784801 qat_quantizer : Optional [Union ["Int4WeightOnlyQATQuantizer" , "Int8DynActInt4WeightQATQuantizer" ]] = None ,
@@ -824,9 +841,6 @@ def from_pretrained(
824841 pipeline stages. Default: None.
825842 parallelize_fn (Callable | None, optional): Custom function to apply
826843 parallelization (EP + FSDP2). Default: None.
827- checkpointer (Checkpointer, optional): Checkpointer instance for loading weights
828- and enabling save_pretrained() functionality. Required for weight loading
829- and checkpoint management.
830844 peft_config (dict | None, optional): PEFT/LoRA configuration dictionary.
831845 If provided, LoRA adapters will be applied to the model. Default: None.
832846 fp8_config (FP8Config | None, optional): FP8 quantization configuration.
@@ -882,7 +896,6 @@ def _retry(**override):
882896 fp8_config = fp8_config ,
883897 qat_quantizer = qat_quantizer ,
884898 loss_fn = loss_fn ,
885- checkpointer = checkpointer ,
886899 compile_config = compile_config ,
887900 model_wrapper = model_wrapper ,
888901 ** kwargs ,
@@ -899,11 +912,10 @@ def _retry(**override):
899912 device = torch .cuda .current_device ()
900913
901914 # Neither of these parallelization methods support meta device initialization
902- # Also require checkpointer for meta device init, as we need it to load weights
903915 is_meta_device = (
904916 not isinstance (model_wrapper , (MegatronFSDPManager , DDPManager ))
905917 and not force_hf
906- and checkpointer is not None
918+ and get_world_size_safe () > 1
907919 )
908920 init_ctx = ContextManagers ([no_init_weights (), init_empty_weights ()]) if is_meta_device else nullcontext ()
909921
@@ -948,10 +960,10 @@ def _retry(**override):
948960
949961 model = apply_model_infrastructure (
950962 model = model ,
963+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
951964 is_hf_model = is_hf_model ,
952965 cp_size = cp_size ,
953966 tp_size = tp_size ,
954- checkpointer = checkpointer ,
955967 peft_config = peft_config ,
956968 quantization_config = quantization_config ,
957969 fp8_config = fp8_config ,
@@ -963,7 +975,6 @@ def _retry(**override):
963975 is_meta_device = is_meta_device ,
964976 device = device ,
965977 compile_config = compile_config ,
966- model_name_or_path = pretrained_model_name_or_path ,
967978 load_base_model = True ,
968979 cache_dir = kwargs .get ("cache_dir" , TRANSFORMERS_CACHE ),
969980 )
@@ -990,7 +1001,6 @@ def from_config(
9901001 qat_quantizer : Optional [Union ["Int4WeightOnlyQATQuantizer" , "Int8DynActInt4WeightQATQuantizer" ]] = None ,
9911002 loss_fn : Optional [Callable ] = None ,
9921003 compile_config : Optional ["CompileConfig" ] = None ,
993- checkpointer : Optional [Checkpointer ] = None ,
9941004 ** kwargs ,
9951005 ) -> PreTrainedModel :
9961006 """
@@ -1051,9 +1061,6 @@ def from_config(
10511061 it will be replaced with MaskedCrossEntropy. This is passed to AutoPipeline. Default: None.
10521062 compile_config (CompileConfig | None, optional): Configuration for torch.compile.
10531063 If provided, the model will be compiled for improved performance. Default: None.
1054- checkpointer (Checkpointer, optional): Checkpointer instance for checkpoint
1055- management and enabling save_pretrained() functionality. Required for
1056- proper checkpoint handling.
10571064 **kwargs:
10581065 Additional keyword arguments. Notable ones include:
10591066 - tp_size (int): Tensor parallelism size. Default: 1.
@@ -1096,7 +1103,6 @@ def _retry(**override):
10961103 qat_quantizer = qat_quantizer ,
10971104 loss_fn = loss_fn ,
10981105 compile_config = compile_config ,
1099- checkpointer = checkpointer ,
11001106 ** kwargs ,
11011107 )
11021108
@@ -1117,11 +1123,10 @@ def _retry(**override):
11171123 device = torch .cuda .current_device ()
11181124
11191125 # Neither of these parallelization methods support meta device initialization
1120- # Also require checkpointer for meta device init, as we need it to load weights
11211126 is_meta_device = (
11221127 not isinstance (model_wrapper , (MegatronFSDPManager , DDPManager ))
11231128 and not force_hf
1124- and checkpointer is not None
1129+ and get_world_size_safe () > 1
11251130 )
11261131 init_ctx = ContextManagers ([no_init_weights (), init_empty_weights ()]) if is_meta_device else nullcontext ()
11271132
@@ -1162,7 +1167,6 @@ def _retry(**override):
11621167 is_hf_model = is_hf_model ,
11631168 cp_size = cp_size ,
11641169 tp_size = tp_size ,
1165- checkpointer = checkpointer ,
11661170 peft_config = peft_config ,
11671171 quantization_config = quantization_config ,
11681172 fp8_config = fp8_config ,
@@ -1174,7 +1178,7 @@ def _retry(**override):
11741178 is_meta_device = is_meta_device ,
11751179 device = device ,
11761180 compile_config = compile_config ,
1177- model_name_or_path = getattr (config , "name_or_path" ),
1181+ pretrained_model_name_or_path = getattr (config , "name_or_path" ),
11781182 load_base_model = False ,
11791183 cache_dir = kwargs .get ("cache_dir" , TRANSFORMERS_CACHE ),
11801184 )
0 commit comments