|
24 | 24 | compute_throughout_metrics, |
25 | 25 | compute_timing_metrics, |
26 | 26 | ) |
| 27 | +from collections import defaultdict |
| 28 | + |
27 | 29 | from verl.trainer.ppo.ray_trainer import ( |
28 | 30 | AdvantageEstimator, |
29 | 31 | RayPPOTrainer, |
|
45 | 47 | ] |
46 | 48 |
|
47 | 49 |
|
| 50 | +def compute_grpo_outcome_advantage( |
| 51 | + token_level_rewards: torch.Tensor, |
| 52 | + response_mask: torch.Tensor, |
| 53 | + index: np.ndarray, |
| 54 | + traj_index: np.ndarray | None = None, |
| 55 | + epsilon: float = 1e-6, |
| 56 | + norm_adv_by_std_in_grpo: bool = True, |
| 57 | + compute_mean_std_cross_all_data: bool = True, |
| 58 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 59 | + """Compute advantage for GRPO with trajectory-level deduplication support. |
| 60 | +
|
| 61 | + This is a minimal extension of VeRL's GRPO implementation, adding support for |
| 62 | + trajectory-level deduplication via `traj_index` and `compute_mean_std_cross_all_data`. |
| 63 | +
|
| 64 | + Args: |
| 65 | + token_level_rewards: Shape (bs, response_length). |
| 66 | + response_mask: Shape (bs, response_length). |
| 67 | + index: Group index array (e.g., data_id). |
| 68 | + traj_index: Trajectory index array (e.g., rollout_id). If None, no deduplication. |
| 69 | + epsilon: Small value for numerical stability. |
| 70 | + norm_adv_by_std_in_grpo: If True, normalize by std (original GRPO). If False, Dr.GRPO style. |
| 71 | + compute_mean_std_cross_all_data: If True (default), compute mean/std across all data. |
| 72 | + If False, compute mean/std per unique (index, traj_index) trajectory. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + Tuple of (advantages, returns), both shape (bs, response_length). |
| 76 | + """ |
| 77 | + scores = token_level_rewards.sum(dim=-1) |
| 78 | + |
| 79 | + id2score: dict = defaultdict(list) |
| 80 | + id2mean: dict = {} |
| 81 | + id2std: dict = {} |
| 82 | + seen_pairs: set = set() |
| 83 | + |
| 84 | + with torch.no_grad(): |
| 85 | + bsz = scores.shape[0] |
| 86 | + for i in range(bsz): |
| 87 | + # Trajectory deduplication: skip if (index, traj_index) already seen |
| 88 | + if traj_index is not None and (index[i], traj_index[i]) in seen_pairs: |
| 89 | + continue |
| 90 | + id2score[index[i]].append(scores[i]) |
| 91 | + # Mark as seen only when compute_mean_std_cross_all_data is False |
| 92 | + if traj_index is not None and not compute_mean_std_cross_all_data: |
| 93 | + seen_pairs.add((index[i], traj_index[i])) |
| 94 | + |
| 95 | + for idx in id2score: |
| 96 | + if len(id2score[idx]) == 1: |
| 97 | + id2mean[idx] = torch.tensor(0.0) |
| 98 | + id2std[idx] = torch.tensor(1.0) |
| 99 | + elif len(id2score[idx]) > 1: |
| 100 | + scores_tensor = torch.stack(id2score[idx]) |
| 101 | + id2mean[idx] = torch.mean(scores_tensor) |
| 102 | + id2std[idx] = torch.std(scores_tensor) |
| 103 | + else: |
| 104 | + raise ValueError(f"no score in prompt index: {idx}") |
| 105 | + |
| 106 | + for i in range(bsz): |
| 107 | + if norm_adv_by_std_in_grpo: |
| 108 | + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) |
| 109 | + else: |
| 110 | + scores[i] = scores[i] - id2mean[index[i]] |
| 111 | + scores = scores.unsqueeze(-1) * response_mask |
| 112 | + |
| 113 | + return scores, scores |
| 114 | + |
| 115 | + |
48 | 116 | @contextmanager |
49 | 117 | def _timer(name: str, timing_raw: Dict[str, float]): |
50 | 118 | with Timer(name=name, logger=None) as timer: |
@@ -355,15 +423,42 @@ def _train_step(self, batch_dict: dict) -> dict: |
355 | 423 | "norm_adv_by_std_in_grpo", True |
356 | 424 | ) # GRPO adv normalization factor |
357 | 425 |
|
358 | | - batch = compute_advantage( |
359 | | - batch, |
360 | | - adv_estimator=self.config.algorithm.adv_estimator, |
361 | | - gamma=self.config.algorithm.gamma, |
362 | | - lam=self.config.algorithm.lam, |
363 | | - num_repeat=self.config.actor_rollout_ref.rollout.n, |
364 | | - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, |
365 | | - config=self.config.algorithm, |
| 426 | + # compute_mean_std_cross_all_data: trajectory-level advantage computation |
| 427 | + # Currently only supported for GRPO algorithm |
| 428 | + compute_mean_std_cross_all_data = self.config.algorithm.get( |
| 429 | + "compute_mean_std_cross_all_data", True |
366 | 430 | ) |
| 431 | + if not compute_mean_std_cross_all_data: |
| 432 | + assert self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO, ( |
| 433 | + f"compute_mean_std_cross_all_data=False is only supported for GRPO, " |
| 434 | + f"got {self.config.algorithm.adv_estimator}" |
| 435 | + ) |
| 436 | + |
| 437 | + # Use local GRPO implementation when compute_mean_std_cross_all_data is disabled |
| 438 | + if self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO: |
| 439 | + if "response_mask" not in batch.batch: |
| 440 | + batch.batch["response_mask"] = compute_response_mask(batch) |
| 441 | + traj_index = batch.non_tensor_batch["rollout_id_list"] |
| 442 | + advantages, returns = compute_grpo_outcome_advantage( |
| 443 | + token_level_rewards=batch.batch["token_level_rewards"], |
| 444 | + response_mask=batch.batch["response_mask"], |
| 445 | + index=batch.non_tensor_batch["uid"], |
| 446 | + traj_index=traj_index, |
| 447 | + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, |
| 448 | + compute_mean_std_cross_all_data=compute_mean_std_cross_all_data, |
| 449 | + ) |
| 450 | + batch.batch["advantages"] = advantages |
| 451 | + batch.batch["returns"] = returns |
| 452 | + else: |
| 453 | + batch = compute_advantage( |
| 454 | + batch, |
| 455 | + adv_estimator=self.config.algorithm.adv_estimator, |
| 456 | + gamma=self.config.algorithm.gamma, |
| 457 | + lam=self.config.algorithm.lam, |
| 458 | + num_repeat=self.config.actor_rollout_ref.rollout.n, |
| 459 | + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, |
| 460 | + config=self.config.algorithm, |
| 461 | + ) |
367 | 462 |
|
368 | 463 | # Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details. |
369 | 464 | metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing")) |
|
0 commit comments