Skip to content
Merged

nit #1293

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
max_concurrency: int = 16
epsilon_high: float | None = None # 0.28 from DAPO.
off_policy_steps: int = 0
degenerate_group_masking: bool = True # Whether to mask out degenerate groups with all-0 advantages.
degenerate_group_masking: bool = (
True # Whether to mask out degenerate groups with all-0 advantages.
)

def __post_init__(self):
if self.num_generations <= 1:
Expand Down Expand Up @@ -568,11 +570,11 @@ def grpo_loss_fn(
advantages = jnp.astype(train_example.advantages, jnp.float32)
if advantages.ndim != 1:
raise ValueError(
f"Expected advantages to be a 1D array, but got shape {advantages.shape}"
"Expected advantages to be a 1D array, but got shape"
f" {advantages.shape}"
)

# Mask out degenerate groups (all-0 sequence advantages).
#
# For group-relative advantages, all-0 indicates a no-signal group (e.g. all
# rewards are equal). Such groups should not contribute to policy, KL, or
# entropy terms in this loss.
Expand All @@ -582,11 +584,9 @@ def grpo_loss_fn(
invalid_group = jnp.all(jnp.isclose(grouped_advantages, 0.0), axis=-1)
valid_group_mask = jnp.logical_not(invalid_group)
valid_sequence_mask = jnp.repeat(valid_group_mask, num_generations)[:, None]
effective_completion_mask = completion_mask * valid_sequence_mask.astype(
completion_mask.dtype
completion_mask = completion_mask * valid_sequence_mask.astype(
completion_mask.dtype
)
else:
effective_completion_mask = completion_mask

if train_example.old_per_token_logps is None:
old_per_token_logps = jax.lax.stop_gradient(per_token_logps)
Expand All @@ -597,15 +597,15 @@ def grpo_loss_fn(

seq_importance_ratio = per_token_logps - old_per_token_logps
# Record KL divergence before clipping.
ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, effective_completion_mask)
ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, completion_mask)

seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0)

# TODO(sizhi): Refactor this to a separate function.
if loss_algo == "gspo-token":
seq_importance_ratio = (seq_importance_ratio * effective_completion_mask).sum(
seq_importance_ratio = (seq_importance_ratio * completion_mask).sum(
axis=-1
) / jnp.clip(effective_completion_mask.sum(-1), min=1)
) / jnp.clip(completion_mask.sum(-1), min=1)
seq_importance_ratio = (
per_token_logps
- jax.lax.stop_gradient(per_token_logps)
Expand All @@ -616,17 +616,15 @@ def grpo_loss_fn(
is_ratio = jnp.exp(seq_importance_ratio)
advantages = advantages[:, None]
pg_loss_1 = -advantages * is_ratio
pg_loss_2 = -advantages * jnp.clip(
is_ratio, 1 - epsilon, 1 + epsilon_high
)
pg_loss_2 = -advantages * jnp.clip(is_ratio, 1 - epsilon, 1 + epsilon_high)

per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32)

clipped_fraction = ppo_helpers.masked_mean(
jnp.greater(pg_loss_2, pg_loss_1), effective_completion_mask
jnp.greater(pg_loss_2, pg_loss_1), completion_mask
)

pg_loss = ppo_helpers.masked_mean(per_token_loss, effective_completion_mask)
pg_loss = ppo_helpers.masked_mean(per_token_loss, completion_mask)
aux = {
"kl": 0.0,
"pg_loss": pg_loss,
Expand All @@ -641,15 +639,15 @@ def grpo_loss_fn(

# Log mean KL.
aux["kl"] = jnp.astype(
(kl * effective_completion_mask).sum() / jnp.clip(effective_completion_mask.sum(), min=1),
(kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1),
jnp.float32,
)

loss = common.aggregate_loss(
per_token_loss, effective_completion_mask, loss_aggregation_mode
per_token_loss, completion_mask, loss_aggregation_mode
)
token_entropy = ppo_helpers.compute_entropy_from_logits(logits)
entropy_loss = ppo_helpers.masked_mean(token_entropy, effective_completion_mask)
entropy_loss = ppo_helpers.masked_mean(token_entropy, completion_mask)
aux["entropy"] = entropy_loss

return loss, aux
Expand Down
Loading