Skip to content

Commit 5ffc16d

Browse files
authored
feat: remove checkpointer from Automodel class (#1147)
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
1 parent 801f63e commit 5ffc16d

File tree

8 files changed

+147
-482
lines changed

8 files changed

+147
-482
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from nemo_automodel.components._peft.lora import apply_lora_to_linear_modules
4747
from nemo_automodel.components.checkpoint.checkpointing import (
4848
Checkpointer,
49+
CheckpointingConfig,
4950
_maybe_adapt_state_dict_to_hf,
5051
)
5152
from nemo_automodel.components.distributed.ddp import DDPManager
@@ -524,7 +525,6 @@ def apply_model_infrastructure(
524525
is_hf_model,
525526
is_meta_device,
526527
device,
527-
checkpointer,
528528
model_wrapper=None,
529529
tp_size=1,
530530
cp_size=1,
@@ -536,9 +536,9 @@ def apply_model_infrastructure(
536536
autopipeline=None,
537537
parallelize_fn=None,
538538
compile_config=None,
539-
model_name_or_path=None,
540539
load_base_model=False,
541540
cache_dir=None,
541+
pretrained_model_name_or_path="",
542542
**_kwargs,
543543
):
544544
"""Apply sharding, PEFT, quantization, and checkpoint loading to a model.
@@ -558,7 +558,6 @@ def apply_model_infrastructure(
558558
is_hf_model: Whether this is an HF model (vs custom implementation)
559559
is_meta_device: Whether model was initialized on meta device
560560
device: Target device for model
561-
checkpointer: Checkpointer instance for weight loading
562561
model_wrapper: Model wrapper (FSDP2Manager, DDPManager, etc.). Default: None
563562
tp_size: Tensor parallelism size. Default: 1
564563
cp_size: Context parallelism size. Default: 1
@@ -570,7 +569,7 @@ def apply_model_infrastructure(
570569
autopipeline: AutoPipeline instance for pipeline parallelism. Default: None
571570
parallelize_fn: Function to apply parallelization (EP + FSDP2). Default: None
572571
compile_config: Compilation configuration. Default: None
573-
model_name_or_path: Model name or path for checkpoint loading. Default: None
572+
pretrained_model_name_or_path: Model name or path for checkpoint loading. Default: None
574573
load_base_model: Whether to load base model weights (True for from_pretrained). Default: False
575574
cache_dir: Cache directory for model weights. Default: None
576575
**_kwargs: Additional keyword arguments (ignored, allows passing extra kwargs)
@@ -580,6 +579,24 @@ def apply_model_infrastructure(
580579
"""
581580
_verify_sdpa_support(model, is_hf_model, cp_size)
582581

582+
# Create a dummy checkpointer. We can pass in dummy values here since we are only loading the base weights.
583+
ckpt_config = CheckpointingConfig(
584+
enabled=True,
585+
checkpoint_dir="",
586+
model_save_format="safetensors",
587+
model_cache_dir=cache_dir,
588+
model_repo_id=pretrained_model_name_or_path,
589+
save_consolidated=True,
590+
is_peft=peft_config is not None,
591+
)
592+
checkpointer = Checkpointer(
593+
ckpt_config,
594+
0,
595+
0,
596+
0,
597+
getattr(model_wrapper, "moe_mesh", None) if model_wrapper else None,
598+
)
599+
583600
# Handle checkpointer config updates if checkpointer is provided
584601
dequantize_base_checkpoint = False
585602
if checkpointer is not None:
@@ -599,11 +616,10 @@ def apply_model_infrastructure(
599616
model, tp_size, autopipeline, peft_config, quantization_config, fp8_config, qat_quantizer
600617
)
601618

602-
# hold a list copy of the model state dict keys before any parallelization
603-
if checkpointer is not None:
604-
checkpointer.config.model_state_dict_keys = list(
605-
_maybe_adapt_state_dict_to_hf(model, model.state_dict(), quantization=dequantize_base_checkpoint).keys()
606-
)
619+
# hold a list copy of the model state dict keys before any parallelization. To be used during checkpoint saving in safetensors format.
620+
pre_shard_hf_state_dict_keys = list(
621+
_maybe_adapt_state_dict_to_hf(model, model.state_dict(), quantization=dequantize_base_checkpoint).keys()
622+
)
607623

608624
# Loss function check
609625
if not _supports_logits_to_keep(model) and not isinstance(loss_fn, MaskedCrossEntropy):
@@ -613,24 +629,26 @@ def apply_model_infrastructure(
613629
# Note: AutoPipeline takes care of applying PP + EP + FSDP. _shard_ep_fsdp will take care of applying EP + FSDP if no PP.
614630
if autopipeline is not None:
615631
model = _shard_pp(autopipeline, model, loss_fn, parallelize_fn)
632+
for part in model.parts:
633+
setattr(part, "_pre_shard_hf_state_dict_keys", pre_shard_hf_state_dict_keys)
616634
else:
617635
model = _shard_ep_fsdp(model, model_wrapper, parallelize_fn)
618636
if compile_config is not None:
619637
model = compile_model(model, compile_config)
638+
if isinstance(model_wrapper, DDPManager):
639+
setattr(model.module, "_pre_shard_hf_state_dict_keys", pre_shard_hf_state_dict_keys)
640+
else:
641+
setattr(model, "_pre_shard_hf_state_dict_keys", pre_shard_hf_state_dict_keys)
620642

621643
# Load the checkpoint if needed and return
622644
# Weights need to be loaded for meta device models that were parallelized:
623645
# 1. When parallelize_fn was used (which will internally apply FSDP2/EP sharding)
624646
# 2. When FSDP2Manager.parallelize was used (but not MegatronFSDP which handles weights internally)
625-
should_load_checkpoint = (
626-
is_meta_device
627-
and checkpointer is not None
628-
and any(
629-
[
630-
parallelize_fn is not None and get_world_size_safe() > 1,
631-
callable(getattr(model_wrapper, "parallelize", None)),
632-
]
633-
)
647+
should_load_checkpoint = is_meta_device and any(
648+
[
649+
parallelize_fn is not None and get_world_size_safe() > 1,
650+
callable(getattr(model_wrapper, "parallelize", None)),
651+
]
634652
)
635653
if should_load_checkpoint:
636654
models_to_load = model.parts if hasattr(model, "parts") else [model]
@@ -640,7 +658,7 @@ def apply_model_infrastructure(
640658
mp,
641659
device,
642660
cache_dir,
643-
model_name_or_path,
661+
pretrained_model_name_or_path,
644662
lora_a_init,
645663
load_base_model=load_base_model,
646664
)
@@ -778,7 +796,6 @@ def from_pretrained(
778796
model_wrapper=None,
779797
autopipeline: AutoPipeline | None = None,
780798
parallelize_fn: Callable | None = None,
781-
checkpointer: Optional[Checkpointer] = None,
782799
peft_config: Optional[dict] = None,
783800
fp8_config: Optional["FP8Config"] = None,
784801
qat_quantizer: Optional[Union["Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer"]] = None,
@@ -824,9 +841,6 @@ def from_pretrained(
824841
pipeline stages. Default: None.
825842
parallelize_fn (Callable | None, optional): Custom function to apply
826843
parallelization (EP + FSDP2). Default: None.
827-
checkpointer (Checkpointer, optional): Checkpointer instance for loading weights
828-
and enabling save_pretrained() functionality. Required for weight loading
829-
and checkpoint management.
830844
peft_config (dict | None, optional): PEFT/LoRA configuration dictionary.
831845
If provided, LoRA adapters will be applied to the model. Default: None.
832846
fp8_config (FP8Config | None, optional): FP8 quantization configuration.
@@ -882,7 +896,6 @@ def _retry(**override):
882896
fp8_config=fp8_config,
883897
qat_quantizer=qat_quantizer,
884898
loss_fn=loss_fn,
885-
checkpointer=checkpointer,
886899
compile_config=compile_config,
887900
model_wrapper=model_wrapper,
888901
**kwargs,
@@ -899,11 +912,10 @@ def _retry(**override):
899912
device = torch.cuda.current_device()
900913

901914
# Neither of these parallelization methods support meta device initialization
902-
# Also require checkpointer for meta device init, as we need it to load weights
903915
is_meta_device = (
904916
not isinstance(model_wrapper, (MegatronFSDPManager, DDPManager))
905917
and not force_hf
906-
and checkpointer is not None
918+
and get_world_size_safe() > 1
907919
)
908920
init_ctx = ContextManagers([no_init_weights(), init_empty_weights()]) if is_meta_device else nullcontext()
909921

@@ -948,10 +960,10 @@ def _retry(**override):
948960

949961
model = apply_model_infrastructure(
950962
model=model,
963+
pretrained_model_name_or_path=pretrained_model_name_or_path,
951964
is_hf_model=is_hf_model,
952965
cp_size=cp_size,
953966
tp_size=tp_size,
954-
checkpointer=checkpointer,
955967
peft_config=peft_config,
956968
quantization_config=quantization_config,
957969
fp8_config=fp8_config,
@@ -963,7 +975,6 @@ def _retry(**override):
963975
is_meta_device=is_meta_device,
964976
device=device,
965977
compile_config=compile_config,
966-
model_name_or_path=pretrained_model_name_or_path,
967978
load_base_model=True,
968979
cache_dir=kwargs.get("cache_dir", TRANSFORMERS_CACHE),
969980
)
@@ -990,7 +1001,6 @@ def from_config(
9901001
qat_quantizer: Optional[Union["Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer"]] = None,
9911002
loss_fn: Optional[Callable] = None,
9921003
compile_config: Optional["CompileConfig"] = None,
993-
checkpointer: Optional[Checkpointer] = None,
9941004
**kwargs,
9951005
) -> PreTrainedModel:
9961006
"""
@@ -1051,9 +1061,6 @@ def from_config(
10511061
it will be replaced with MaskedCrossEntropy. This is passed to AutoPipeline. Default: None.
10521062
compile_config (CompileConfig | None, optional): Configuration for torch.compile.
10531063
If provided, the model will be compiled for improved performance. Default: None.
1054-
checkpointer (Checkpointer, optional): Checkpointer instance for checkpoint
1055-
management and enabling save_pretrained() functionality. Required for
1056-
proper checkpoint handling.
10571064
**kwargs:
10581065
Additional keyword arguments. Notable ones include:
10591066
- tp_size (int): Tensor parallelism size. Default: 1.
@@ -1096,7 +1103,6 @@ def _retry(**override):
10961103
qat_quantizer=qat_quantizer,
10971104
loss_fn=loss_fn,
10981105
compile_config=compile_config,
1099-
checkpointer=checkpointer,
11001106
**kwargs,
11011107
)
11021108

@@ -1117,11 +1123,10 @@ def _retry(**override):
11171123
device = torch.cuda.current_device()
11181124

11191125
# Neither of these parallelization methods support meta device initialization
1120-
# Also require checkpointer for meta device init, as we need it to load weights
11211126
is_meta_device = (
11221127
not isinstance(model_wrapper, (MegatronFSDPManager, DDPManager))
11231128
and not force_hf
1124-
and checkpointer is not None
1129+
and get_world_size_safe() > 1
11251130
)
11261131
init_ctx = ContextManagers([no_init_weights(), init_empty_weights()]) if is_meta_device else nullcontext()
11271132

@@ -1162,7 +1167,6 @@ def _retry(**override):
11621167
is_hf_model=is_hf_model,
11631168
cp_size=cp_size,
11641169
tp_size=tp_size,
1165-
checkpointer=checkpointer,
11661170
peft_config=peft_config,
11671171
quantization_config=quantization_config,
11681172
fp8_config=fp8_config,
@@ -1174,7 +1178,7 @@ def _retry(**override):
11741178
is_meta_device=is_meta_device,
11751179
device=device,
11761180
compile_config=compile_config,
1177-
model_name_or_path=getattr(config, "name_or_path"),
1181+
pretrained_model_name_or_path=getattr(config, "name_or_path"),
11781182
load_base_model=False,
11791183
cache_dir=kwargs.get("cache_dir", TRANSFORMERS_CACHE),
11801184
)

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ class CheckpointingConfig:
8787
model_repo_id: str
8888
save_consolidated: bool
8989
is_peft: bool
90-
model_state_dict_keys: list[str] = None # copy of the model state dict keys before any parallelization
90+
model_state_dict_keys: list[str] = (
91+
None # copy of the model state dict keys before any parallelization. Kept for BW compatibility.
92+
)
9193
is_async: bool = False
9294
dequantize_base_checkpoint: bool | None = None
9395
original_model_root_dir: str | None = None
@@ -587,7 +589,11 @@ def _maybe_build_consolidated_index(
587589
# some HF models like Moonlight-16B have non-persistent buffers in the base checkpoint
588590
# however, HF initializes buffers with persistent=False, so we need to make sure these
589591
# buffer keys are not saved during checkpointing
590-
keys_to_remove = list(set(fqn_to_file_index_mapping.keys()) - set(self.config.model_state_dict_keys))
592+
# The `_pre_shard_hf_state_dict_keys` attribute is set in the `apply_model_infrastructure` in auto_model.py
593+
pre_shard_hf_state_dict_keys = (
594+
getattr(model, "_pre_shard_hf_state_dict_keys", None) or self.config.model_state_dict_keys
595+
)
596+
keys_to_remove = list(set(fqn_to_file_index_mapping.keys()) - set(pre_shard_hf_state_dict_keys))
591597
if model_state.is_tied_lm_head:
592598
keys_to_remove.append(model_state.lm_head_param_name)
593599
for key in keys_to_remove:

nemo_automodel/recipes/llm/kd.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import torch
4444
import wandb
4545
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
46-
from transformers.utils import TRANSFORMERS_CACHE
4746

4847
from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer
4948
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
@@ -77,10 +76,6 @@ def _build_teacher_model(
7776
cp_size=1,
7877
parallelize_fn=None,
7978
device=None,
80-
dp_rank=0,
81-
tp_rank=0,
82-
pp_rank=0,
83-
moe_mesh=None,
8479
):
8580
"""Build and initialize the teacher model for knowledge distillation.
8681
@@ -104,37 +99,17 @@ def _build_teacher_model(
10499
The `offload_teacher_model` config option is not supported with this approach.
105100
Device placement is handled internally by NeMoAutoModelForCausalLM infrastructure.
106101
"""
107-
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
108102

109103
assert cfg_teacher is not None, "`teacher_model` section missing from YAML config"
110104
logger.info("Instantiating teacher model")
111105

112-
# Create a simple checkpointer for the teacher (just for weight loading)
113-
teacher_checkpointer = Checkpointer(
114-
CheckpointingConfig(
115-
model_repo_id=cfg_teacher.get("pretrained_model_name_or_path"),
116-
model_cache_dir=cfg_teacher.get("cache_dir", TRANSFORMERS_CACHE),
117-
# Dummy values
118-
is_peft=False,
119-
enabled=False,
120-
checkpoint_dir="",
121-
model_save_format="safetensors",
122-
save_consolidated=False,
123-
),
124-
dp_rank=dp_rank,
125-
tp_rank=tp_rank,
126-
pp_rank=pp_rank,
127-
moe_mesh=moe_mesh,
128-
)
129-
130106
# Build teacher model using the same infrastructure as student
131107
# but without PEFT/FP8/QAT (teacher should be frozen in full precision)
132108
with ScopedRNG(seed=seed, ranked=True):
133109
kwargs: Dict[str, Any] = {
134110
"tp_size": tp_size,
135111
"cp_size": cp_size,
136112
"has_packed_sequence": has_packed_sequence,
137-
"checkpointer": teacher_checkpointer,
138113
"model_wrapper": model_wrapper,
139114
"parallelize_fn": parallelize_fn,
140115
}
@@ -196,10 +171,6 @@ def setup(self): # noqa: C901 – same complexity as parent
196171
cp_size=self.cfg.get("distributed.cp_size", 1),
197172
parallelize_fn=getattr(self.cfg.get("parallelizer", None), "instantiate", None),
198173
device=teacher_device,
199-
dp_rank=self._get_dp_rank(include_cp=True),
200-
tp_rank=self._get_tp_rank(),
201-
pp_rank=self._get_pp_rank(),
202-
moe_mesh=self.moe_mesh,
203174
)
204175
logger.info("Teacher Model: " + str(self.teacher_model))
205176
# KD

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def build_model_and_optimizer(
133133
cfg_peft,
134134
model_wrapper,
135135
seed,
136-
checkpointer: Checkpointer,
137136
has_packed_sequence=False,
138137
tp_size=1,
139138
cp_size=1,
@@ -174,7 +173,6 @@ def build_model_and_optimizer(
174173
"has_packed_sequence": has_packed_sequence,
175174
"autopipeline": autopipeline,
176175
"parallelize_fn": parallelize_fn,
177-
"checkpointer": checkpointer,
178176
"peft_config": cfg_peft,
179177
"model_wrapper": model_wrapper,
180178
"loss_fn": loss_fn,
@@ -214,7 +212,7 @@ def build_model_and_optimizer(
214212
is_hf_model=False,
215213
is_meta_device=False,
216214
device=torch.cuda.current_device(),
217-
model_name_or_path=None,
215+
pretrained_model_name_or_path=None,
218216
load_base_model=False,
219217
cache_dir=TRANSFORMERS_CACHE,
220218
**kwargs,
@@ -923,7 +921,6 @@ def setup(self):
923921
autopipeline=autopipeline,
924922
loss_fn=self.loss_fn,
925923
parallelize_fn=parallelize_fn,
926-
checkpointer=self.checkpointer,
927924
)
928925

929926
if isinstance(model, AutoPipeline):

nemo_automodel/recipes/llm/train_seq_cls.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def setup(self):
118118
autopipeline=None,
119119
loss_fn=self.loss_fn,
120120
parallelize_fn=None,
121-
checkpointer=self.checkpointer,
122121
unfreeze_modules=["classifier"] if self.peft_config is not None else None,
123122
)
124123

0 commit comments

Comments
 (0)