diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index d814aedd7..ee60a0008 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -1051,6 +1051,7 @@ def compute_grpo_outcome_advantage( """ # this assumes response-level rewards scores = token_level_rewards.sum(dim=-1) + returns = scores.clone() id2score = defaultdict(list) id2mean = {} @@ -1075,8 +1076,9 @@ def compute_grpo_outcome_advantage( else: scores[i] = scores[i] - id2mean[index[i]] scores = scores.unsqueeze(-1) * response_mask + returns = returns.unsqueeze(-1) * response_mask - return scores, scores + return scores, returns def repopulate_all_registries():