@@ -100,7 +100,9 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
100100 max_concurrency : int = 16
101101 epsilon_high : float | None = None # 0.28 from DAPO.
102102 off_policy_steps : int = 0
103- degenerate_group_masking : bool = True # Whether to mask out degenerate groups with all-0 advantages.
103+ degenerate_group_masking : bool = (
104+ True # Whether to mask out degenerate groups with all-0 advantages.
105+ )
104106
105107 def __post_init__ (self ):
106108 if self .num_generations <= 1 :
@@ -568,11 +570,11 @@ def grpo_loss_fn(
568570 advantages = jnp .astype (train_example .advantages , jnp .float32 )
569571 if advantages .ndim != 1 :
570572 raise ValueError (
571- f"Expected advantages to be a 1D array, but got shape { advantages .shape } "
573+ "Expected advantages to be a 1D array, but got shape"
574+ f" { advantages .shape } "
572575 )
573576
574577 # Mask out degenerate groups (all-0 sequence advantages).
575- #
576578 # For group-relative advantages, all-0 indicates a no-signal group (e.g. all
577579 # rewards are equal). Such groups should not contribute to policy, KL, or
578580 # entropy terms in this loss.
@@ -582,11 +584,9 @@ def grpo_loss_fn(
582584 invalid_group = jnp .all (jnp .isclose (grouped_advantages , 0.0 ), axis = - 1 )
583585 valid_group_mask = jnp .logical_not (invalid_group )
584586 valid_sequence_mask = jnp .repeat (valid_group_mask , num_generations )[:, None ]
585- effective_completion_mask = completion_mask * valid_sequence_mask .astype (
586- completion_mask .dtype
587+ completion_mask = completion_mask * valid_sequence_mask .astype (
588+ completion_mask .dtype
587589 )
588- else :
589- effective_completion_mask = completion_mask
590590
591591 if train_example .old_per_token_logps is None :
592592 old_per_token_logps = jax .lax .stop_gradient (per_token_logps )
@@ -597,15 +597,15 @@ def grpo_loss_fn(
597597
598598 seq_importance_ratio = per_token_logps - old_per_token_logps
599599 # Record KL divergence before clipping.
600- ppo_kl = ppo_helpers .masked_mean (- seq_importance_ratio , effective_completion_mask )
600+ ppo_kl = ppo_helpers .masked_mean (- seq_importance_ratio , completion_mask )
601601
602602 seq_importance_ratio = jnp .clip (seq_importance_ratio , max = 20.0 , min = - 20.0 )
603603
604604 # TODO(sizhi): Refactor this to a separate function.
605605 if loss_algo == "gspo-token" :
606- seq_importance_ratio = (seq_importance_ratio * effective_completion_mask ).sum (
606+ seq_importance_ratio = (seq_importance_ratio * completion_mask ).sum (
607607 axis = - 1
608- ) / jnp .clip (effective_completion_mask .sum (- 1 ), min = 1 )
608+ ) / jnp .clip (completion_mask .sum (- 1 ), min = 1 )
609609 seq_importance_ratio = (
610610 per_token_logps
611611 - jax .lax .stop_gradient (per_token_logps )
@@ -616,17 +616,15 @@ def grpo_loss_fn(
616616 is_ratio = jnp .exp (seq_importance_ratio )
617617 advantages = advantages [:, None ]
618618 pg_loss_1 = - advantages * is_ratio
619- pg_loss_2 = - advantages * jnp .clip (
620- is_ratio , 1 - epsilon , 1 + epsilon_high
621- )
619+ pg_loss_2 = - advantages * jnp .clip (is_ratio , 1 - epsilon , 1 + epsilon_high )
622620
623621 per_token_loss = jnp .maximum (pg_loss_1 , pg_loss_2 ).astype (jnp .float32 )
624622
625623 clipped_fraction = ppo_helpers .masked_mean (
626- jnp .greater (pg_loss_2 , pg_loss_1 ), effective_completion_mask
624+ jnp .greater (pg_loss_2 , pg_loss_1 ), completion_mask
627625 )
628626
629- pg_loss = ppo_helpers .masked_mean (per_token_loss , effective_completion_mask )
627+ pg_loss = ppo_helpers .masked_mean (per_token_loss , completion_mask )
630628 aux = {
631629 "kl" : 0.0 ,
632630 "pg_loss" : pg_loss ,
@@ -641,15 +639,15 @@ def grpo_loss_fn(
641639
642640 # Log mean KL.
643641 aux ["kl" ] = jnp .astype (
644- (kl * effective_completion_mask ).sum () / jnp .clip (effective_completion_mask .sum (), min = 1 ),
642+ (kl * completion_mask ).sum () / jnp .clip (completion_mask .sum (), min = 1 ),
645643 jnp .float32 ,
646644 )
647645
648646 loss = common .aggregate_loss (
649- per_token_loss , effective_completion_mask , loss_aggregation_mode
647+ per_token_loss , completion_mask , loss_aggregation_mode
650648 )
651649 token_entropy = ppo_helpers .compute_entropy_from_logits (logits )
652- entropy_loss = ppo_helpers .masked_mean (token_entropy , effective_completion_mask )
650+ entropy_loss = ppo_helpers .masked_mean (token_entropy , completion_mask )
653651 aux ["entropy" ] = entropy_loss
654652
655653 return loss , aux
0 commit comments