@@ -215,6 +215,75 @@ def simple_grpo_loss(
215215 return loss
216216
217217
218+ def dapo_loss (
219+ logits : torch .Tensor ,
220+ response : torch .Tensor ,
221+ ref_logprobs : torch .Tensor ,
222+ advantages : torch .Tensor ,
223+ padding_mask : torch .Tensor ,
224+ beta : float = 0.005 ,
225+ clip_eps_low : float = 0.2 ,
226+ clip_eps_high : float = 0.28 ,
227+ ) -> torch .Tensor :
228+ """
229+ DAPO (Direct Alignment Policy Optimization) loss function.
230+
231+ Implements PPO-style clipped objective with KL divergence penalty.
232+ Based on the compute_loss function from old_dapo.py.
233+
234+ Args:
235+ logits: Model output logits [batch_size, seq_len, vocab_size]
236+ response: Response token ids [batch_size, seq_len]
237+ ref_logprobs: Reference model log probabilities [batch_size, seq_len]
238+ advantages: Advantage values [batch_size, 1]
239+ padding_mask: Mask for valid tokens [batch_size, seq_len]
240+ beta: KL divergence coefficient
241+ clip_eps_low: Lower clipping bound for importance sampling ratio
242+ clip_eps_high: Upper clipping bound for importance sampling ratio
243+
244+ Returns:
245+ Scalar loss value
246+ """
247+ # Compute current action log probabilities
248+ action_log_probs = compute_logprobs (logits , response )
249+
250+ # Compute KL divergence term (k3 in DAPO)
251+ if beta != 0.0 :
252+ log_ratio = ref_logprobs - action_log_probs
253+ log_ratio = log_ratio * padding_mask
254+ k3 = log_ratio .exp () - 1 - log_ratio
255+
256+ # Use detached log probs as "old" log probs (for single iteration)
257+ # In multi-iteration setting, these would be passed as input
258+ old_action_log_probs = action_log_probs .detach ()
259+
260+ # Compute importance sampling ratio
261+ coef_1 = torch .exp (action_log_probs - old_action_log_probs )
262+
263+ # Clipped importance sampling ratio
264+ coef_2 = torch .clamp (coef_1 , 1 - clip_eps_low , 1 + clip_eps_high )
265+
266+ # Compute per-token losses with advantages
267+ # advantages shape: [batch_size, 1], unsqueeze to [batch_size, 1] for broadcasting
268+ per_token_loss1 = coef_1 * advantages
269+ per_token_loss2 = coef_2 * advantages
270+
271+ # Take minimum for clipped objective (negative because we minimize)
272+ per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
273+
274+ # Apply action mask
275+ per_token_loss = per_token_loss * padding_mask
276+
277+ # Add KL penalty
278+ if beta != 0.0 :
279+ per_token_loss = per_token_loss + beta * k3
280+
281+ # Average over tokens and batch
282+ loss = (per_token_loss .sum (dim = 1 ) / padding_mask .sum (dim = 1 ).clamp (min = 1.0 )).mean ()
283+
284+ return loss
285+
286+
218287@dataclass
219288class JuliaRewardActor (ForgeActor ):
220289 """Reward actor for Julia code execution using GenericOpenEnvActor.
@@ -550,9 +619,7 @@ async def main(cfg: DictConfig):
550619 ) = await asyncio .gather (
551620 JuliaDatasetActor .options (** cfg .actors .dataset ).as_actor (** cfg .dataset ),
552621 Policy .options (** cfg .services .policy ).as_service (** cfg .policy ),
553- RLTrainer .options (** cfg .actors .trainer ).as_actor (
554- ** cfg .trainer , loss = simple_grpo_loss
555- ),
622+ RLTrainer .options (** cfg .actors .trainer ).as_actor (** cfg .trainer , loss = dapo_loss ),
556623 ReplayBuffer .options (** cfg .actors .replay_buffer ).as_actor (
557624 ** cfg .replay_buffer , collate = collate
558625 ),
0 commit comments