Skip to content

Commit ce553b2

Browse files
authored
feat: Creating a centralized script for latent creation (#101)
* Add support for a uniform data processor Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Adding processing logic for video models Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Linting fixes Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Fix secrets error Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Formatting fix Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> --------- Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
1 parent cabae30 commit ce553b2

File tree

17 files changed

+2749
-338
lines changed

17 files changed

+2749
-338
lines changed

.github/workflows/config/.secrets.baseline

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,6 @@
9090
{
9191
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
9292
},
93-
{
94-
"path": "detect_secrets.filters.common.is_baseline_file",
95-
"filename": ".github/workflows/config/.secrets.baseline"
96-
},
9793
{
9894
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
9995
"min_level": 2
@@ -139,10 +135,10 @@
139135
"filename": "examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml",
140136
"hashed_secret": "c70f071570ba65f9c4079d6051e955ff4f802eea",
141137
"is_verified": false,
142-
"line_number": 67,
138+
"line_number": 72,
143139
"is_secret": false
144140
}
145141
]
146142
},
147-
"generated_at": "2026-01-30T18:50:34Z"
143+
"generated_at": "2026-02-12T07:45:24Z"
148144
}

dfm/src/automodel/recipes/train.py

Lines changed: 98 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
2626
from nemo_automodel.components.loggers.log_utils import setup_logging
2727
from nemo_automodel.components.loggers.wandb_utils import suppress_wandb_log_messages
28+
from nemo_automodel.components.optim.scheduler import OptimizerParamScheduler
2829
from nemo_automodel.components.training.rng import StatefulRNG
2930
from nemo_automodel.components.training.step_scheduler import StepScheduler
3031
from nemo_automodel.recipes.base_recipe import BaseRecipe
@@ -195,20 +196,93 @@ def build_model_and_optimizer(
195196

196197

197198
def build_lr_scheduler(
199+
cfg,
198200
optimizer: torch.optim.Optimizer,
199-
*,
200-
num_epochs: int,
201-
steps_per_epoch: int,
202-
eta_min: float = 1e-6,
203-
) -> torch.optim.lr_scheduler.CosineAnnealingLR:
204-
"""Build the cosine annealing learning rate scheduler."""
205-
206-
total_steps = max(1, num_epochs * max(1, steps_per_epoch))
207-
logging.info(f"[INFO] Scheduler configured for {total_steps} total steps")
208-
return torch.optim.lr_scheduler.CosineAnnealingLR(
209-
optimizer,
210-
T_max=total_steps,
211-
eta_min=eta_min,
201+
total_steps: int,
202+
) -> Optional[OptimizerParamScheduler]:
203+
"""Build the learning rate scheduler.
204+
205+
Args:
206+
cfg: Configuration for the OptimizerParamScheduler from YAML. If None, no scheduler
207+
is created and constant LR is used. Supports:
208+
- lr_decay_style: constant, linear, cosine, inverse-square-root, WSD
209+
- lr_warmup_steps: Number of warmup steps (or fraction < 1 for percentage)
210+
- min_lr: Minimum LR after decay
211+
- init_lr: Initial LR for warmup (defaults to 10% of max_lr if warmup enabled)
212+
- wd_incr_style: constant, linear, cosine (for weight decay scheduling)
213+
- wsd_decay_steps: WSD-specific decay steps
214+
- lr_wsd_decay_style: WSD-specific decay style (cosine, linear, exponential, minus_sqrt)
215+
optimizer: The optimizer to be scheduled.
216+
total_steps: Total number of optimizer steps for the training run.
217+
218+
Returns:
219+
OptimizerParamScheduler instance, or None if cfg is None.
220+
"""
221+
if cfg is None:
222+
return None
223+
224+
user_cfg = cfg.to_dict() if hasattr(cfg, "to_dict") else dict(cfg)
225+
226+
base_lr = optimizer.param_groups[0]["lr"]
227+
base_wd = optimizer.param_groups[0].get("weight_decay", 0.0)
228+
229+
# Compute defaults from runtime values
230+
default_cfg: Dict[str, Any] = {
231+
"optimizer": optimizer,
232+
"lr_warmup_steps": min(1000, total_steps // 10),
233+
"lr_decay_steps": total_steps,
234+
"lr_decay_style": "cosine",
235+
"init_lr": base_lr * 0.1,
236+
"max_lr": base_lr,
237+
"min_lr": base_lr * 0.01,
238+
"start_wd": base_wd,
239+
"end_wd": base_wd,
240+
"wd_incr_steps": total_steps,
241+
"wd_incr_style": "constant",
242+
}
243+
244+
# Handle warmup as fraction before merging
245+
if "lr_warmup_steps" in user_cfg:
246+
warmup = user_cfg["lr_warmup_steps"]
247+
if isinstance(warmup, float) and 0 < warmup < 1:
248+
user_cfg["lr_warmup_steps"] = int(warmup * total_steps)
249+
250+
# WSD defaults if user specifies WSD style
251+
if user_cfg.get("lr_decay_style") == "WSD":
252+
default_cfg["wsd_decay_steps"] = max(1, total_steps // 10)
253+
default_cfg["lr_wsd_decay_style"] = "cosine"
254+
255+
# User config overrides defaults
256+
default_cfg.update(user_cfg)
257+
258+
# If user disabled warmup, set init_lr = max_lr
259+
if default_cfg["lr_warmup_steps"] == 0:
260+
default_cfg["init_lr"] = default_cfg["max_lr"]
261+
262+
# Ensure warmup < decay steps
263+
if default_cfg["lr_warmup_steps"] >= default_cfg["lr_decay_steps"]:
264+
default_cfg["lr_warmup_steps"] = max(0, default_cfg["lr_decay_steps"] - 1)
265+
266+
logging.info(
267+
f"[INFO] LR Scheduler: style={default_cfg['lr_decay_style']}, "
268+
f"warmup={default_cfg['lr_warmup_steps']}, total={default_cfg['lr_decay_steps']}, "
269+
f"max_lr={default_cfg['max_lr']}, min_lr={default_cfg['min_lr']}"
270+
)
271+
272+
return OptimizerParamScheduler(
273+
optimizer=default_cfg["optimizer"],
274+
init_lr=default_cfg["init_lr"],
275+
max_lr=default_cfg["max_lr"],
276+
min_lr=default_cfg["min_lr"],
277+
lr_warmup_steps=default_cfg["lr_warmup_steps"],
278+
lr_decay_steps=default_cfg["lr_decay_steps"],
279+
lr_decay_style=default_cfg["lr_decay_style"],
280+
start_wd=default_cfg["start_wd"],
281+
end_wd=default_cfg["end_wd"],
282+
wd_incr_steps=default_cfg["wd_incr_steps"],
283+
wd_incr_style=default_cfg["wd_incr_style"],
284+
wsd_decay_steps=default_cfg.get("wsd_decay_steps"),
285+
lr_wsd_decay_style=default_cfg.get("lr_wsd_decay_style"),
212286
)
213287

214288

@@ -390,11 +464,17 @@ def setup(self):
390464
grad_acc_steps = max(1, self.global_batch_size // max(1, self.local_batch_size * self.dp_size))
391465
self.steps_per_epoch = ceil(self.raw_steps_per_epoch / grad_acc_steps)
392466

393-
self.lr_scheduler = build_lr_scheduler(
467+
# Calculate total optimizer steps for LR scheduler
468+
total_steps = self.num_epochs * self.steps_per_epoch
469+
470+
# Build LR scheduler (returns None if lr_scheduler not in config)
471+
# Wrap in list for compatibility with checkpointing (OptimizerState expects list)
472+
lr_scheduler = build_lr_scheduler(
473+
self.cfg.get("lr_scheduler", None),
394474
self.optimizer,
395-
num_epochs=self.num_epochs,
396-
steps_per_epoch=self.steps_per_epoch,
475+
total_steps,
397476
)
477+
self.lr_scheduler = [lr_scheduler] if lr_scheduler is not None else None
398478

399479
self.global_step = 0
400480
self.start_epoch = 0
@@ -490,7 +570,8 @@ def run_train_validation_loop(self):
490570
grad_norm = float(grad_norm) if torch.is_tensor(grad_norm) else grad_norm
491571

492572
self.optimizer.step()
493-
self.lr_scheduler.step()
573+
if self.lr_scheduler is not None:
574+
self.lr_scheduler[0].step(1)
494575

495576
group_loss_mean = float(sum(micro_losses) / len(micro_losses))
496577
epoch_loss += group_loss_mean

0 commit comments

Comments
 (0)