9
9
from coati .distributed .reward .reward_fn import boxed_math_reward_fn , math_reward_fn
10
10
from coati .distributed .reward .verifiable_reward import VerifiableReward
11
11
from coati .distributed .utils import calc_action_log_probs
12
- from coati .trainer .utils import all_gather_tensors , all_reduce_mean , all_reduce_sum
12
+ from coati .trainer .utils import all_reduce_mean , all_reduce_sum
13
13
from transformers import AutoModelForCausalLM , AutoTokenizer
14
14
15
15
from colossalai .nn .lr_scheduler import CosineAnnealingWarmupLR
@@ -201,10 +201,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
201
201
reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
202
202
# [minibatch_size x num_generations]
203
203
advantages = ((reward - reward_mean ) / (reward_std + 1e-4 )).unsqueeze (dim = - 1 )
204
- # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
205
- group_ans_acc = (
206
- ans_acc .view (- 1 , self .num_generations ).mean (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
207
- )
204
+
208
205
# [minibatch_size x num_of_generation]
209
206
loss_mask = torch .ones (action_mask .size (0 ), device = action_mask .device ).bool ()
210
207
@@ -214,37 +211,14 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
214
211
loss_mask ,
215
212
action_mask [:, - 1 ] == False ,
216
213
)
217
- prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
218
-
219
- # [minibatch_size] -> calculate the number of effective prompts
220
- effective_prompts_mask = prompt_level_mask .any (dim = 1 )
221
- effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
222
- self .effective_prompt_count += effective_prompts .item ()
223
- excessive_prompts_idx = None
214
+ self .effective_prompt_count += group_reward .size (0 ) * self .dp_size
224
215
225
216
mean_kl , mean_loss = [], []
226
217
227
218
if self .grpo_config .get ("dynamic_batching" , True ):
228
219
need_update = self .effective_prompt_count >= self .batch_size * self .dp_size
229
220
excessive_prompts = self .effective_prompt_count - self .batch_size * self .dp_size
230
-
231
- if excessive_prompts > 0 :
232
- excessive_prompts_per_rank = excessive_prompts // self .dp_size
233
- # Only count excessive prompts if they are greater than 1 per rank.
234
- # TODO: customize excessive prompts calculation.
235
- if excessive_prompts_per_rank != 0 :
236
- # Mask excessive prompts to False
237
- true_indices = torch .nonzero (effective_prompts_mask ).squeeze ()
238
- if excessive_prompts_per_rank <= len (true_indices ):
239
- excessive_prompts_idx = true_indices [- excessive_prompts_per_rank :]
240
- else :
241
- excessive_prompts_idx = true_indices
242
- effective_prompts_mask [excessive_prompts_idx ] = False
243
-
244
- for mask_idx in range (len (effective_prompts_mask )):
245
- if effective_prompts_mask [mask_idx ] == False :
246
- # Update loss mask.
247
- loss_mask [mask_idx ] = False
221
+ assert excessive_prompts <= 0 , "Debug: Excessive prompts should always be less than 0. Bug!!!!"
248
222
else :
249
223
# If dynamic batching is disabled, we need to use all samples for training.
250
224
need_update = (step_idx + 1 ) % self .num_microbatches == 0
@@ -460,9 +434,7 @@ def _criterion(outputs, inputs):
460
434
self .optimizer .step ()
461
435
self .optimizer .zero_grad ()
462
436
self .global_step += 1
463
- self .total_sample_count = all_reduce_sum (
464
- torch .tensor (self .total_sample_count ).to (self .accum_loss .device ), self .plugin
465
- ).item ()
437
+ # no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
466
438
sample_utilization = self .effective_sample_count / self .total_sample_count
467
439
self .effective_prompt_count = 0
468
440
self .effective_sample_count = 0
@@ -507,14 +479,9 @@ def _criterion(outputs, inputs):
507
479
self .accum_advantages .zero_ ()
508
480
self .accum_response_length .zero_ ()
509
481
self .accum_count = 0
510
-
511
- if excessive_prompts_idx is not None :
512
- # All gather excessive prompts index across DP ranks.
513
- excessive_prompts_idx = [idx + self .dp_rank * self .minibatch_size for idx in excessive_prompts_idx ]
514
- excessive_prompts_idx = all_gather_tensors (excessive_prompts_idx , self .plugin )
515
- return loss_scalar , excessive_prompts_idx
482
+ return loss_scalar
516
483
else :
517
- return None , excessive_prompts_idx
484
+ return None
518
485
519
486
def calculate_group_reward (self , rollout_group : Dict [str , Any ]) -> Dict [str , Any ]:
520
487
"""
0 commit comments