Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 101 additions & 9 deletions agentlightning/verl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import random
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from pprint import pprint
Expand Down Expand Up @@ -45,6 +46,72 @@
]


def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
traj_index: np.ndarray | None = None,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
compute_mean_std_cross_all_data: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute advantage for GRPO with trajectory-level deduplication support.

This is a minimal extension of VeRL's GRPO implementation, adding support for
trajectory-level deduplication via `traj_index` and `compute_mean_std_cross_all_data`.

Args:
token_level_rewards: Shape (bs, response_length).
response_mask: Shape (bs, response_length).
index: Group index array (e.g., data_id).
traj_index: Trajectory index array (e.g., rollout_id). If None, no deduplication.
epsilon: Small value for numerical stability.
norm_adv_by_std_in_grpo: If True, normalize by std (original GRPO). If False, Dr.GRPO style.
compute_mean_std_cross_all_data: If True (default), compute mean/std across all data.
If False, compute mean/std per unique (index, traj_index) trajectory.

Returns:
Tuple of (advantages, returns), both shape (bs, response_length).
"""
scores = token_level_rewards.sum(dim=-1)

id2score: dict = defaultdict(list)
id2mean: dict = {}
id2std: dict = {}
seen_pairs: set = set()

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
# Trajectory deduplication: skip if (index, traj_index) already seen
if traj_index is not None and (index[i], traj_index[i]) in seen_pairs:
continue
id2score[index[i]].append(scores[i])
# Mark as seen only when compute_mean_std_cross_all_data is False
if traj_index is not None and not compute_mean_std_cross_all_data:
seen_pairs.add((index[i], traj_index[i]))

for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message uses f-string formatting but doesn't include the idx variable value. The message should be updated to include the actual index value that's causing the issue for better debugging:

raise ValueError(f"no score in prompt index: {idx}")

should ensure the value is actually included in the error output.

Copilot uses AI. Check for mistakes.

for i in range(bsz):
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
else:
scores[i] = scores[i] - id2mean[index[i]]
Comment on lines +85 to +109
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function accepts index as np.ndarray but uses it directly to index into dictionaries (lines 90, 108, 110). In Python dictionaries, NumPy array elements may not hash correctly depending on their dtype. If index contains NumPy scalars (e.g., np.int64), this could cause issues.

Consider converting array elements to Python native types when using them as dictionary keys:

idx_key = int(index[i])
id2score[idx_key].append(scores[i])

Or document that index must contain hashable types that work as dictionary keys.

Copilot uses AI. Check for mistakes.
scores = scores.unsqueeze(-1) * response_mask

return scores, scores

Comment on lines +49 to +113
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new compute_grpo_outcome_advantage function lacks test coverage. Given that this is a critical mathematical computation affecting training outcomes, unit tests should be added to verify:

  1. Correct behavior when compute_mean_std_cross_all_data=True vs False
  2. Proper handling of trajectory deduplication with different (index, traj_index) combinations
  3. Device consistency (tensors on GPU)
  4. Edge cases: single-sample groups, all identical scores, etc.
  5. Correct advantage normalization with and without std division

Consider adding tests in tests/trainer/ directory or a new test file specifically for GRPO advantage computation.

Copilot uses AI. Check for mistakes.

@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
Expand Down Expand Up @@ -355,15 +422,40 @@ def _train_step(self, batch_dict: dict) -> dict:
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor

batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)
# compute_mean_std_cross_all_data: trajectory-level advantage computation
# Currently only supported for GRPO algorithm
compute_mean_std_cross_all_data = self.config.algorithm.get("compute_mean_std_cross_all_data", True)
if not compute_mean_std_cross_all_data:
assert self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO, (
f"compute_mean_std_cross_all_data=False is only supported for GRPO, "
f"got {self.config.algorithm.adv_estimator}"
)
Comment on lines +428 to +432
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion on lines 432-435 only checks when compute_mean_std_cross_all_data=False, but the new GRPO implementation is used for ALL GRPO cases (line 438 condition). This means when compute_mean_std_cross_all_data=True with a non-GRPO estimator, the assertion is never checked, but the code would still go through the else branch at line 452.

While this is not necessarily incorrect (the else branch handles non-GRPO cases properly), the control flow could be clearer. Consider restructuring to make the relationship between the flag and the GRPO check more explicit, or add a comment explaining why the assertion only needs to check the False case.

Copilot uses AI. Check for mistakes.

# Use local GRPO implementation when compute_mean_std_cross_all_data is disabled
if self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO:
if "response_mask" not in batch.batch:
batch.batch["response_mask"] = compute_response_mask(batch)
traj_index = batch.non_tensor_batch["rollout_id_list"]
advantages, returns = compute_grpo_outcome_advantage(
token_level_rewards=batch.batch["token_level_rewards"],
response_mask=batch.batch["response_mask"],
index=batch.non_tensor_batch["uid"],
traj_index=traj_index,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
compute_mean_std_cross_all_data=compute_mean_std_cross_all_data,
)
batch.batch["advantages"] = advantages
batch.batch["returns"] = returns
else:
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)

# Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details.
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing"))
Expand Down