@@ -59,7 +59,9 @@ def _clone_loss(self) -> None:
5959 self .loss = self .closure_loss .detach ().clone ()
6060
6161 @classmethod
62- def from_training_step_output (cls , training_step_output : STEP_OUTPUT , normalize : int = 1 ) -> "ClosureResult" :
62+ def from_training_step_output (
63+ cls , training_step_output : STEP_OUTPUT , normalize : int = 1 , num_global_valid_tokens : Optional [int ] = None
64+ ) -> "ClosureResult" :
6365 closure_loss , extra = None , {}
6466
6567 if isinstance (training_step_output , Mapping ):
@@ -80,7 +82,10 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize:
8082 if closure_loss is not None :
8183 # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
8284 # note: avoid in-place operation `x /= y` here on purpose
83- closure_loss = closure_loss / normalize
85+ if num_global_valid_tokens is not None :
86+ closure_loss = closure_loss / num_global_valid_tokens
87+ elif normalize > 1 :
88+ closure_loss = closure_loss / normalize
8489
8590 return cls (closure_loss , extra = extra )
8691
@@ -315,6 +320,7 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
315320
316321 """
317322 trainer = self .trainer
323+ num_global_valid_tokens = kwargs .pop ("num_global_valid_tokens" , None )
318324
319325 training_step_output = call ._call_strategy_hook (trainer , "training_step" , * kwargs .values ())
320326 self .trainer .strategy .post_training_step () # unused hook - call anyway for backward compatibility
@@ -326,4 +332,6 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
326332 " place."
327333 )
328334
329- return self .output_result_cls .from_training_step_output (training_step_output , trainer .accumulate_grad_batches )
335+ return self .output_result_cls .from_training_step_output (
336+ training_step_output , trainer .accumulate_grad_batches , num_global_valid_tokens = num_global_valid_tokens
337+ )
0 commit comments