@@ -218,8 +218,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
218218
219219 if self .grpo_config .get ("dynamic_batching" , True ):
220220 need_update = self .effective_prompt_count >= self .batch_size * self .dp_size
221- excessive_prompts = self .effective_prompt_count - self .batch_size * self .dp_size
222- assert excessive_prompts <= 0 , "Debug: Excessive prompts should always be less than 0. Bug!!!!"
223221 else :
224222 # If dynamic batching is disabled, we need to use all samples for training.
225223 need_update = (step_idx + 1 ) % self .num_microbatches == 0
@@ -488,7 +486,7 @@ def _criterion(outputs, inputs):
488486 else :
489487 return None
490488
491- def calculate_group_reward (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
489+ def calculate_reward (self , rollout : Dict [str , Any ]) -> Dict [str , Any ]:
492490 """
493491 Calculate the group reward for the given rollout group.
494492
@@ -507,20 +505,20 @@ def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any
507505 Returns:
508506 Dict[str, Any]: The new group data with calculated reward.
509507 """
510- reward_group = self .reward_model (
511- rollout_group ["input_ids" ],
512- gt_answer = rollout_group ["gt_answer" ],
513- response_idx = rollout_group ["response_idx" ],
508+ reward_model_output = self .reward_model (
509+ rollout ["input_ids" ],
510+ gt_answer = rollout ["gt_answer" ],
511+ response_idx = rollout ["response_idx" ],
514512 )
515513 # [num_of_generation]
516- reward = torch .tensor ([value [0 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
517- format_acc = torch .tensor ([value [1 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
518- ans_acc = torch .tensor ([value [2 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
519-
520- rollout_group ["reward" ] = reward .view ((- 1 , 1 ))
521- rollout_group ["format_acc" ] = format_acc .view ((- 1 , 1 ))
522- rollout_group ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
523- return rollout_group
514+ reward = torch .tensor ([value [0 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
515+ format_acc = torch .tensor ([value [1 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
516+ ans_acc = torch .tensor ([value [2 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
517+
518+ rollout ["reward" ] = reward .view ((- 1 , 1 ))
519+ rollout ["format_acc" ] = format_acc .view ((- 1 , 1 ))
520+ rollout ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
521+ return rollout
524522
525523 def prompt_level_filtering (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
526524 """
0 commit comments