@@ -218,8 +218,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
218
218
219
219
if self .grpo_config .get ("dynamic_batching" , True ):
220
220
need_update = self .effective_prompt_count >= self .batch_size * self .dp_size
221
- excessive_prompts = self .effective_prompt_count - self .batch_size * self .dp_size
222
- assert excessive_prompts <= 0 , "Debug: Excessive prompts should always be less than 0. Bug!!!!"
223
221
else :
224
222
# If dynamic batching is disabled, we need to use all samples for training.
225
223
need_update = (step_idx + 1 ) % self .num_microbatches == 0
@@ -488,7 +486,7 @@ def _criterion(outputs, inputs):
488
486
else :
489
487
return None
490
488
491
- def calculate_group_reward (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
489
+ def calculate_reward (self , rollout : Dict [str , Any ]) -> Dict [str , Any ]:
492
490
"""
493
491
Calculate the group reward for the given rollout group.
494
492
@@ -507,20 +505,20 @@ def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any
507
505
Returns:
508
506
Dict[str, Any]: The new group data with calculated reward.
509
507
"""
510
- reward_group = self .reward_model (
511
- rollout_group ["input_ids" ],
512
- gt_answer = rollout_group ["gt_answer" ],
513
- response_idx = rollout_group ["response_idx" ],
508
+ reward_model_output = self .reward_model (
509
+ rollout ["input_ids" ],
510
+ gt_answer = rollout ["gt_answer" ],
511
+ response_idx = rollout ["response_idx" ],
514
512
)
515
513
# [num_of_generation]
516
- reward = torch .tensor ([value [0 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
517
- format_acc = torch .tensor ([value [1 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
518
- ans_acc = torch .tensor ([value [2 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
519
-
520
- rollout_group ["reward" ] = reward .view ((- 1 , 1 ))
521
- rollout_group ["format_acc" ] = format_acc .view ((- 1 , 1 ))
522
- rollout_group ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
523
- return rollout_group
514
+ reward = torch .tensor ([value [0 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
515
+ format_acc = torch .tensor ([value [1 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
516
+ ans_acc = torch .tensor ([value [2 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
517
+
518
+ rollout ["reward" ] = reward .view ((- 1 , 1 ))
519
+ rollout ["format_acc" ] = format_acc .view ((- 1 , 1 ))
520
+ rollout ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
521
+ return rollout
524
522
525
523
def prompt_level_filtering (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
526
524
"""
0 commit comments