1
1
from contextlib import nullcontext
2
- from typing import Any , Optional
2
+ from typing import Any , Dict , Optional
3
3
4
4
import ray
5
5
import torch
@@ -179,7 +179,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
179
179
Format:
180
180
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
181
181
"""
182
-
183
182
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
184
183
data = {k : v .view (- 1 , v .size (- 1 )) for k , v in kwargs .items ()}
185
184
action_mask = data ["action_mask" ]
@@ -188,15 +187,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
188
187
response_length = torch .sum (action_mask , dim = 1 ).to (torch .float32 )
189
188
train_microbatch_size = self .grpo_config .get ("train_microbatch_size" , data ["input_ids" ].size (0 ))
190
189
191
- reward_group = self .reward_model (
192
- data ["input_ids" ],
193
- gt_answer = data ["gt_answer" ],
194
- response_idx = data ["response_idx" ],
195
- )
196
-
197
- reward = torch .tensor ([value [0 ] for value in reward_group ]).to (data ["input_ids" ].device )
198
- format_acc = torch .tensor ([value [1 ] for value in reward_group ]).to (data ["input_ids" ].device )
199
- ans_acc = torch .tensor ([value [2 ] for value in reward_group ]).to (data ["input_ids" ].device )
190
+ reward = data ["reward" ].view ((- 1 ))
191
+ format_acc = data ["format_acc" ].view ((- 1 ))
192
+ ans_acc = data ["ans_acc" ].view ((- 1 ))
200
193
201
194
# [minibatch_size, num_generations]
202
195
@@ -213,11 +206,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
213
206
ans_acc .view (- 1 , self .num_generations ).mean (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
214
207
)
215
208
# [minibatch_size x num_of_generation]
216
- loss_mask = (
217
- torch .ones (action_mask .size (0 ), device = action_mask .device ).bool ()
218
- if self .filter_range is None
219
- else torch .logical_and (group_ans_acc > self .filter_range [0 ], group_ans_acc < self .filter_range [1 ])
220
- )
209
+ loss_mask = torch .ones (action_mask .size (0 ), device = action_mask .device ).bool ()
221
210
222
211
# filter out overlength samples
223
212
if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
@@ -525,6 +514,68 @@ def _criterion(outputs, inputs):
525
514
else :
526
515
return None , excessive_prompts_idx
527
516
517
+ def calculate_group_reward (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
518
+ """
519
+ Calculate the group reward for the given rollout group.
520
+
521
+ Args:
522
+ rollout_group (Dict[str, Any]):
523
+ a group of samples generated by the model from the same prompt
524
+ contain the following keys:
525
+ "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
526
+ "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
527
+ "action_mask": torch.Tensor, [num_of_generation, response_length]
528
+ "action_log_probs": torch.Tensor, [num_of_generation, response_length]
529
+ "response_idx": int, torch.Tensor, [num_of_generation, 2]
530
+ "gt_answer": torch.Tensor, [num_of_generation, 128]
531
+ "temperature": torch.Tensor, [] (scalar)
532
+
533
+ Returns:
534
+ Dict[str, Any]: The new group data with calculated reward.
535
+ """
536
+ reward_group = self .reward_model (
537
+ rollout_group ["input_ids" ],
538
+ gt_answer = rollout_group ["gt_answer" ],
539
+ response_idx = rollout_group ["response_idx" ],
540
+ )
541
+ # [num_of_generation]
542
+ reward = torch .tensor ([value [0 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
543
+ format_acc = torch .tensor ([value [1 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
544
+ ans_acc = torch .tensor ([value [2 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
545
+
546
+ rollout_group ["reward" ] = reward .view ((- 1 , 1 ))
547
+ rollout_group ["format_acc" ] = format_acc .view ((- 1 , 1 ))
548
+ rollout_group ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
549
+ return rollout_group
550
+
551
+ def prompt_level_filtering (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
552
+ """
553
+ rollout_group: Dict[str, Any]
554
+ a group of samples generated by the model from the same prompt
555
+ contain the following keys:
556
+ "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
557
+ "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
558
+ "action_mask": torch.Tensor, [num_of_generation, response_length]
559
+ "action_log_probs": torch.Tensor, [num_of_generation, response_length]
560
+ "response_idx": int, torch.Tensor, [num_of_generation, 2]
561
+ "gt_answer": torch.Tensor, [num_of_generation, 128]
562
+ "temperature": torch.Tensor, [] (scalar)
563
+ "reward": torch.Tensor, [num_of_generation]
564
+ "format_acc": torch.Tensor, [num_of_generation]
565
+ "ans_acc": torch.Tensor, [num_of_generation]
566
+ """
567
+ if self .filter_range is not None :
568
+ # filter prompt whoes accuracy is too high or too low (out of range)
569
+ group_ans_acc = torch .mean (rollout_group ["ans_acc" ])
570
+ if group_ans_acc < self .filter_range [0 ] or group_ans_acc > self .filter_range [1 ]:
571
+ # filter out the prompt
572
+ return None
573
+ else :
574
+ return rollout_group
575
+ else :
576
+ # no filter
577
+ return rollout_group
578
+
528
579
def state_dict (self ):
529
580
self .policy_model ._force_wait_all_gather ()
530
581
model = self .policy_model .unwrap ()
0 commit comments