@@ -218,30 +218,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
218
218
219
219
if self .grpo_config .get ("dynamic_batching" , True ):
220
220
need_update = self .effective_prompt_count >= self .batch_size * self .dp_size
221
- excessive_prompts = self .effective_prompt_count - self .batch_size * self .dp_size
222
-
223
- if excessive_prompts > 0 :
224
- excessive_prompts_per_rank = excessive_prompts // self .dp_size
225
- # Only count excessive prompts if they are greater than 1 per rank.
226
- # TODO: customize excessive prompts calculation.
227
- if excessive_prompts_per_rank != 0 :
228
- # Mask excessive prompts to False
229
- true_indices = torch .nonzero (effective_prompts_mask )
230
- # Make sure the indices are not empty.
231
- if true_indices .numel () > 0 :
232
- true_indices = true_indices .squeeze (- 1 )
233
- if excessive_prompts_per_rank <= len (true_indices ):
234
- excessive_prompts_idx = true_indices [- excessive_prompts_per_rank :]
235
- else :
236
- excessive_prompts_idx = true_indices
237
- effective_prompts_mask [excessive_prompts_idx ] = False
238
-
239
- for mask_idx in range (len (effective_prompts_mask )):
240
- if effective_prompts_mask [mask_idx ] == False :
241
- # Update loss mask.
242
- loss_mask [mask_idx ] = False
243
- else :
244
- excessive_prompts_idx = torch .empty ([0 ])
245
221
else :
246
222
# If dynamic batching is disabled, we need to use all samples for training.
247
223
need_update = (step_idx + 1 ) % self .num_microbatches == 0
@@ -510,7 +486,7 @@ def _criterion(outputs, inputs):
510
486
else :
511
487
return None
512
488
513
- def calculate_group_reward (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
489
+ def calculate_reward (self , rollout : Dict [str , Any ]) -> Dict [str , Any ]:
514
490
"""
515
491
Calculate the group reward for the given rollout group.
516
492
@@ -529,20 +505,20 @@ def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any
529
505
Returns:
530
506
Dict[str, Any]: The new group data with calculated reward.
531
507
"""
532
- reward_group = self .reward_model (
533
- rollout_group ["input_ids" ],
534
- gt_answer = rollout_group ["gt_answer" ],
535
- response_idx = rollout_group ["response_idx" ],
508
+ reward_model_output = self .reward_model (
509
+ rollout ["input_ids" ],
510
+ gt_answer = rollout ["gt_answer" ],
511
+ response_idx = rollout ["response_idx" ],
536
512
)
537
513
# [num_of_generation]
538
- reward = torch .tensor ([value [0 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
539
- format_acc = torch .tensor ([value [1 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
540
- ans_acc = torch .tensor ([value [2 ] for value in reward_group ]).to (rollout_group ["input_ids" ].device )
541
-
542
- rollout_group ["reward" ] = reward .view ((- 1 , 1 ))
543
- rollout_group ["format_acc" ] = format_acc .view ((- 1 , 1 ))
544
- rollout_group ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
545
- return rollout_group
514
+ reward = torch .tensor ([value [0 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
515
+ format_acc = torch .tensor ([value [1 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
516
+ ans_acc = torch .tensor ([value [2 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
517
+
518
+ rollout ["reward" ] = reward .view ((- 1 , 1 ))
519
+ rollout ["format_acc" ] = format_acc .view ((- 1 , 1 ))
520
+ rollout ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
521
+ return rollout
546
522
547
523
def prompt_level_filtering (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
548
524
"""
0 commit comments