From 1a233964d58a20805749185be7e0c81df820596f Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 24 Nov 2025 10:49:41 +0800 Subject: [PATCH 1/3] Add `token_level_reward` to `Experience` --- tests/common/experience_test.py | 42 +++++++++++++++++ trinity/common/experience.py | 80 ++++++++++++--------------------- trinity/trainer/verl/utils.py | 26 +++++++---- trinity/trainer/verl_trainer.py | 2 +- 4 files changed, 89 insertions(+), 61 deletions(-) diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 7e4c47aa48..826c29d546 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -123,6 +123,48 @@ def test_gather(self): self.assertEqual(batch.rewards[0], 0.1) self.assertEqual(batch.rewards[1], 0.2) + def test_gather_with_token_level_reward(self): + # test empty gathering + batch = Experiences.gather_experiences([]) + self.assertEqual(batch.tokens.numel(), 0) + self.assertEqual(batch.rewards.numel(), 0) + self.assertEqual(batch.token_level_rewards.numel(), 0) + self.assertEqual(batch.eids, []) + + # test single experience gathering + exp = Experience( + tokens=torch.tensor([1, 2, 3]), + token_level_reward=torch.tensor([0, 1.0]), + prompt_length=1, + ) + batch = Experiences.gather_experiences([exp]) + self.assertEqual(batch.batch_size, 1) + self.assertTrue( + torch.equal(batch.tokens[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:]) + ) + self.assertEqual(batch.prompt_length, 1) + self.assertIsNone(batch.rewards) + self.assertTrue(torch.equal(batch.token_level_rewards[0], torch.tensor([0, 1.0]))) + + # test multiple experiences gathering + exps = [ + Experience( + tokens=torch.tensor([1, 2]), token_level_reward=torch.tensor([0.1]), prompt_length=1 + ), + Experience( + tokens=torch.tensor([3, 4, 5]), + token_level_reward=torch.tensor([0.2]), + prompt_length=2, + ), + ] + batch = Experiences.gather_experiences(exps) + self.assertEqual(batch.batch_size, 2) + self.assertEqual(batch.prompt_length, 2) + self.assertEqual(batch.tokens.shape[1], 3) + self.assertIsNone(batch.rewards) + self.assertTrue(torch.equal(batch.token_level_rewards[0], torch.tensor([0.1]))) + self.assertTrue(torch.equal(batch.token_level_rewards[1], torch.tensor([0.2]))) + def test_action_mask_and_logprobs_type(self): exp = Experience(tokens=[1, 2, 3], logprobs=[0.1, 0.2, 0.3], prompt_length=1) self.assertIsInstance(exp.tokens, torch.Tensor) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 42af635873..683d98b9a2 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -101,6 +101,7 @@ class Experience: prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks logprobs: Optional[Tensor] = None # [resp_length] reward: Optional[float] = None + token_level_reward: Optional[Tensor] = None # [resp_length] advantages: Optional[Tensor] = None # [resp_length] returns: Optional[Tensor] = None # [resp_length] info: dict = field( @@ -136,6 +137,7 @@ def __init__( # noqa: C901 tokens, logprobs=None, reward=None, + token_level_reward=None, advantages=None, returns=None, info=None, @@ -182,6 +184,9 @@ def __init__( # noqa: C901 logprobs = torch.tensor(logprobs, dtype=torch.float32) self.logprobs = logprobs self.reward = reward + if isinstance(token_level_reward, list): + token_level_reward = torch.tensor(token_level_reward, dtype=torch.float32) + self.token_level_reward = token_level_reward if isinstance(advantages, list): advantages = torch.tensor(advantages, dtype=torch.float32) self.advantages = advantages @@ -286,6 +291,14 @@ def gather( else: rewards = None + # Gather token level rewards + if all(exp.token_level_reward is not None for exp in experiences): + token_level_rewards = gather_response_attrs( + experiences, "token_level_reward", max_response_length + ) + else: + token_level_rewards = None + # gather action_masks action_masks = gather_action_masks(experiences, max_response_length) @@ -295,21 +308,20 @@ def gather( ) # gather logprobs - if all(exp.logprobs is not None for exp in experiences): - logprobs = gather_logprobs(experiences, max_response_length) + logprobs = gather_response_attrs(experiences, "logprobs", max_response_length) else: logprobs = None # gather advantages if all(exp.advantages is not None for exp in experiences): - advantages = gather_advantages(experiences, max_response_length) + advantages = gather_response_attrs(experiences, "advantages", max_response_length) else: advantages = None # gather returns if all(exp.returns is not None for exp in experiences): - returns = gather_returns(experiences, max_response_length) + returns = gather_response_attrs(experiences, "returns", max_response_length) else: returns = None @@ -323,6 +335,7 @@ def gather( eids=eids, tokens=tokens, rewards=rewards, + token_level_rewards=token_level_rewards, advantages=advantages, returns=returns, attention_masks=attention_masks, @@ -404,6 +417,7 @@ class Experiences: eids: List[EID] # Experience IDs of each experience in the batch tokens: Tensor # [batch_size, seq_length] rewards: Tensor # [batch_size] + token_level_rewards: Tensor # [batch_size, response_length] advantages: Optional[Tensor] # [batch_size, response_length] returns: Optional[Tensor] # [batch_size, response_length] attention_masks: Tensor # [batch_size, sequence_length] @@ -447,6 +461,7 @@ def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences exps = Experiences( tokens=torch.empty(0, dtype=torch.int32), rewards=torch.empty(0, dtype=torch.float32), + token_level_rewards=torch.empty(0, dtype=torch.float32), advantages=torch.empty(0, dtype=torch.float32), returns=torch.empty(0, dtype=torch.float32), attention_masks=torch.empty(0, dtype=torch.bool), @@ -522,59 +537,20 @@ def gather_attention_masks(experiences, max_prompt_length: int, max_response_len return attention_masks -def gather_logprobs(experiences, max_response_length: int) -> Tensor: - logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] - return torch.stack( - [ - torch.cat( - [ - exp.logprobs, - torch.full( - (max_response_length - len(exp.logprobs),), - 0.0, - dtype=logprob_dtype, - ), - ] - ) - for exp in experiences - ] - ) - - -def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]: - if experiences[0].advantages is None: - return None - advantages_dtype = experiences[0].advantages.dtype - return torch.stack( - [ - torch.cat( - [ - exp.advantages, - torch.full( - (max_response_length - len(exp.advantages),), - 0.0, - dtype=advantages_dtype, - ), - ] - ) - for exp in experiences - ] - ) - - -def gather_returns(experiences, max_response_length: int) -> Optional[dict[str, List[Tensor]]]: - if experiences[0].returns is None: - return None - returns_dtype = experiences[0].returns.dtype +def gather_response_attrs( + experiences, attr_name: str, max_response_length: int, pad_value: int = 0 +) -> Tensor: + dtype = getattr(experiences[0], attr_name).dtype + pad_value = torch.tensor(pad_value, dtype=dtype) return torch.stack( [ torch.cat( [ - exp.returns, + getattr(exp, attr_name), torch.full( - (max_response_length - len(exp.returns),), - 0.0, - dtype=returns_dtype, + (max_response_length - len(getattr(exp, attr_name)),), + pad_value, + dtype=dtype, ), ] ) diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 6ef793d68d..9a35eb8a29 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -1,6 +1,7 @@ """Utils for ccompatibility issues with verl.""" import os +from logging import Logger import numpy as np import torch @@ -12,7 +13,7 @@ from trinity.common.experience import Experiences -def to_data_proto(experiences: Experiences) -> DataProto: # noqa: C901 +def to_data_proto(experiences: Experiences, logger: Logger) -> DataProto: # noqa: C901 """Convert Experiences to verl DataProto.""" attention_mask = experiences.attention_masks cumsum = torch.cumsum(attention_mask, dim=-1) @@ -31,13 +32,22 @@ def to_data_proto(experiences: Experiences) -> DataProto: # noqa: C901 ), } - if experiences.rewards is not None: - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) - eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + if experiences.rewards is not None or experiences.token_level_rewards is not None: + assert experiences.logprobs is not None + if experiences.token_level_rewards is not None: + if experiences.rewards is not None: + logger.warning( + "Both experiences.rewards and experiences.token_level_rewards are provided. " + "Using experiences.token_level_rewards." + ) + token_level_rewards = experiences.token_level_rewards + else: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] batch_dict.update( { "token_level_scores": token_level_rewards, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index a2c11ff49f..e845502721 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -413,7 +413,7 @@ def upload_state_dict(self): # state dict sync self.actor_rollout_wg.upload_state_dict(self.global_steps) def train_step(self, batch: Experiences) -> Dict: # noqa C901 - batch = to_data_proto(batch) + batch = to_data_proto(batch, self.logger) batch = self.post_process_batch(batch) metrics = {} self.global_steps += 1 From 6385f90addd240fabe5a16fb1b4426d85b49eb36 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 24 Nov 2025 15:46:38 +0800 Subject: [PATCH 2/3] doc fix --- trinity/common/experience.py | 4 ++++ trinity/trainer/verl_trainer.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 683d98b9a2..4e3aa936be 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -416,8 +416,12 @@ class Experiences: eids: List[EID] # Experience IDs of each experience in the batch tokens: Tensor # [batch_size, seq_length] + + # At least one of `rewards` or `token_level_rewards` must be provided (not None). + # If both are provided, `token_level_rewards` will be used and `rewards` will be ignored. rewards: Tensor # [batch_size] token_level_rewards: Tensor # [batch_size, response_length] + advantages: Optional[Tensor] # [batch_size, response_length] returns: Optional[Tensor] # [batch_size, response_length] attention_masks: Tensor # [batch_size, sequence_length] diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 5351709721..e5fd872350 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -454,7 +454,7 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901 else: # skip token_level_scores for sft/dpo if "token_level_scores" in batch.batch.keys(): - batch.batch["token_level_scores"] = batch.batch["token_level_scores"] + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # update critic if self.algorithm.use_critic: From 88307bec3291392b0aa4d25dbef83de043884839 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 24 Nov 2025 16:31:05 +0800 Subject: [PATCH 3/3] add assert --- trinity/trainer/verl_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index e5fd872350..603f54c98e 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -454,6 +454,7 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901 else: # skip token_level_scores for sft/dpo if "token_level_scores" in batch.batch.keys(): + assert "token_level_rewards" not in batch.batch.keys() batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # update critic