Skip to content

Commit 9b7aa6f

Browse files
Scale loss by number of valid tokens in global batch in case of cross entropy loss and properly formatted batches provided by user while using gradient accumulation
1 parent 385fd56 commit 9b7aa6f

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/lightning/pytorch/loops/optimization/automatic.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def _clone_loss(self) -> None:
5959
self.loss = self.closure_loss.detach().clone()
6060

6161
@classmethod
62-
def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize: int = 1) -> "ClosureResult":
62+
def from_training_step_output(
63+
cls, training_step_output: STEP_OUTPUT, normalize: int = 1, num_global_valid_tokens: Optional[int] = None
64+
) -> "ClosureResult":
6365
closure_loss, extra = None, {}
6466

6567
if isinstance(training_step_output, Mapping):
@@ -80,7 +82,10 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize:
8082
if closure_loss is not None:
8183
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
8284
# note: avoid in-place operation `x /= y` here on purpose
83-
closure_loss = closure_loss / normalize
85+
if num_global_valid_tokens is not None:
86+
closure_loss = closure_loss / num_global_valid_tokens
87+
elif normalize > 1:
88+
closure_loss = closure_loss / normalize
8489

8590
return cls(closure_loss, extra=extra)
8691

@@ -315,6 +320,7 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
315320
316321
"""
317322
trainer = self.trainer
323+
num_global_valid_tokens = kwargs.pop("num_global_valid_tokens", None)
318324

319325
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
320326
self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
@@ -326,4 +332,6 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
326332
" place."
327333
)
328334

329-
return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)
335+
return self.output_result_cls.from_training_step_output(
336+
training_step_output, trainer.accumulate_grad_batches, num_global_valid_tokens=num_global_valid_tokens
337+
)

0 commit comments

Comments
 (0)