Skip to content

Commit b2ffe7d

Browse files
committed
add GRPO with trajectory-level deduplication support
1 parent 5f3093d commit b2ffe7d

File tree

1 file changed

+103
-8
lines changed

1 file changed

+103
-8
lines changed

agentlightning/verl/trainer.py

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
compute_throughout_metrics,
2525
compute_timing_metrics,
2626
)
27+
from collections import defaultdict
28+
2729
from verl.trainer.ppo.ray_trainer import (
2830
AdvantageEstimator,
2931
RayPPOTrainer,
@@ -45,6 +47,72 @@
4547
]
4648

4749

50+
def compute_grpo_outcome_advantage(
51+
token_level_rewards: torch.Tensor,
52+
response_mask: torch.Tensor,
53+
index: np.ndarray,
54+
traj_index: np.ndarray | None = None,
55+
epsilon: float = 1e-6,
56+
norm_adv_by_std_in_grpo: bool = True,
57+
compute_mean_std_cross_all_data: bool = True,
58+
) -> tuple[torch.Tensor, torch.Tensor]:
59+
"""Compute advantage for GRPO with trajectory-level deduplication support.
60+
61+
This is a minimal extension of VeRL's GRPO implementation, adding support for
62+
trajectory-level deduplication via `traj_index` and `compute_mean_std_cross_all_data`.
63+
64+
Args:
65+
token_level_rewards: Shape (bs, response_length).
66+
response_mask: Shape (bs, response_length).
67+
index: Group index array (e.g., data_id).
68+
traj_index: Trajectory index array (e.g., rollout_id). If None, no deduplication.
69+
epsilon: Small value for numerical stability.
70+
norm_adv_by_std_in_grpo: If True, normalize by std (original GRPO). If False, Dr.GRPO style.
71+
compute_mean_std_cross_all_data: If True (default), compute mean/std across all data.
72+
If False, compute mean/std per unique (index, traj_index) trajectory.
73+
74+
Returns:
75+
Tuple of (advantages, returns), both shape (bs, response_length).
76+
"""
77+
scores = token_level_rewards.sum(dim=-1)
78+
79+
id2score: dict = defaultdict(list)
80+
id2mean: dict = {}
81+
id2std: dict = {}
82+
seen_pairs: set = set()
83+
84+
with torch.no_grad():
85+
bsz = scores.shape[0]
86+
for i in range(bsz):
87+
# Trajectory deduplication: skip if (index, traj_index) already seen
88+
if traj_index is not None and (index[i], traj_index[i]) in seen_pairs:
89+
continue
90+
id2score[index[i]].append(scores[i])
91+
# Mark as seen only when compute_mean_std_cross_all_data is False
92+
if traj_index is not None and not compute_mean_std_cross_all_data:
93+
seen_pairs.add((index[i], traj_index[i]))
94+
95+
for idx in id2score:
96+
if len(id2score[idx]) == 1:
97+
id2mean[idx] = torch.tensor(0.0)
98+
id2std[idx] = torch.tensor(1.0)
99+
elif len(id2score[idx]) > 1:
100+
scores_tensor = torch.stack(id2score[idx])
101+
id2mean[idx] = torch.mean(scores_tensor)
102+
id2std[idx] = torch.std(scores_tensor)
103+
else:
104+
raise ValueError(f"no score in prompt index: {idx}")
105+
106+
for i in range(bsz):
107+
if norm_adv_by_std_in_grpo:
108+
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
109+
else:
110+
scores[i] = scores[i] - id2mean[index[i]]
111+
scores = scores.unsqueeze(-1) * response_mask
112+
113+
return scores, scores
114+
115+
48116
@contextmanager
49117
def _timer(name: str, timing_raw: Dict[str, float]):
50118
with Timer(name=name, logger=None) as timer:
@@ -355,15 +423,42 @@ def _train_step(self, batch_dict: dict) -> dict:
355423
"norm_adv_by_std_in_grpo", True
356424
) # GRPO adv normalization factor
357425

358-
batch = compute_advantage(
359-
batch,
360-
adv_estimator=self.config.algorithm.adv_estimator,
361-
gamma=self.config.algorithm.gamma,
362-
lam=self.config.algorithm.lam,
363-
num_repeat=self.config.actor_rollout_ref.rollout.n,
364-
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
365-
config=self.config.algorithm,
426+
# compute_mean_std_cross_all_data: trajectory-level advantage computation
427+
# Currently only supported for GRPO algorithm
428+
compute_mean_std_cross_all_data = self.config.algorithm.get(
429+
"compute_mean_std_cross_all_data", True
366430
)
431+
if not compute_mean_std_cross_all_data:
432+
assert self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO, (
433+
f"compute_mean_std_cross_all_data=False is only supported for GRPO, "
434+
f"got {self.config.algorithm.adv_estimator}"
435+
)
436+
437+
# Use local GRPO implementation when compute_mean_std_cross_all_data is disabled
438+
if self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO:
439+
if "response_mask" not in batch.batch:
440+
batch.batch["response_mask"] = compute_response_mask(batch)
441+
traj_index = batch.non_tensor_batch["rollout_id_list"]
442+
advantages, returns = compute_grpo_outcome_advantage(
443+
token_level_rewards=batch.batch["token_level_rewards"],
444+
response_mask=batch.batch["response_mask"],
445+
index=batch.non_tensor_batch["uid"],
446+
traj_index=traj_index,
447+
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
448+
compute_mean_std_cross_all_data=compute_mean_std_cross_all_data,
449+
)
450+
batch.batch["advantages"] = advantages
451+
batch.batch["returns"] = returns
452+
else:
453+
batch = compute_advantage(
454+
batch,
455+
adv_estimator=self.config.algorithm.adv_estimator,
456+
gamma=self.config.algorithm.gamma,
457+
lam=self.config.algorithm.lam,
458+
num_repeat=self.config.actor_rollout_ref.rollout.n,
459+
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
460+
config=self.config.algorithm,
461+
)
367462

368463
# Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details.
369464
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing"))

0 commit comments

Comments
 (0)