Skip to content

Commit b48eb36

Browse files
yurekamiclaude
andcommitted
Fix W&B step mismatch by consolidating log calls (#15204)
The W&B _step counter was showing approximately 2x the actual global_step because `log('global_step')` and `log('step')` were called separately, causing WandbLogger's internal _step counter to increment twice per training step. This fix consolidates both metrics into a single `log_dict()` call in all three strategy files, ensuring W&B step aligns correctly with trainer.global_step for accurate cross-run comparisons. Files modified: - nemo/lightning/pytorch/strategies/megatron_strategy.py - nemo/lightning/pytorch/strategies/fsdp_strategy.py - nemo/lightning/pytorch/strategies/fsdp2_strategy.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 1a3c291 commit b48eb36

File tree

3 files changed

+15
-24
lines changed

3 files changed

+15
-24
lines changed

nemo/lightning/pytorch/strategies/fsdp2_strategy.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -401,19 +401,16 @@ def training_step(self, batch, batch_idx=None) -> STEP_OUTPUT:
401401
else:
402402
loss = self.lightning_module.training_step(batch, batch_idx)
403403

404-
self.lightning_module.log(
405-
'global_step',
406-
self.trainer.global_step,
404+
# Log global_step and step together using log_dict to prevent W&B from
405+
# incrementing its internal _step counter multiple times per training step.
406+
# This fixes issue #15204 where W&B showed ~2x the actual global_step.
407+
self.lightning_module.log_dict(
408+
{'global_step': self.trainer.global_step, 'step': self.trainer.global_step},
407409
prog_bar=True,
408410
rank_zero_only=True,
409411
batch_size=1,
410412
)
411413

412-
self.lightning_module.log(
413-
'step',
414-
self.trainer.global_step,
415-
)
416-
417414
return loss
418415

419416
@override

nemo/lightning/pytorch/strategies/fsdp_strategy.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,15 @@ def training_step(self, batch, batch_idx=None) -> STEP_OUTPUT:
181181
with self.precision_plugin.train_step_context():
182182
loss, reduced = self._step_proxy("training", batch, batch_idx)
183183

184-
self.lightning_module.log(
185-
'global_step',
186-
self.trainer.global_step,
184+
# Log global_step and step together using log_dict to prevent W&B from
185+
# incrementing its internal _step counter multiple times per training step.
186+
# This fixes issue #15204 where W&B showed ~2x the actual global_step.
187+
self.lightning_module.log_dict(
188+
{'global_step': self.trainer.global_step, 'step': self.trainer.global_step},
187189
prog_bar=True,
188190
rank_zero_only=True,
189191
batch_size=1,
190192
)
191-
192-
self.lightning_module.log(
193-
'step',
194-
self.trainer.global_step,
195-
)
196193
self.lightning_module.log(
197194
'reduced_train_loss', reduced['avg'], prog_bar=True, rank_zero_only=True, batch_size=1
198195
)

nemo/lightning/pytorch/strategies/megatron_strategy.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -800,18 +800,15 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
800800
if (torch.is_tensor(self.trainer.global_step) and self.trainer.global_step.is_cuda)
801801
else torch.tensor(self.trainer.global_step, pin_memory=True).to("cuda", non_blocking=True)
802802
)
803-
self.lightning_module.log(
804-
"global_step",
805-
global_step,
803+
# Log global_step and step together using log_dict to prevent W&B from
804+
# incrementing its internal _step counter multiple times per training step.
805+
# This fixes issue #15204 where W&B showed ~2x the actual global_step.
806+
self.lightning_module.log_dict(
807+
{"global_step": global_step, "step": global_step},
806808
prog_bar=True,
807809
batch_size=1,
808810
)
809811

810-
self.lightning_module.log(
811-
"step",
812-
global_step,
813-
)
814-
815812
if self.log_memory_usage:
816813
# maximum GPU memory that has been managed by the caching allocator
817814
max_memory_reserved = torch.cuda.max_memory_reserved()

0 commit comments

Comments
 (0)