Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/common/experience_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,48 @@ def test_gather(self):
self.assertEqual(batch.rewards[0], 0.1)
self.assertEqual(batch.rewards[1], 0.2)

def test_gather_with_token_level_reward(self):
# test empty gathering
batch = Experiences.gather_experiences([])
self.assertEqual(batch.tokens.numel(), 0)
self.assertEqual(batch.rewards.numel(), 0)
self.assertEqual(batch.token_level_rewards.numel(), 0)
self.assertEqual(batch.eids, [])

# test single experience gathering
exp = Experience(
tokens=torch.tensor([1, 2, 3]),
token_level_reward=torch.tensor([0, 1.0]),
prompt_length=1,
)
batch = Experiences.gather_experiences([exp])
self.assertEqual(batch.batch_size, 1)
self.assertTrue(
torch.equal(batch.tokens[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:])
)
self.assertEqual(batch.prompt_length, 1)
self.assertIsNone(batch.rewards)
self.assertTrue(torch.equal(batch.token_level_rewards[0], torch.tensor([0, 1.0])))

# test multiple experiences gathering
exps = [
Experience(
tokens=torch.tensor([1, 2]), token_level_reward=torch.tensor([0.1]), prompt_length=1
),
Experience(
tokens=torch.tensor([3, 4, 5]),
token_level_reward=torch.tensor([0.2]),
prompt_length=2,
),
]
batch = Experiences.gather_experiences(exps)
self.assertEqual(batch.batch_size, 2)
self.assertEqual(batch.prompt_length, 2)
self.assertEqual(batch.tokens.shape[1], 3)
self.assertIsNone(batch.rewards)
self.assertTrue(torch.equal(batch.token_level_rewards[0], torch.tensor([0.1])))
self.assertTrue(torch.equal(batch.token_level_rewards[1], torch.tensor([0.2])))

def test_action_mask_and_logprobs_type(self):
exp = Experience(tokens=[1, 2, 3], logprobs=[0.1, 0.2, 0.3], prompt_length=1)
self.assertIsInstance(exp.tokens, torch.Tensor)
Expand Down
84 changes: 32 additions & 52 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Experience:
prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks
logprobs: Optional[Tensor] = None # [resp_length]
reward: Optional[float] = None
token_level_reward: Optional[Tensor] = None # [resp_length]
advantages: Optional[Tensor] = None # [resp_length]
returns: Optional[Tensor] = None # [resp_length]
info: dict = field(
Expand Down Expand Up @@ -136,6 +137,7 @@ def __init__( # noqa: C901
tokens,
logprobs=None,
reward=None,
token_level_reward=None,
advantages=None,
returns=None,
info=None,
Expand Down Expand Up @@ -182,6 +184,9 @@ def __init__( # noqa: C901
logprobs = torch.tensor(logprobs, dtype=torch.float32)
self.logprobs = logprobs
self.reward = reward
if isinstance(token_level_reward, list):
token_level_reward = torch.tensor(token_level_reward, dtype=torch.float32)
self.token_level_reward = token_level_reward
if isinstance(advantages, list):
advantages = torch.tensor(advantages, dtype=torch.float32)
self.advantages = advantages
Expand Down Expand Up @@ -286,6 +291,14 @@ def gather(
else:
rewards = None

# Gather token level rewards
if all(exp.token_level_reward is not None for exp in experiences):
token_level_rewards = gather_response_attrs(
experiences, "token_level_reward", max_response_length
)
else:
token_level_rewards = None

# gather action_masks
action_masks = gather_action_masks(experiences, max_response_length)

Expand All @@ -295,21 +308,20 @@ def gather(
)

# gather logprobs

if all(exp.logprobs is not None for exp in experiences):
logprobs = gather_logprobs(experiences, max_response_length)
logprobs = gather_response_attrs(experiences, "logprobs", max_response_length)
else:
logprobs = None

# gather advantages
if all(exp.advantages is not None for exp in experiences):
advantages = gather_advantages(experiences, max_response_length)
advantages = gather_response_attrs(experiences, "advantages", max_response_length)
else:
advantages = None

# gather returns
if all(exp.returns is not None for exp in experiences):
returns = gather_returns(experiences, max_response_length)
returns = gather_response_attrs(experiences, "returns", max_response_length)
else:
returns = None

Expand All @@ -323,6 +335,7 @@ def gather(
eids=eids,
tokens=tokens,
rewards=rewards,
token_level_rewards=token_level_rewards,
advantages=advantages,
returns=returns,
attention_masks=attention_masks,
Expand Down Expand Up @@ -403,7 +416,12 @@ class Experiences:

eids: List[EID] # Experience IDs of each experience in the batch
tokens: Tensor # [batch_size, seq_length]

# At least one of `rewards` or `token_level_rewards` must be provided (not None).
# If both are provided, `token_level_rewards` will be used and `rewards` will be ignored.
rewards: Tensor # [batch_size]
token_level_rewards: Tensor # [batch_size, response_length]

advantages: Optional[Tensor] # [batch_size, response_length]
returns: Optional[Tensor] # [batch_size, response_length]
attention_masks: Tensor # [batch_size, sequence_length]
Expand Down Expand Up @@ -447,6 +465,7 @@ def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences
exps = Experiences(
tokens=torch.empty(0, dtype=torch.int32),
rewards=torch.empty(0, dtype=torch.float32),
token_level_rewards=torch.empty(0, dtype=torch.float32),
advantages=torch.empty(0, dtype=torch.float32),
returns=torch.empty(0, dtype=torch.float32),
attention_masks=torch.empty(0, dtype=torch.bool),
Expand Down Expand Up @@ -522,59 +541,20 @@ def gather_attention_masks(experiences, max_prompt_length: int, max_response_len
return attention_masks


def gather_logprobs(experiences, max_response_length: int) -> Tensor:
logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr]
return torch.stack(
[
torch.cat(
[
exp.logprobs,
torch.full(
(max_response_length - len(exp.logprobs),),
0.0,
dtype=logprob_dtype,
),
]
)
for exp in experiences
]
)


def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]:
if experiences[0].advantages is None:
return None
advantages_dtype = experiences[0].advantages.dtype
return torch.stack(
[
torch.cat(
[
exp.advantages,
torch.full(
(max_response_length - len(exp.advantages),),
0.0,
dtype=advantages_dtype,
),
]
)
for exp in experiences
]
)


def gather_returns(experiences, max_response_length: int) -> Optional[dict[str, List[Tensor]]]:
if experiences[0].returns is None:
return None
returns_dtype = experiences[0].returns.dtype
def gather_response_attrs(
experiences, attr_name: str, max_response_length: int, pad_value: int = 0
) -> Tensor:
dtype = getattr(experiences[0], attr_name).dtype
pad_value = torch.tensor(pad_value, dtype=dtype)
return torch.stack(
[
torch.cat(
[
exp.returns,
getattr(exp, attr_name),
torch.full(
(max_response_length - len(exp.returns),),
0.0,
dtype=returns_dtype,
(max_response_length - len(getattr(exp, attr_name)),),
pad_value,
dtype=dtype,
),
]
)
Expand Down
26 changes: 18 additions & 8 deletions trinity/trainer/verl/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utils for ccompatibility issues with verl."""

import os
from logging import Logger

import numpy as np
import torch
Expand All @@ -12,7 +13,7 @@
from trinity.common.experience import Experiences


def to_data_proto(experiences: Experiences) -> DataProto: # noqa: C901
def to_data_proto(experiences: Experiences, logger: Logger) -> DataProto: # noqa: C901
"""Convert Experiences to verl DataProto."""
attention_mask = experiences.attention_masks
cumsum = torch.cumsum(attention_mask, dim=-1)
Expand All @@ -31,13 +32,22 @@ def to_data_proto(experiences: Experiences) -> DataProto: # noqa: C901
),
}

if experiences.rewards is not None:
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
eos_mask_idx = cumsum.argmax(dim=-1)
token_level_rewards[
torch.arange(experiences.batch_size), eos_mask_idx
] = experiences.rewards
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
if experiences.rewards is not None or experiences.token_level_rewards is not None:
assert experiences.logprobs is not None
if experiences.token_level_rewards is not None:
if experiences.rewards is not None:
logger.warning(
"Both experiences.rewards and experiences.token_level_rewards are provided. "
"Using experiences.token_level_rewards."
)
token_level_rewards = experiences.token_level_rewards
else:
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
eos_mask_idx = cumsum.argmax(dim=-1)
token_level_rewards[
torch.arange(experiences.batch_size), eos_mask_idx
] = experiences.rewards
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
batch_dict.update(
{
"token_level_scores": token_level_rewards,
Expand Down
5 changes: 3 additions & 2 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def upload_state_dict(self): # state dict sync
self.actor_rollout_wg.upload_state_dict(self.global_steps)

def train_step(self, batch: Experiences) -> Dict: # noqa C901
batch = to_data_proto(batch)
batch = to_data_proto(batch, self.logger)
batch = self.post_process_batch(batch)
metrics = {}
self.global_steps += 1
Expand Down Expand Up @@ -454,7 +454,8 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901
else:
# skip token_level_scores for sft/dpo
if "token_level_scores" in batch.batch.keys():
batch.batch["token_level_scores"] = batch.batch["token_level_scores"]
assert "token_level_rewards" not in batch.batch.keys()
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

# update critic
if self.algorithm.use_critic:
Expand Down