@@ -117,7 +117,12 @@ def collate(
117117 return inputs , targets
118118
119119
120- # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
120+ # TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py
121+ # Currently duplicated because of function signature differences:
122+ # - This function takes logits + response, computes logprobs internally
123+ # - SimpleGRPOLoss takes pre-computed logprobs
124+ # - TitanTrainer passes logits, so would need wrapper or signature change
125+ # Consider refactoring TitanTrainer's loss interface to standardize this.
121126def simple_grpo_loss (
122127 logits : torch .Tensor ,
123128 response : torch .Tensor ,
@@ -129,11 +134,30 @@ def simple_grpo_loss(
129134 logprobs : torch .Tensor = compute_logprobs (logits , response )
130135 kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
131136 per_token_policy_loss = torch .exp (logprobs - logprobs .detach ()) * advantages
132- per_token_loss = - (per_token_policy_loss - beta * kl )
133- loss = (
134- ((per_token_loss * padding_mask ).sum (dim = 1 ))
137+
138+ # Compute mean KL per valid token
139+ mean_kl = (
140+ ((kl * padding_mask ).sum (dim = 1 )) / (padding_mask .sum (dim = 1 ).clamp (min = 1.0 ))
141+ ).mean ()
142+
143+ # Compute mean policy loss per valid token
144+ mean_policy_loss = (
145+ ((per_token_policy_loss * padding_mask ).sum (dim = 1 ))
135146 / (padding_mask .sum (dim = 1 ).clamp (min = 1.0 ))
136147 ).mean ()
148+
149+ # Compute loss using the means (mathematically equivalent)
150+ loss = - (mean_policy_loss - beta * mean_kl )
151+
152+ # Log metrics
153+ record_metric ("grpo_loss/kl_divergence_mean" , mean_kl .item (), Reduce .MEAN )
154+ record_metric (
155+ "grpo_loss/kl_divergence_max" , (kl * padding_mask ).max ().item (), Reduce .MAX
156+ )
157+ record_metric ("grpo_loss/policy_loss" , mean_policy_loss .item (), Reduce .MEAN )
158+ record_metric ("grpo_loss/advantage_mean" , advantages .mean ().item (), Reduce .MEAN )
159+ record_metric ("grpo_loss/advantage_std" , advantages .std ().item (), Reduce .MEAN )
160+
137161 return loss
138162
139163
0 commit comments