Skip to content

Commit 01423d6

Browse files
authored
Add token_level_reward to Experience (#404)
1 parent b2282ae commit 01423d6

File tree

4 files changed

+95
-62
lines changed

4 files changed

+95
-62
lines changed

tests/common/experience_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,48 @@ def test_gather(self):
123123
self.assertEqual(batch.rewards[0], 0.1)
124124
self.assertEqual(batch.rewards[1], 0.2)
125125

126+
def test_gather_with_token_level_reward(self):
127+
# test empty gathering
128+
batch = Experiences.gather_experiences([])
129+
self.assertEqual(batch.tokens.numel(), 0)
130+
self.assertEqual(batch.rewards.numel(), 0)
131+
self.assertEqual(batch.token_level_rewards.numel(), 0)
132+
self.assertEqual(batch.eids, [])
133+
134+
# test single experience gathering
135+
exp = Experience(
136+
tokens=torch.tensor([1, 2, 3]),
137+
token_level_reward=torch.tensor([0, 1.0]),
138+
prompt_length=1,
139+
)
140+
batch = Experiences.gather_experiences([exp])
141+
self.assertEqual(batch.batch_size, 1)
142+
self.assertTrue(
143+
torch.equal(batch.tokens[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:])
144+
)
145+
self.assertEqual(batch.prompt_length, 1)
146+
self.assertIsNone(batch.rewards)
147+
self.assertTrue(torch.equal(batch.token_level_rewards[0], torch.tensor([0, 1.0])))
148+
149+
# test multiple experiences gathering
150+
exps = [
151+
Experience(
152+
tokens=torch.tensor([1, 2]), token_level_reward=torch.tensor([0.1]), prompt_length=1
153+
),
154+
Experience(
155+
tokens=torch.tensor([3, 4, 5]),
156+
token_level_reward=torch.tensor([0.2]),
157+
prompt_length=2,
158+
),
159+
]
160+
batch = Experiences.gather_experiences(exps)
161+
self.assertEqual(batch.batch_size, 2)
162+
self.assertEqual(batch.prompt_length, 2)
163+
self.assertEqual(batch.tokens.shape[1], 3)
164+
self.assertIsNone(batch.rewards)
165+
self.assertTrue(torch.equal(batch.token_level_rewards[0], torch.tensor([0.1])))
166+
self.assertTrue(torch.equal(batch.token_level_rewards[1], torch.tensor([0.2])))
167+
126168
def test_action_mask_and_logprobs_type(self):
127169
exp = Experience(tokens=[1, 2, 3], logprobs=[0.1, 0.2, 0.3], prompt_length=1)
128170
self.assertIsInstance(exp.tokens, torch.Tensor)

trinity/common/experience.py

Lines changed: 32 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class Experience:
101101
prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks
102102
logprobs: Optional[Tensor] = None # [resp_length]
103103
reward: Optional[float] = None
104+
token_level_reward: Optional[Tensor] = None # [resp_length]
104105
advantages: Optional[Tensor] = None # [resp_length]
105106
returns: Optional[Tensor] = None # [resp_length]
106107
info: dict = field(
@@ -136,6 +137,7 @@ def __init__( # noqa: C901
136137
tokens,
137138
logprobs=None,
138139
reward=None,
140+
token_level_reward=None,
139141
advantages=None,
140142
returns=None,
141143
info=None,
@@ -182,6 +184,9 @@ def __init__( # noqa: C901
182184
logprobs = torch.tensor(logprobs, dtype=torch.float32)
183185
self.logprobs = logprobs
184186
self.reward = reward
187+
if isinstance(token_level_reward, list):
188+
token_level_reward = torch.tensor(token_level_reward, dtype=torch.float32)
189+
self.token_level_reward = token_level_reward
185190
if isinstance(advantages, list):
186191
advantages = torch.tensor(advantages, dtype=torch.float32)
187192
self.advantages = advantages
@@ -286,6 +291,14 @@ def gather(
286291
else:
287292
rewards = None
288293

294+
# Gather token level rewards
295+
if all(exp.token_level_reward is not None for exp in experiences):
296+
token_level_rewards = gather_response_attrs(
297+
experiences, "token_level_reward", max_response_length
298+
)
299+
else:
300+
token_level_rewards = None
301+
289302
# gather action_masks
290303
action_masks = gather_action_masks(experiences, max_response_length)
291304

@@ -295,21 +308,20 @@ def gather(
295308
)
296309

297310
# gather logprobs
298-
299311
if all(exp.logprobs is not None for exp in experiences):
300-
logprobs = gather_logprobs(experiences, max_response_length)
312+
logprobs = gather_response_attrs(experiences, "logprobs", max_response_length)
301313
else:
302314
logprobs = None
303315

304316
# gather advantages
305317
if all(exp.advantages is not None for exp in experiences):
306-
advantages = gather_advantages(experiences, max_response_length)
318+
advantages = gather_response_attrs(experiences, "advantages", max_response_length)
307319
else:
308320
advantages = None
309321

310322
# gather returns
311323
if all(exp.returns is not None for exp in experiences):
312-
returns = gather_returns(experiences, max_response_length)
324+
returns = gather_response_attrs(experiences, "returns", max_response_length)
313325
else:
314326
returns = None
315327

@@ -323,6 +335,7 @@ def gather(
323335
eids=eids,
324336
tokens=tokens,
325337
rewards=rewards,
338+
token_level_rewards=token_level_rewards,
326339
advantages=advantages,
327340
returns=returns,
328341
attention_masks=attention_masks,
@@ -403,7 +416,12 @@ class Experiences:
403416

404417
eids: List[EID] # Experience IDs of each experience in the batch
405418
tokens: Tensor # [batch_size, seq_length]
419+
420+
# At least one of `rewards` or `token_level_rewards` must be provided (not None).
421+
# If both are provided, `token_level_rewards` will be used and `rewards` will be ignored.
406422
rewards: Tensor # [batch_size]
423+
token_level_rewards: Tensor # [batch_size, response_length]
424+
407425
advantages: Optional[Tensor] # [batch_size, response_length]
408426
returns: Optional[Tensor] # [batch_size, response_length]
409427
attention_masks: Tensor # [batch_size, sequence_length]
@@ -447,6 +465,7 @@ def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences
447465
exps = Experiences(
448466
tokens=torch.empty(0, dtype=torch.int32),
449467
rewards=torch.empty(0, dtype=torch.float32),
468+
token_level_rewards=torch.empty(0, dtype=torch.float32),
450469
advantages=torch.empty(0, dtype=torch.float32),
451470
returns=torch.empty(0, dtype=torch.float32),
452471
attention_masks=torch.empty(0, dtype=torch.bool),
@@ -522,59 +541,20 @@ def gather_attention_masks(experiences, max_prompt_length: int, max_response_len
522541
return attention_masks
523542

524543

525-
def gather_logprobs(experiences, max_response_length: int) -> Tensor:
526-
logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr]
527-
return torch.stack(
528-
[
529-
torch.cat(
530-
[
531-
exp.logprobs,
532-
torch.full(
533-
(max_response_length - len(exp.logprobs),),
534-
0.0,
535-
dtype=logprob_dtype,
536-
),
537-
]
538-
)
539-
for exp in experiences
540-
]
541-
)
542-
543-
544-
def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]:
545-
if experiences[0].advantages is None:
546-
return None
547-
advantages_dtype = experiences[0].advantages.dtype
548-
return torch.stack(
549-
[
550-
torch.cat(
551-
[
552-
exp.advantages,
553-
torch.full(
554-
(max_response_length - len(exp.advantages),),
555-
0.0,
556-
dtype=advantages_dtype,
557-
),
558-
]
559-
)
560-
for exp in experiences
561-
]
562-
)
563-
564-
565-
def gather_returns(experiences, max_response_length: int) -> Optional[dict[str, List[Tensor]]]:
566-
if experiences[0].returns is None:
567-
return None
568-
returns_dtype = experiences[0].returns.dtype
544+
def gather_response_attrs(
545+
experiences, attr_name: str, max_response_length: int, pad_value: int = 0
546+
) -> Tensor:
547+
dtype = getattr(experiences[0], attr_name).dtype
548+
pad_value = torch.tensor(pad_value, dtype=dtype)
569549
return torch.stack(
570550
[
571551
torch.cat(
572552
[
573-
exp.returns,
553+
getattr(exp, attr_name),
574554
torch.full(
575-
(max_response_length - len(exp.returns),),
576-
0.0,
577-
dtype=returns_dtype,
555+
(max_response_length - len(getattr(exp, attr_name)),),
556+
pad_value,
557+
dtype=dtype,
578558
),
579559
]
580560
)

trinity/trainer/verl/utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utils for ccompatibility issues with verl."""
22

33
import os
4+
from logging import Logger
45

56
import numpy as np
67
import torch
@@ -12,7 +13,7 @@
1213
from trinity.common.experience import Experiences
1314

1415

15-
def to_data_proto(experiences: Experiences) -> DataProto: # noqa: C901
16+
def to_data_proto(experiences: Experiences, logger: Logger) -> DataProto: # noqa: C901
1617
"""Convert Experiences to verl DataProto."""
1718
attention_mask = experiences.attention_masks
1819
cumsum = torch.cumsum(attention_mask, dim=-1)
@@ -31,13 +32,22 @@ def to_data_proto(experiences: Experiences) -> DataProto: # noqa: C901
3132
),
3233
}
3334

34-
if experiences.rewards is not None:
35-
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
36-
eos_mask_idx = cumsum.argmax(dim=-1)
37-
token_level_rewards[
38-
torch.arange(experiences.batch_size), eos_mask_idx
39-
] = experiences.rewards
40-
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
35+
if experiences.rewards is not None or experiences.token_level_rewards is not None:
36+
assert experiences.logprobs is not None
37+
if experiences.token_level_rewards is not None:
38+
if experiences.rewards is not None:
39+
logger.warning(
40+
"Both experiences.rewards and experiences.token_level_rewards are provided. "
41+
"Using experiences.token_level_rewards."
42+
)
43+
token_level_rewards = experiences.token_level_rewards
44+
else:
45+
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
46+
eos_mask_idx = cumsum.argmax(dim=-1)
47+
token_level_rewards[
48+
torch.arange(experiences.batch_size), eos_mask_idx
49+
] = experiences.rewards
50+
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
4151
batch_dict.update(
4252
{
4353
"token_level_scores": token_level_rewards,

trinity/trainer/verl_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def upload_state_dict(self): # state dict sync
413413
self.actor_rollout_wg.upload_state_dict(self.global_steps)
414414

415415
def train_step(self, batch: Experiences) -> Dict: # noqa C901
416-
batch = to_data_proto(batch)
416+
batch = to_data_proto(batch, self.logger)
417417
batch = self.post_process_batch(batch)
418418
metrics = {}
419419
self.global_steps += 1
@@ -454,7 +454,8 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901
454454
else:
455455
# skip token_level_scores for sft/dpo
456456
if "token_level_scores" in batch.batch.keys():
457-
batch.batch["token_level_scores"] = batch.batch["token_level_scores"]
457+
assert "token_level_rewards" not in batch.batch.keys()
458+
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
458459

459460
# update critic
460461
if self.algorithm.use_critic:

0 commit comments

Comments
 (0)