4646from nemo_automodel .components ._peft .lora import apply_lora_to_linear_modules
4747from nemo_automodel .components .checkpoint .checkpointing import (
4848 Checkpointer ,
49+ CheckpointingConfig ,
4950 _maybe_adapt_state_dict_to_hf ,
5051)
5152from nemo_automodel .components .distributed .ddp import DDPManager
@@ -524,7 +525,6 @@ def apply_model_infrastructure(
524525 is_hf_model ,
525526 is_meta_device ,
526527 device ,
527- checkpointer ,
528528 model_wrapper = None ,
529529 tp_size = 1 ,
530530 cp_size = 1 ,
@@ -539,6 +539,7 @@ def apply_model_infrastructure(
539539 model_name_or_path = None ,
540540 load_base_model = False ,
541541 cache_dir = None ,
542+ pretrained_model_name_or_path = "" ,
542543 ** _kwargs ,
543544):
544545 """Apply sharding, PEFT, quantization, and checkpoint loading to a model.
@@ -558,7 +559,6 @@ def apply_model_infrastructure(
558559 is_hf_model: Whether this is an HF model (vs custom implementation)
559560 is_meta_device: Whether model was initialized on meta device
560561 device: Target device for model
561- checkpointer: Checkpointer instance for weight loading
562562 model_wrapper: Model wrapper (FSDP2Manager, DDPManager, etc.). Default: None
563563 tp_size: Tensor parallelism size. Default: 1
564564 cp_size: Context parallelism size. Default: 1
@@ -580,6 +580,24 @@ def apply_model_infrastructure(
580580 """
581581 _verify_sdpa_support (model , is_hf_model , cp_size )
582582
583+ # Create a dummy checkpointer. We can pass in dummy values here since we are only loading the base weights.
584+ ckpt_config = CheckpointingConfig (
585+ enabled = True ,
586+ checkpoint_dir = "" ,
587+ model_save_format = "safetensors" ,
588+ model_cache_dir = cache_dir ,
589+ model_repo_id = pretrained_model_name_or_path ,
590+ save_consolidated = True ,
591+ is_peft = peft_config is not None ,
592+ )
593+ checkpointer = Checkpointer (
594+ ckpt_config ,
595+ 0 ,
596+ 0 ,
597+ 0 ,
598+ getattr (model_wrapper , "moe_mesh" , None ) if model_wrapper else None ,
599+ )
600+
583601 # Handle checkpointer config updates if checkpointer is provided
584602 dequantize_base_checkpoint = False
585603 if checkpointer is not None :
@@ -600,10 +618,9 @@ def apply_model_infrastructure(
600618 )
601619
602620 # hold a list copy of the model state dict keys before any parallelization
603- if checkpointer is not None :
604- checkpointer .config .model_state_dict_keys = list (
605- _maybe_adapt_state_dict_to_hf (model , model .state_dict (), quantization = dequantize_base_checkpoint ).keys ()
606- )
621+ checkpointer .config .model_state_dict_keys = list (
622+ _maybe_adapt_state_dict_to_hf (model , model .state_dict (), quantization = dequantize_base_checkpoint ).keys ()
623+ )
607624
608625 # Loss function check
609626 if not _supports_logits_to_keep (model ) and not isinstance (loss_fn , MaskedCrossEntropy ):
@@ -622,15 +639,11 @@ def apply_model_infrastructure(
622639 # Weights need to be loaded for meta device models that were parallelized:
623640 # 1. When parallelize_fn was used (which will internally apply FSDP2/EP sharding)
624641 # 2. When FSDP2Manager.parallelize was used (but not MegatronFSDP which handles weights internally)
625- should_load_checkpoint = (
626- is_meta_device
627- and checkpointer is not None
628- and any (
629- [
630- parallelize_fn is not None and get_world_size_safe () > 1 ,
631- callable (getattr (model_wrapper , "parallelize" , None )),
632- ]
633- )
642+ should_load_checkpoint = is_meta_device and any (
643+ [
644+ parallelize_fn is not None and get_world_size_safe () > 1 ,
645+ callable (getattr (model_wrapper , "parallelize" , None )),
646+ ]
634647 )
635648 if should_load_checkpoint :
636649 models_to_load = model .parts if hasattr (model , "parts" ) else [model ]
@@ -778,7 +791,6 @@ def from_pretrained(
778791 model_wrapper = None ,
779792 autopipeline : AutoPipeline | None = None ,
780793 parallelize_fn : Callable | None = None ,
781- checkpointer : Optional [Checkpointer ] = None ,
782794 peft_config : Optional [dict ] = None ,
783795 fp8_config : Optional ["FP8Config" ] = None ,
784796 qat_quantizer : Optional [Union ["Int4WeightOnlyQATQuantizer" , "Int8DynActInt4WeightQATQuantizer" ]] = None ,
@@ -824,9 +836,6 @@ def from_pretrained(
824836 pipeline stages. Default: None.
825837 parallelize_fn (Callable | None, optional): Custom function to apply
826838 parallelization (EP + FSDP2). Default: None.
827- checkpointer (Checkpointer, optional): Checkpointer instance for loading weights
828- and enabling save_pretrained() functionality. Required for weight loading
829- and checkpoint management.
830839 peft_config (dict | None, optional): PEFT/LoRA configuration dictionary.
831840 If provided, LoRA adapters will be applied to the model. Default: None.
832841 fp8_config (FP8Config | None, optional): FP8 quantization configuration.
@@ -882,7 +891,6 @@ def _retry(**override):
882891 fp8_config = fp8_config ,
883892 qat_quantizer = qat_quantizer ,
884893 loss_fn = loss_fn ,
885- checkpointer = checkpointer ,
886894 compile_config = compile_config ,
887895 model_wrapper = model_wrapper ,
888896 ** kwargs ,
@@ -899,12 +907,7 @@ def _retry(**override):
899907 device = torch .cuda .current_device ()
900908
901909 # Neither of these parallelization methods support meta device initialization
902- # Also require checkpointer for meta device init, as we need it to load weights
903- is_meta_device = (
904- not isinstance (model_wrapper , (MegatronFSDPManager , DDPManager ))
905- and not force_hf
906- and checkpointer is not None
907- )
910+ is_meta_device = not isinstance (model_wrapper , (MegatronFSDPManager , DDPManager )) and not force_hf
908911 init_ctx = ContextManagers ([no_init_weights (), init_empty_weights ()]) if is_meta_device else nullcontext ()
909912
910913 try :
@@ -948,10 +951,10 @@ def _retry(**override):
948951
949952 model = apply_model_infrastructure (
950953 model = model ,
954+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
951955 is_hf_model = is_hf_model ,
952956 cp_size = cp_size ,
953957 tp_size = tp_size ,
954- checkpointer = checkpointer ,
955958 peft_config = peft_config ,
956959 quantization_config = quantization_config ,
957960 fp8_config = fp8_config ,
@@ -990,7 +993,6 @@ def from_config(
990993 qat_quantizer : Optional [Union ["Int4WeightOnlyQATQuantizer" , "Int8DynActInt4WeightQATQuantizer" ]] = None ,
991994 loss_fn : Optional [Callable ] = None ,
992995 compile_config : Optional ["CompileConfig" ] = None ,
993- checkpointer : Optional [Checkpointer ] = None ,
994996 ** kwargs ,
995997 ) -> PreTrainedModel :
996998 """
@@ -1051,9 +1053,6 @@ def from_config(
10511053 it will be replaced with MaskedCrossEntropy. This is passed to AutoPipeline. Default: None.
10521054 compile_config (CompileConfig | None, optional): Configuration for torch.compile.
10531055 If provided, the model will be compiled for improved performance. Default: None.
1054- checkpointer (Checkpointer, optional): Checkpointer instance for checkpoint
1055- management and enabling save_pretrained() functionality. Required for
1056- proper checkpoint handling.
10571056 **kwargs:
10581057 Additional keyword arguments. Notable ones include:
10591058 - tp_size (int): Tensor parallelism size. Default: 1.
@@ -1096,7 +1095,6 @@ def _retry(**override):
10961095 qat_quantizer = qat_quantizer ,
10971096 loss_fn = loss_fn ,
10981097 compile_config = compile_config ,
1099- checkpointer = checkpointer ,
11001098 ** kwargs ,
11011099 )
11021100
@@ -1117,12 +1115,7 @@ def _retry(**override):
11171115 device = torch .cuda .current_device ()
11181116
11191117 # Neither of these parallelization methods support meta device initialization
1120- # Also require checkpointer for meta device init, as we need it to load weights
1121- is_meta_device = (
1122- not isinstance (model_wrapper , (MegatronFSDPManager , DDPManager ))
1123- and not force_hf
1124- and checkpointer is not None
1125- )
1118+ is_meta_device = not isinstance (model_wrapper , (MegatronFSDPManager , DDPManager )) and not force_hf
11261119 init_ctx = ContextManagers ([no_init_weights (), init_empty_weights ()]) if is_meta_device else nullcontext ()
11271120
11281121 try :
@@ -1162,7 +1155,6 @@ def _retry(**override):
11621155 is_hf_model = is_hf_model ,
11631156 cp_size = cp_size ,
11641157 tp_size = tp_size ,
1165- checkpointer = checkpointer ,
11661158 peft_config = peft_config ,
11671159 quantization_config = quantization_config ,
11681160 fp8_config = fp8_config ,
0 commit comments