@@ -101,6 +101,7 @@ class Experience:
101101 prompt_length : int = 1 # Length of the prompt in tokens, used for generating attention masks
102102 logprobs : Optional [Tensor ] = None # [resp_length]
103103 reward : Optional [float ] = None
104+ token_level_reward : Optional [Tensor ] = None # [resp_length]
104105 advantages : Optional [Tensor ] = None # [resp_length]
105106 returns : Optional [Tensor ] = None # [resp_length]
106107 info : dict = field (
@@ -136,6 +137,7 @@ def __init__( # noqa: C901
136137 tokens ,
137138 logprobs = None ,
138139 reward = None ,
140+ token_level_reward = None ,
139141 advantages = None ,
140142 returns = None ,
141143 info = None ,
@@ -182,6 +184,9 @@ def __init__( # noqa: C901
182184 logprobs = torch .tensor (logprobs , dtype = torch .float32 )
183185 self .logprobs = logprobs
184186 self .reward = reward
187+ if isinstance (token_level_reward , list ):
188+ token_level_reward = torch .tensor (token_level_reward , dtype = torch .float32 )
189+ self .token_level_reward = token_level_reward
185190 if isinstance (advantages , list ):
186191 advantages = torch .tensor (advantages , dtype = torch .float32 )
187192 self .advantages = advantages
@@ -286,6 +291,14 @@ def gather(
286291 else :
287292 rewards = None
288293
294+ # Gather token level rewards
295+ if all (exp .token_level_reward is not None for exp in experiences ):
296+ token_level_rewards = gather_response_attrs (
297+ experiences , "token_level_reward" , max_response_length
298+ )
299+ else :
300+ token_level_rewards = None
301+
289302 # gather action_masks
290303 action_masks = gather_action_masks (experiences , max_response_length )
291304
@@ -295,21 +308,20 @@ def gather(
295308 )
296309
297310 # gather logprobs
298-
299311 if all (exp .logprobs is not None for exp in experiences ):
300- logprobs = gather_logprobs (experiences , max_response_length )
312+ logprobs = gather_response_attrs (experiences , "logprobs" , max_response_length )
301313 else :
302314 logprobs = None
303315
304316 # gather advantages
305317 if all (exp .advantages is not None for exp in experiences ):
306- advantages = gather_advantages (experiences , max_response_length )
318+ advantages = gather_response_attrs (experiences , "advantages" , max_response_length )
307319 else :
308320 advantages = None
309321
310322 # gather returns
311323 if all (exp .returns is not None for exp in experiences ):
312- returns = gather_returns (experiences , max_response_length )
324+ returns = gather_response_attrs (experiences , "returns" , max_response_length )
313325 else :
314326 returns = None
315327
@@ -323,6 +335,7 @@ def gather(
323335 eids = eids ,
324336 tokens = tokens ,
325337 rewards = rewards ,
338+ token_level_rewards = token_level_rewards ,
326339 advantages = advantages ,
327340 returns = returns ,
328341 attention_masks = attention_masks ,
@@ -403,7 +416,12 @@ class Experiences:
403416
404417 eids : List [EID ] # Experience IDs of each experience in the batch
405418 tokens : Tensor # [batch_size, seq_length]
419+
420+ # At least one of `rewards` or `token_level_rewards` must be provided (not None).
421+ # If both are provided, `token_level_rewards` will be used and `rewards` will be ignored.
406422 rewards : Tensor # [batch_size]
423+ token_level_rewards : Tensor # [batch_size, response_length]
424+
407425 advantages : Optional [Tensor ] # [batch_size, response_length]
408426 returns : Optional [Tensor ] # [batch_size, response_length]
409427 attention_masks : Tensor # [batch_size, sequence_length]
@@ -447,6 +465,7 @@ def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences
447465 exps = Experiences (
448466 tokens = torch .empty (0 , dtype = torch .int32 ),
449467 rewards = torch .empty (0 , dtype = torch .float32 ),
468+ token_level_rewards = torch .empty (0 , dtype = torch .float32 ),
450469 advantages = torch .empty (0 , dtype = torch .float32 ),
451470 returns = torch .empty (0 , dtype = torch .float32 ),
452471 attention_masks = torch .empty (0 , dtype = torch .bool ),
@@ -522,59 +541,20 @@ def gather_attention_masks(experiences, max_prompt_length: int, max_response_len
522541 return attention_masks
523542
524543
525- def gather_logprobs (experiences , max_response_length : int ) -> Tensor :
526- logprob_dtype = experiences [0 ].logprobs .dtype # type: ignore [union-attr]
527- return torch .stack (
528- [
529- torch .cat (
530- [
531- exp .logprobs ,
532- torch .full (
533- (max_response_length - len (exp .logprobs ),),
534- 0.0 ,
535- dtype = logprob_dtype ,
536- ),
537- ]
538- )
539- for exp in experiences
540- ]
541- )
542-
543-
544- def gather_advantages (experiences , max_response_length : int ) -> Optional [Tensor ]:
545- if experiences [0 ].advantages is None :
546- return None
547- advantages_dtype = experiences [0 ].advantages .dtype
548- return torch .stack (
549- [
550- torch .cat (
551- [
552- exp .advantages ,
553- torch .full (
554- (max_response_length - len (exp .advantages ),),
555- 0.0 ,
556- dtype = advantages_dtype ,
557- ),
558- ]
559- )
560- for exp in experiences
561- ]
562- )
563-
564-
565- def gather_returns (experiences , max_response_length : int ) -> Optional [dict [str , List [Tensor ]]]:
566- if experiences [0 ].returns is None :
567- return None
568- returns_dtype = experiences [0 ].returns .dtype
544+ def gather_response_attrs (
545+ experiences , attr_name : str , max_response_length : int , pad_value : int = 0
546+ ) -> Tensor :
547+ dtype = getattr (experiences [0 ], attr_name ).dtype
548+ pad_value = torch .tensor (pad_value , dtype = dtype )
569549 return torch .stack (
570550 [
571551 torch .cat (
572552 [
573- exp . returns ,
553+ getattr ( exp , attr_name ) ,
574554 torch .full (
575- (max_response_length - len (exp . returns ),),
576- 0.0 ,
577- dtype = returns_dtype ,
555+ (max_response_length - len (getattr ( exp , attr_name ) ),),
556+ pad_value ,
557+ dtype = dtype ,
578558 ),
579559 ]
580560 )
0 commit comments