-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add trajectory-level deduplication for GRPO advantage normalization #462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import random | ||
| from collections import defaultdict | ||
| from contextlib import contextmanager | ||
| from copy import deepcopy | ||
| from pprint import pprint | ||
|
|
@@ -45,6 +46,72 @@ | |
| ] | ||
|
|
||
|
|
||
| def compute_grpo_outcome_advantage( | ||
| token_level_rewards: torch.Tensor, | ||
| response_mask: torch.Tensor, | ||
| index: np.ndarray, | ||
| traj_index: np.ndarray | None = None, | ||
| epsilon: float = 1e-6, | ||
| norm_adv_by_std_in_grpo: bool = True, | ||
| compute_mean_std_cross_all_data: bool = True, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Compute advantage for GRPO with trajectory-level deduplication support. | ||
|
|
||
| This is a minimal extension of VeRL's GRPO implementation, adding support for | ||
| trajectory-level deduplication via `traj_index` and `compute_mean_std_cross_all_data`. | ||
|
|
||
| Args: | ||
| token_level_rewards: Shape (bs, response_length). | ||
| response_mask: Shape (bs, response_length). | ||
| index: Group index array (e.g., data_id). | ||
| traj_index: Trajectory index array (e.g., rollout_id). If None, no deduplication. | ||
| epsilon: Small value for numerical stability. | ||
| norm_adv_by_std_in_grpo: If True, normalize by std (original GRPO). If False, Dr.GRPO style. | ||
| compute_mean_std_cross_all_data: If True (default), compute mean/std across all data. | ||
| If False, compute mean/std per unique (index, traj_index) trajectory. | ||
|
|
||
| Returns: | ||
| Tuple of (advantages, returns), both shape (bs, response_length). | ||
| """ | ||
| scores = token_level_rewards.sum(dim=-1) | ||
|
|
||
| id2score: dict = defaultdict(list) | ||
| id2mean: dict = {} | ||
| id2std: dict = {} | ||
| seen_pairs: set = set() | ||
|
|
||
| with torch.no_grad(): | ||
| bsz = scores.shape[0] | ||
| for i in range(bsz): | ||
| # Trajectory deduplication: skip if (index, traj_index) already seen | ||
| if traj_index is not None and (index[i], traj_index[i]) in seen_pairs: | ||
| continue | ||
| id2score[index[i]].append(scores[i]) | ||
| # Mark as seen only when compute_mean_std_cross_all_data is False | ||
| if traj_index is not None and not compute_mean_std_cross_all_data: | ||
| seen_pairs.add((index[i], traj_index[i])) | ||
|
|
||
| for idx in id2score: | ||
| if len(id2score[idx]) == 1: | ||
| id2mean[idx] = torch.tensor(0.0) | ||
| id2std[idx] = torch.tensor(1.0) | ||
zzjweb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| elif len(id2score[idx]) > 1: | ||
| scores_tensor = torch.stack(id2score[idx]) | ||
| id2mean[idx] = torch.mean(scores_tensor) | ||
| id2std[idx] = torch.std(scores_tensor) | ||
zzjweb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| raise ValueError(f"no score in prompt index: {idx}") | ||
|
||
|
|
||
| for i in range(bsz): | ||
| if norm_adv_by_std_in_grpo: | ||
| scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) | ||
| else: | ||
| scores[i] = scores[i] - id2mean[index[i]] | ||
|
Comment on lines
+85
to
+109
|
||
| scores = scores.unsqueeze(-1) * response_mask | ||
zzjweb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return scores, scores | ||
|
|
||
|
Comment on lines
+49
to
+113
|
||
|
|
||
| @contextmanager | ||
| def _timer(name: str, timing_raw: Dict[str, float]): | ||
| with Timer(name=name, logger=None) as timer: | ||
|
|
@@ -355,15 +422,40 @@ def _train_step(self, batch_dict: dict) -> dict: | |
| "norm_adv_by_std_in_grpo", True | ||
| ) # GRPO adv normalization factor | ||
|
|
||
| batch = compute_advantage( | ||
| batch, | ||
| adv_estimator=self.config.algorithm.adv_estimator, | ||
| gamma=self.config.algorithm.gamma, | ||
| lam=self.config.algorithm.lam, | ||
| num_repeat=self.config.actor_rollout_ref.rollout.n, | ||
| norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, | ||
| config=self.config.algorithm, | ||
| ) | ||
| # compute_mean_std_cross_all_data: trajectory-level advantage computation | ||
| # Currently only supported for GRPO algorithm | ||
| compute_mean_std_cross_all_data = self.config.algorithm.get("compute_mean_std_cross_all_data", True) | ||
| if not compute_mean_std_cross_all_data: | ||
| assert self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO, ( | ||
| f"compute_mean_std_cross_all_data=False is only supported for GRPO, " | ||
| f"got {self.config.algorithm.adv_estimator}" | ||
| ) | ||
|
Comment on lines
+428
to
+432
|
||
|
|
||
| # Use local GRPO implementation when compute_mean_std_cross_all_data is disabled | ||
zzjweb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO: | ||
| if "response_mask" not in batch.batch: | ||
| batch.batch["response_mask"] = compute_response_mask(batch) | ||
| traj_index = batch.non_tensor_batch["rollout_id_list"] | ||
| advantages, returns = compute_grpo_outcome_advantage( | ||
| token_level_rewards=batch.batch["token_level_rewards"], | ||
| response_mask=batch.batch["response_mask"], | ||
| index=batch.non_tensor_batch["uid"], | ||
| traj_index=traj_index, | ||
| norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, | ||
| compute_mean_std_cross_all_data=compute_mean_std_cross_all_data, | ||
| ) | ||
| batch.batch["advantages"] = advantages | ||
| batch.batch["returns"] = returns | ||
| else: | ||
| batch = compute_advantage( | ||
| batch, | ||
| adv_estimator=self.config.algorithm.adv_estimator, | ||
| gamma=self.config.algorithm.gamma, | ||
| lam=self.config.algorithm.lam, | ||
| num_repeat=self.config.actor_rollout_ref.rollout.n, | ||
| norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, | ||
| config=self.config.algorithm, | ||
| ) | ||
|
|
||
| # Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details. | ||
| metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing")) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.