Skip to content

Commit 36c0474

Browse files
tianshubThe tunix Authors
authored andcommitted
nit
PiperOrigin-RevId: 888800288
1 parent 8ea9ef7 commit 36c0474

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

tunix/rl/agentic/agentic_grpo_learner.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
100100
max_concurrency: int = 16
101101
epsilon_high: float | None = None # 0.28 from DAPO.
102102
off_policy_steps: int = 0
103-
degenerate_group_masking: bool = True # Whether to mask out degenerate groups with all-0 advantages.
103+
degenerate_group_masking: bool = (
104+
True # Whether to mask out degenerate groups with all-0 advantages.
105+
)
104106

105107
def __post_init__(self):
106108
if self.num_generations <= 1:
@@ -568,11 +570,11 @@ def grpo_loss_fn(
568570
advantages = jnp.astype(train_example.advantages, jnp.float32)
569571
if advantages.ndim != 1:
570572
raise ValueError(
571-
f"Expected advantages to be a 1D array, but got shape {advantages.shape}"
573+
"Expected advantages to be a 1D array, but got shape"
574+
f" {advantages.shape}"
572575
)
573576

574577
# Mask out degenerate groups (all-0 sequence advantages).
575-
#
576578
# For group-relative advantages, all-0 indicates a no-signal group (e.g. all
577579
# rewards are equal). Such groups should not contribute to policy, KL, or
578580
# entropy terms in this loss.
@@ -582,11 +584,9 @@ def grpo_loss_fn(
582584
invalid_group = jnp.all(jnp.isclose(grouped_advantages, 0.0), axis=-1)
583585
valid_group_mask = jnp.logical_not(invalid_group)
584586
valid_sequence_mask = jnp.repeat(valid_group_mask, num_generations)[:, None]
585-
effective_completion_mask = completion_mask * valid_sequence_mask.astype(
586-
completion_mask.dtype
587+
completion_mask = completion_mask * valid_sequence_mask.astype(
588+
completion_mask.dtype
587589
)
588-
else:
589-
effective_completion_mask = completion_mask
590590

591591
if train_example.old_per_token_logps is None:
592592
old_per_token_logps = jax.lax.stop_gradient(per_token_logps)
@@ -597,15 +597,15 @@ def grpo_loss_fn(
597597

598598
seq_importance_ratio = per_token_logps - old_per_token_logps
599599
# Record KL divergence before clipping.
600-
ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, effective_completion_mask)
600+
ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, completion_mask)
601601

602602
seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0)
603603

604604
# TODO(sizhi): Refactor this to a separate function.
605605
if loss_algo == "gspo-token":
606-
seq_importance_ratio = (seq_importance_ratio * effective_completion_mask).sum(
606+
seq_importance_ratio = (seq_importance_ratio * completion_mask).sum(
607607
axis=-1
608-
) / jnp.clip(effective_completion_mask.sum(-1), min=1)
608+
) / jnp.clip(completion_mask.sum(-1), min=1)
609609
seq_importance_ratio = (
610610
per_token_logps
611611
- jax.lax.stop_gradient(per_token_logps)
@@ -616,17 +616,15 @@ def grpo_loss_fn(
616616
is_ratio = jnp.exp(seq_importance_ratio)
617617
advantages = advantages[:, None]
618618
pg_loss_1 = -advantages * is_ratio
619-
pg_loss_2 = -advantages * jnp.clip(
620-
is_ratio, 1 - epsilon, 1 + epsilon_high
621-
)
619+
pg_loss_2 = -advantages * jnp.clip(is_ratio, 1 - epsilon, 1 + epsilon_high)
622620

623621
per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32)
624622

625623
clipped_fraction = ppo_helpers.masked_mean(
626-
jnp.greater(pg_loss_2, pg_loss_1), effective_completion_mask
624+
jnp.greater(pg_loss_2, pg_loss_1), completion_mask
627625
)
628626

629-
pg_loss = ppo_helpers.masked_mean(per_token_loss, effective_completion_mask)
627+
pg_loss = ppo_helpers.masked_mean(per_token_loss, completion_mask)
630628
aux = {
631629
"kl": 0.0,
632630
"pg_loss": pg_loss,
@@ -641,15 +639,15 @@ def grpo_loss_fn(
641639

642640
# Log mean KL.
643641
aux["kl"] = jnp.astype(
644-
(kl * effective_completion_mask).sum() / jnp.clip(effective_completion_mask.sum(), min=1),
642+
(kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1),
645643
jnp.float32,
646644
)
647645

648646
loss = common.aggregate_loss(
649-
per_token_loss, effective_completion_mask, loss_aggregation_mode
647+
per_token_loss, completion_mask, loss_aggregation_mode
650648
)
651649
token_entropy = ppo_helpers.compute_entropy_from_logits(logits)
652-
entropy_loss = ppo_helpers.masked_mean(token_entropy, effective_completion_mask)
650+
entropy_loss = ppo_helpers.masked_mean(token_entropy, completion_mask)
653651
aux["entropy"] = entropy_loss
654652

655653
return loss, aux

0 commit comments

Comments
 (0)