Skip to content

Commit a3589a9

Browse files
committed
refactor
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
1 parent 801f63e commit a3589a9

File tree

7 files changed

+128
-481
lines changed

7 files changed

+128
-481
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from nemo_automodel.components._peft.lora import apply_lora_to_linear_modules
4747
from nemo_automodel.components.checkpoint.checkpointing import (
4848
Checkpointer,
49+
CheckpointingConfig,
4950
_maybe_adapt_state_dict_to_hf,
5051
)
5152
from nemo_automodel.components.distributed.ddp import DDPManager
@@ -524,7 +525,6 @@ def apply_model_infrastructure(
524525
is_hf_model,
525526
is_meta_device,
526527
device,
527-
checkpointer,
528528
model_wrapper=None,
529529
tp_size=1,
530530
cp_size=1,
@@ -539,6 +539,7 @@ def apply_model_infrastructure(
539539
model_name_or_path=None,
540540
load_base_model=False,
541541
cache_dir=None,
542+
pretrained_model_name_or_path="",
542543
**_kwargs,
543544
):
544545
"""Apply sharding, PEFT, quantization, and checkpoint loading to a model.
@@ -558,7 +559,6 @@ def apply_model_infrastructure(
558559
is_hf_model: Whether this is an HF model (vs custom implementation)
559560
is_meta_device: Whether model was initialized on meta device
560561
device: Target device for model
561-
checkpointer: Checkpointer instance for weight loading
562562
model_wrapper: Model wrapper (FSDP2Manager, DDPManager, etc.). Default: None
563563
tp_size: Tensor parallelism size. Default: 1
564564
cp_size: Context parallelism size. Default: 1
@@ -580,6 +580,24 @@ def apply_model_infrastructure(
580580
"""
581581
_verify_sdpa_support(model, is_hf_model, cp_size)
582582

583+
# Create a dummy checkpointer. We can pass in dummy values here since we are only loading the base weights.
584+
ckpt_config = CheckpointingConfig(
585+
enabled=True,
586+
checkpoint_dir="",
587+
model_save_format="safetensors",
588+
model_cache_dir=cache_dir,
589+
model_repo_id=pretrained_model_name_or_path,
590+
save_consolidated=True,
591+
is_peft=peft_config is not None,
592+
)
593+
checkpointer = Checkpointer(
594+
ckpt_config,
595+
0,
596+
0,
597+
0,
598+
getattr(model_wrapper, "moe_mesh", None) if model_wrapper else None,
599+
)
600+
583601
# Handle checkpointer config updates if checkpointer is provided
584602
dequantize_base_checkpoint = False
585603
if checkpointer is not None:
@@ -600,10 +618,9 @@ def apply_model_infrastructure(
600618
)
601619

602620
# hold a list copy of the model state dict keys before any parallelization
603-
if checkpointer is not None:
604-
checkpointer.config.model_state_dict_keys = list(
605-
_maybe_adapt_state_dict_to_hf(model, model.state_dict(), quantization=dequantize_base_checkpoint).keys()
606-
)
621+
checkpointer.config.model_state_dict_keys = list(
622+
_maybe_adapt_state_dict_to_hf(model, model.state_dict(), quantization=dequantize_base_checkpoint).keys()
623+
)
607624

608625
# Loss function check
609626
if not _supports_logits_to_keep(model) and not isinstance(loss_fn, MaskedCrossEntropy):
@@ -622,15 +639,11 @@ def apply_model_infrastructure(
622639
# Weights need to be loaded for meta device models that were parallelized:
623640
# 1. When parallelize_fn was used (which will internally apply FSDP2/EP sharding)
624641
# 2. When FSDP2Manager.parallelize was used (but not MegatronFSDP which handles weights internally)
625-
should_load_checkpoint = (
626-
is_meta_device
627-
and checkpointer is not None
628-
and any(
629-
[
630-
parallelize_fn is not None and get_world_size_safe() > 1,
631-
callable(getattr(model_wrapper, "parallelize", None)),
632-
]
633-
)
642+
should_load_checkpoint = is_meta_device and any(
643+
[
644+
parallelize_fn is not None and get_world_size_safe() > 1,
645+
callable(getattr(model_wrapper, "parallelize", None)),
646+
]
634647
)
635648
if should_load_checkpoint:
636649
models_to_load = model.parts if hasattr(model, "parts") else [model]
@@ -778,7 +791,6 @@ def from_pretrained(
778791
model_wrapper=None,
779792
autopipeline: AutoPipeline | None = None,
780793
parallelize_fn: Callable | None = None,
781-
checkpointer: Optional[Checkpointer] = None,
782794
peft_config: Optional[dict] = None,
783795
fp8_config: Optional["FP8Config"] = None,
784796
qat_quantizer: Optional[Union["Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer"]] = None,
@@ -824,9 +836,6 @@ def from_pretrained(
824836
pipeline stages. Default: None.
825837
parallelize_fn (Callable | None, optional): Custom function to apply
826838
parallelization (EP + FSDP2). Default: None.
827-
checkpointer (Checkpointer, optional): Checkpointer instance for loading weights
828-
and enabling save_pretrained() functionality. Required for weight loading
829-
and checkpoint management.
830839
peft_config (dict | None, optional): PEFT/LoRA configuration dictionary.
831840
If provided, LoRA adapters will be applied to the model. Default: None.
832841
fp8_config (FP8Config | None, optional): FP8 quantization configuration.
@@ -882,7 +891,6 @@ def _retry(**override):
882891
fp8_config=fp8_config,
883892
qat_quantizer=qat_quantizer,
884893
loss_fn=loss_fn,
885-
checkpointer=checkpointer,
886894
compile_config=compile_config,
887895
model_wrapper=model_wrapper,
888896
**kwargs,
@@ -899,12 +907,7 @@ def _retry(**override):
899907
device = torch.cuda.current_device()
900908

901909
# Neither of these parallelization methods support meta device initialization
902-
# Also require checkpointer for meta device init, as we need it to load weights
903-
is_meta_device = (
904-
not isinstance(model_wrapper, (MegatronFSDPManager, DDPManager))
905-
and not force_hf
906-
and checkpointer is not None
907-
)
910+
is_meta_device = not isinstance(model_wrapper, (MegatronFSDPManager, DDPManager)) and not force_hf
908911
init_ctx = ContextManagers([no_init_weights(), init_empty_weights()]) if is_meta_device else nullcontext()
909912

910913
try:
@@ -948,10 +951,10 @@ def _retry(**override):
948951

949952
model = apply_model_infrastructure(
950953
model=model,
954+
pretrained_model_name_or_path=pretrained_model_name_or_path,
951955
is_hf_model=is_hf_model,
952956
cp_size=cp_size,
953957
tp_size=tp_size,
954-
checkpointer=checkpointer,
955958
peft_config=peft_config,
956959
quantization_config=quantization_config,
957960
fp8_config=fp8_config,
@@ -990,7 +993,6 @@ def from_config(
990993
qat_quantizer: Optional[Union["Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer"]] = None,
991994
loss_fn: Optional[Callable] = None,
992995
compile_config: Optional["CompileConfig"] = None,
993-
checkpointer: Optional[Checkpointer] = None,
994996
**kwargs,
995997
) -> PreTrainedModel:
996998
"""
@@ -1051,9 +1053,6 @@ def from_config(
10511053
it will be replaced with MaskedCrossEntropy. This is passed to AutoPipeline. Default: None.
10521054
compile_config (CompileConfig | None, optional): Configuration for torch.compile.
10531055
If provided, the model will be compiled for improved performance. Default: None.
1054-
checkpointer (Checkpointer, optional): Checkpointer instance for checkpoint
1055-
management and enabling save_pretrained() functionality. Required for
1056-
proper checkpoint handling.
10571056
**kwargs:
10581057
Additional keyword arguments. Notable ones include:
10591058
- tp_size (int): Tensor parallelism size. Default: 1.
@@ -1096,7 +1095,6 @@ def _retry(**override):
10961095
qat_quantizer=qat_quantizer,
10971096
loss_fn=loss_fn,
10981097
compile_config=compile_config,
1099-
checkpointer=checkpointer,
11001098
**kwargs,
11011099
)
11021100

@@ -1117,12 +1115,7 @@ def _retry(**override):
11171115
device = torch.cuda.current_device()
11181116

11191117
# Neither of these parallelization methods support meta device initialization
1120-
# Also require checkpointer for meta device init, as we need it to load weights
1121-
is_meta_device = (
1122-
not isinstance(model_wrapper, (MegatronFSDPManager, DDPManager))
1123-
and not force_hf
1124-
and checkpointer is not None
1125-
)
1118+
is_meta_device = not isinstance(model_wrapper, (MegatronFSDPManager, DDPManager)) and not force_hf
11261119
init_ctx = ContextManagers([no_init_weights(), init_empty_weights()]) if is_meta_device else nullcontext()
11271120

11281121
try:
@@ -1162,7 +1155,6 @@ def _retry(**override):
11621155
is_hf_model=is_hf_model,
11631156
cp_size=cp_size,
11641157
tp_size=tp_size,
1165-
checkpointer=checkpointer,
11661158
peft_config=peft_config,
11671159
quantization_config=quantization_config,
11681160
fp8_config=fp8_config,

nemo_automodel/recipes/llm/kd.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import torch
4444
import wandb
4545
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
46-
from transformers.utils import TRANSFORMERS_CACHE
4746

4847
from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer
4948
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
@@ -77,10 +76,6 @@ def _build_teacher_model(
7776
cp_size=1,
7877
parallelize_fn=None,
7978
device=None,
80-
dp_rank=0,
81-
tp_rank=0,
82-
pp_rank=0,
83-
moe_mesh=None,
8479
):
8580
"""Build and initialize the teacher model for knowledge distillation.
8681
@@ -104,37 +99,17 @@ def _build_teacher_model(
10499
The `offload_teacher_model` config option is not supported with this approach.
105100
Device placement is handled internally by NeMoAutoModelForCausalLM infrastructure.
106101
"""
107-
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
108102

109103
assert cfg_teacher is not None, "`teacher_model` section missing from YAML config"
110104
logger.info("Instantiating teacher model")
111105

112-
# Create a simple checkpointer for the teacher (just for weight loading)
113-
teacher_checkpointer = Checkpointer(
114-
CheckpointingConfig(
115-
model_repo_id=cfg_teacher.get("pretrained_model_name_or_path"),
116-
model_cache_dir=cfg_teacher.get("cache_dir", TRANSFORMERS_CACHE),
117-
# Dummy values
118-
is_peft=False,
119-
enabled=False,
120-
checkpoint_dir="",
121-
model_save_format="safetensors",
122-
save_consolidated=False,
123-
),
124-
dp_rank=dp_rank,
125-
tp_rank=tp_rank,
126-
pp_rank=pp_rank,
127-
moe_mesh=moe_mesh,
128-
)
129-
130106
# Build teacher model using the same infrastructure as student
131107
# but without PEFT/FP8/QAT (teacher should be frozen in full precision)
132108
with ScopedRNG(seed=seed, ranked=True):
133109
kwargs: Dict[str, Any] = {
134110
"tp_size": tp_size,
135111
"cp_size": cp_size,
136112
"has_packed_sequence": has_packed_sequence,
137-
"checkpointer": teacher_checkpointer,
138113
"model_wrapper": model_wrapper,
139114
"parallelize_fn": parallelize_fn,
140115
}
@@ -196,10 +171,6 @@ def setup(self): # noqa: C901 – same complexity as parent
196171
cp_size=self.cfg.get("distributed.cp_size", 1),
197172
parallelize_fn=getattr(self.cfg.get("parallelizer", None), "instantiate", None),
198173
device=teacher_device,
199-
dp_rank=self._get_dp_rank(include_cp=True),
200-
tp_rank=self._get_tp_rank(),
201-
pp_rank=self._get_pp_rank(),
202-
moe_mesh=self.moe_mesh,
203174
)
204175
logger.info("Teacher Model: " + str(self.teacher_model))
205176
# KD

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def build_model_and_optimizer(
133133
cfg_peft,
134134
model_wrapper,
135135
seed,
136-
checkpointer: Checkpointer,
137136
has_packed_sequence=False,
138137
tp_size=1,
139138
cp_size=1,
@@ -174,7 +173,6 @@ def build_model_and_optimizer(
174173
"has_packed_sequence": has_packed_sequence,
175174
"autopipeline": autopipeline,
176175
"parallelize_fn": parallelize_fn,
177-
"checkpointer": checkpointer,
178176
"peft_config": cfg_peft,
179177
"model_wrapper": model_wrapper,
180178
"loss_fn": loss_fn,
@@ -923,7 +921,6 @@ def setup(self):
923921
autopipeline=autopipeline,
924922
loss_fn=self.loss_fn,
925923
parallelize_fn=parallelize_fn,
926-
checkpointer=self.checkpointer,
927924
)
928925

929926
if isinstance(model, AutoPipeline):

nemo_automodel/recipes/llm/train_seq_cls.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def setup(self):
118118
autopipeline=None,
119119
loss_fn=self.loss_fn,
120120
parallelize_fn=None,
121-
checkpointer=self.checkpointer,
122121
unfreeze_modules=["classifier"] if self.peft_config is not None else None,
123122
)
124123

nemo_automodel/recipes/vlm/finetune.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def build_model_and_optimizer(
123123
cfg_peft,
124124
model_wrapper,
125125
seed,
126-
checkpointer: Checkpointer,
127126
tp_size=1,
128127
cp_size=1,
129128
freeze_embeddings=True,
@@ -144,7 +143,6 @@ def build_model_and_optimizer(
144143
"tp_size": tp_size,
145144
"cp_size": cp_size,
146145
"parallelize_fn": parallelize_fn,
147-
"checkpointer": checkpointer,
148146
"peft_config": cfg_peft,
149147
"model_wrapper": model_wrapper,
150148
"loss_fn": loss_fn,
@@ -653,7 +651,6 @@ def setup(self):
653651
cfg_compile=self.cfg.get("compile", None),
654652
loss_fn=self.loss_fn,
655653
parallelize_fn=parallelize_fn,
656-
checkpointer=self.checkpointer,
657654
autopipeline=autopipeline,
658655
)
659656

0 commit comments

Comments
 (0)