1
- import warnings
2
1
from contextlib import nullcontext
3
2
from typing import Any , Optional
4
3
10
9
from coati .distributed .reward .reward_fn import boxed_math_reward_fn , math_reward_fn
11
10
from coati .distributed .reward .verifiable_reward import VerifiableReward
12
11
from coati .distributed .utils import calc_action_log_probs
13
- from coati .trainer .utils import all_reduce_mean , all_reduce_sum
12
+ from coati .trainer .utils import all_gather_tensors , all_reduce_mean , all_reduce_sum
14
13
from transformers import AutoModelForCausalLM , AutoTokenizer
15
14
16
15
from colossalai .nn .lr_scheduler import CosineAnnealingWarmupLR
@@ -43,13 +42,6 @@ def __init__(
43
42
wandb_group_name : str = None ,
44
43
):
45
44
print (f"Using GRPO config: { grpo_config } " )
46
- if grpo_config .get ("loss_variation" , "sample_level" ) == "token_level" :
47
- if batch_size != minibatch_size :
48
- warnings .warn (
49
- f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: { minibatch_size } ->{ batch_size } " ,
50
- UserWarning ,
51
- )
52
- minibatch_size = batch_size
53
45
if (
54
46
plugin_config .get ("pp_size" , 1 ) > 1
55
47
and "num_microbatches" not in plugin_config
@@ -91,6 +83,7 @@ def __init__(
91
83
self .grpo_config = grpo_config
92
84
self .project_name = project_name
93
85
self .effective_sample_count = 0
86
+ self .effective_prompt_count = 0
94
87
self .total_sample_count = 0
95
88
self .project_name = project_name
96
89
self .run_name = run_name
@@ -219,70 +212,66 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
219
212
group_ans_acc = (
220
213
ans_acc .view (- 1 , self .num_generations ).mean (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
221
214
)
215
+ # [minibatch_size x num_of_generation]
222
216
loss_mask = (
223
217
torch .ones (action_mask .size (0 ), device = action_mask .device ).bool ()
224
218
if self .filter_range is None
225
219
else torch .logical_and (group_ans_acc > self .filter_range [0 ], group_ans_acc < self .filter_range [1 ])
226
220
)
221
+
227
222
# filter out overlength samples
228
223
if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
229
224
loss_mask = torch .logical_and (
230
225
loss_mask ,
231
226
action_mask [:, - 1 ] == False ,
232
227
)
233
- effective_tokens_count = torch .sum (action_mask , dim = - 1 ) * loss_mask
234
- effective_samples = all_reduce_sum (torch .sum (loss_mask ), self .plugin )
235
- total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
236
- total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
237
- self .effective_sample_count += effective_samples .item ()
238
- self .total_sample_count += total_samples .item ()
228
+ prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
229
+
230
+ # [minibatch_size] -> calculate the number of effective prompts
231
+ effective_prompts_mask = prompt_level_mask .any (dim = 1 )
232
+ effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
233
+ self .effective_prompt_count += effective_prompts .item ()
234
+ excessive_prompts_idx = None
239
235
240
236
mean_kl , mean_loss = [], []
241
237
242
238
if self .grpo_config .get ("dynamic_batching" , True ):
243
- need_update = self .effective_sample_count >= self .batch_size * self .dp_size * self .num_generations
244
- # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
245
- num_excessive_samples = (
246
- int (
247
- (self .effective_sample_count - self .batch_size * self .dp_size * self .num_generations )
248
- / self .num_generations
249
- / self .dp_size
250
- )
251
- * self .num_generations
252
- )
253
- if num_excessive_samples > 0 :
254
- data = {
255
- k : (
256
- v [: - num_excessive_samples if num_excessive_samples != 0 else v .size (0 )]
257
- if k
258
- in [
259
- "input_ids" ,
260
- "attention_mask" ,
261
- "action_log_probs" ,
262
- "action_mask" ,
263
- "response_idx" ,
264
- "gt_answer" ,
265
- ]
266
- else v
267
- )
268
- for k , v in data .items ()
269
- }
270
- action_mask = action_mask [
271
- : - num_excessive_samples if num_excessive_samples != 0 else action_mask .size (0 )
272
- ]
273
- loss_mask = loss_mask [: - num_excessive_samples if num_excessive_samples != 0 else loss_mask .size (0 )]
274
- advantages = advantages [: - num_excessive_samples if num_excessive_samples != 0 else advantages .size (0 )]
275
- else :
276
- num_excessive_samples = 0
239
+ need_update = self .effective_prompt_count >= self .batch_size * self .dp_size
240
+ excessive_prompts = self .effective_prompt_count - self .batch_size * self .dp_size
241
+
242
+ if excessive_prompts > 0 :
243
+ excessive_prompts_per_rank = excessive_prompts // self .dp_size
244
+ # Only count excessive prompts if they are greater than 1 per rank.
245
+ # TODO: customize excessive prompts calculation.
246
+ if excessive_prompts_per_rank != 0 :
247
+ # Mask excessive prompts to False
248
+ true_indices = torch .nonzero (effective_prompts_mask ).squeeze ()
249
+ if excessive_prompts_per_rank <= len (true_indices ):
250
+ excessive_prompts_idx = true_indices [- excessive_prompts_per_rank :]
251
+ else :
252
+ excessive_prompts_idx = true_indices
253
+ effective_prompts_mask [excessive_prompts_idx ] = False
254
+
255
+ for mask_idx in range (len (effective_prompts_mask )):
256
+ if effective_prompts_mask [mask_idx ] == False :
257
+ # Update loss mask.
258
+ loss_mask [mask_idx ] = False
277
259
else :
278
260
# If dynamic batching is disabled, we need to use all samples for training.
279
261
need_update = (step_idx + 1 ) % self .num_microbatches == 0
280
- num_excessive_samples = 0
262
+
263
+ effective_samples = all_reduce_sum (torch .sum (loss_mask ), self .plugin )
264
+ effective_tokens_count = torch .sum (action_mask , dim = - 1 ) * loss_mask
265
+ total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
266
+ total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
267
+ self .effective_sample_count += effective_samples .item ()
268
+ self .total_sample_count += total_samples .item ()
281
269
282
270
pbar .set_postfix (
283
271
{
284
- "Step" : self .global_step + 1 ,
285
- "Status" : f"Collecting: { self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } " ,
272
+ "Global Step" : self .global_step ,
273
+ "Effective prompts" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } " ,
274
+ "Effective samples" : f"{ self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } " ,
286
275
}
287
276
)
288
277
@@ -381,7 +370,7 @@ def _criterion(outputs, inputs):
381
370
kl .append (appox_kl .mean ())
382
371
else :
383
372
per_token_kl = 0.0
384
- kl .append (0.0 )
373
+ kl .append (torch . tensor ( 0.0 ) )
385
374
386
375
loss , _ = self .policy_loss_fn (
387
376
action_log_probs ,
@@ -485,6 +474,7 @@ def _criterion(outputs, inputs):
485
474
self .optimizer .zero_grad ()
486
475
self .global_step += 1
487
476
sample_utilization = self .effective_sample_count / self .total_sample_count
477
+ self .effective_prompt_count = 0
488
478
self .effective_sample_count = 0
489
479
self .total_sample_count = 0
490
480
loss_scalar = self .accum_loss .item ()
@@ -501,6 +491,7 @@ def _criterion(outputs, inputs):
501
491
f"Acc Reward: { self .accum_ans_acc .item () / self .accum_count :.4f} " ,
502
492
f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
503
493
f"Response Length: { self .accum_response_length .item () / self .accum_count :.4f} " ,
494
+ f"Sample_utilization: { sample_utilization :.4f} " ,
504
495
] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
505
496
print ("\n " .join (to_log_msg ))
506
497
metrics = {
@@ -526,9 +517,15 @@ def _criterion(outputs, inputs):
526
517
self .accum_advantages .zero_ ()
527
518
self .accum_response_length .zero_ ()
528
519
self .accum_count = 0
529
- return loss_scalar , num_excessive_samples // self .num_generations
520
+
521
+ if excessive_prompts_idx is not None :
522
+ # All gather excessive prompts index across DP ranks.
523
+ excessive_prompts_idx = [idx + self .dp_rank * self .minibatch_size for idx in excessive_prompts_idx ]
524
+ excessive_prompts_idx = all_gather_tensors (excessive_prompts_idx , self .plugin )
525
+
526
+ return loss_scalar , excessive_prompts_idx
530
527
else :
531
- return None , num_excessive_samples // self . num_generations
528
+ return None , excessive_prompts_idx
532
529
533
530
def state_dict (self ):
534
531
self .policy_model ._force_wait_all_gather ()
0 commit comments