1
- import warnings
2
1
from contextlib import nullcontext
3
2
from typing import Any , Optional
4
3
10
9
from coati .distributed .reward .reward_fn import boxed_math_reward_fn , math_reward_fn
11
10
from coati .distributed .reward .verifiable_reward import VerifiableReward
12
11
from coati .distributed .utils import calc_action_log_probs
13
- from coati .trainer .utils import all_reduce_mean , all_reduce_sum
12
+ from coati .trainer .utils import all_gather_tensors , all_reduce_mean , all_reduce_sum
14
13
from transformers import AutoModelForCausalLM , AutoTokenizer
15
14
16
15
from colossalai .nn .lr_scheduler import CosineAnnealingWarmupLR
@@ -42,13 +41,6 @@ def __init__(
42
41
save_dir = "./model" ,
43
42
):
44
43
print (f"Using GRPO config: { grpo_config } " )
45
- if grpo_config .get ("loss_variation" , "sample_level" ) == "token_level" :
46
- if batch_size != minibatch_size :
47
- warnings .warn (
48
- f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: { minibatch_size } ->{ batch_size } " ,
49
- UserWarning ,
50
- )
51
- minibatch_size = batch_size
52
44
if (
53
45
plugin_config .get ("pp_size" , 1 ) > 1
54
46
and "num_microbatches" not in plugin_config
@@ -90,6 +82,7 @@ def __init__(
90
82
self .grpo_config = grpo_config
91
83
self .project_name = project_name
92
84
self .effective_sample_count = 0
85
+ self .effective_prompt_count = 0
93
86
self .total_sample_count = 0
94
87
95
88
self .policy_loss_fn = PolicyLoss (
@@ -213,70 +206,66 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
213
206
group_ans_acc = (
214
207
ans_acc .view (- 1 , self .num_generations ).mean (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
215
208
)
209
+ # [minibatch_size x num_of_generation]
216
210
loss_mask = (
217
211
torch .ones (action_mask .size (0 ), device = action_mask .device ).bool ()
218
212
if self .filter_range is None
219
213
else torch .logical_and (group_ans_acc > self .filter_range [0 ], group_ans_acc < self .filter_range [1 ])
220
214
)
215
+
221
216
# filter out overlength samples
222
217
if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
223
218
loss_mask = torch .logical_and (
224
219
loss_mask ,
225
220
action_mask [:, - 1 ] == False ,
226
221
)
227
- effective_tokens_count = torch .sum (action_mask , dim = - 1 ) * loss_mask
228
- effective_samples = all_reduce_sum (torch .sum (loss_mask ), self .plugin )
229
- total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
230
- total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
231
- self .effective_sample_count += effective_samples .item ()
232
- self .total_sample_count += total_samples .item ()
222
+ prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
223
+
224
+ # [minibatch_size] -> calculate the number of effective prompts
225
+ effective_prompts_mask = prompt_level_mask .any (dim = 1 )
226
+ effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
227
+ self .effective_prompt_count += effective_prompts .item ()
228
+ excessive_prompts_idx = None
233
229
234
230
mean_kl , mean_loss = [], []
235
231
236
232
if self .grpo_config .get ("dynamic_batching" , True ):
237
- need_update = self .effective_sample_count >= self .batch_size * self .dp_size * self .num_generations
238
- # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
239
- num_excessive_samples = (
240
- int (
241
- (self .effective_sample_count - self .batch_size * self .dp_size * self .num_generations )
242
- / self .num_generations
243
- / self .dp_size
244
- )
245
- * self .num_generations
246
- )
247
- if num_excessive_samples > 0 :
248
- data = {
249
- k : (
250
- v [: - num_excessive_samples if num_excessive_samples != 0 else v .size (0 )]
251
- if k
252
- in [
253
- "input_ids" ,
254
- "attention_mask" ,
255
- "action_log_probs" ,
256
- "action_mask" ,
257
- "response_idx" ,
258
- "gt_answer" ,
259
- ]
260
- else v
261
- )
262
- for k , v in data .items ()
263
- }
264
- action_mask = action_mask [
265
- : - num_excessive_samples if num_excessive_samples != 0 else action_mask .size (0 )
266
- ]
267
- loss_mask = loss_mask [: - num_excessive_samples if num_excessive_samples != 0 else loss_mask .size (0 )]
268
- advantages = advantages [: - num_excessive_samples if num_excessive_samples != 0 else advantages .size (0 )]
269
- else :
270
- num_excessive_samples = 0
233
+ need_update = self .effective_prompt_count >= self .batch_size * self .dp_size
234
+ excessive_prompts = self .effective_prompt_count - self .batch_size * self .dp_size
235
+
236
+ if excessive_prompts > 0 :
237
+ excessive_prompts_per_rank = excessive_prompts // self .dp_size
238
+ # Only count excessive prompts if they are greater than 1 per rank.
239
+ # TODO: customize excessive prompts calculation.
240
+ if excessive_prompts_per_rank != 0 :
241
+ # Mask excessive prompts to False
242
+ true_indices = torch .nonzero (effective_prompts_mask ).squeeze ()
243
+ if excessive_prompts_per_rank <= len (true_indices ):
244
+ excessive_prompts_idx = true_indices [- excessive_prompts_per_rank :]
245
+ else :
246
+ excessive_prompts_idx = true_indices
247
+ effective_prompts_mask [excessive_prompts_idx ] = False
248
+
249
+ for mask_idx in range (len (effective_prompts_mask )):
250
+ if effective_prompts_mask [mask_idx ] == False :
251
+ # Update loss mask.
252
+ loss_mask [mask_idx ] = False
271
253
else :
272
254
# If dynamic batching is disabled, we need to use all samples for training.
273
255
need_update = (step_idx + 1 ) % self .num_microbatches == 0
274
- num_excessive_samples = 0
256
+
257
+ effective_samples = all_reduce_sum (torch .sum (loss_mask ), self .plugin )
258
+ effective_tokens_count = torch .sum (action_mask , dim = - 1 ) * loss_mask
259
+ total_effective_tokens_count = all_reduce_sum (torch .sum (effective_tokens_count ), self .plugin )
260
+ total_samples = all_reduce_sum (torch .sum (torch .ones_like (loss_mask , device = loss_mask .device )), self .plugin )
261
+ self .effective_sample_count += effective_samples .item ()
262
+ self .total_sample_count += total_samples .item ()
275
263
276
264
pbar .set_postfix (
277
265
{
278
- "Step" : self .global_step + 1 ,
279
- "Status" : f"Collecting: { self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } " ,
266
+ "Global Step" : self .global_step ,
267
+ "Effective prompts" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } " ,
268
+ "Effective samples" : f"{ self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } " ,
280
269
}
281
270
)
282
271
@@ -375,7 +364,7 @@ def _criterion(outputs, inputs):
375
364
kl .append (appox_kl .mean ())
376
365
else :
377
366
per_token_kl = 0.0
378
- kl .append (0.0 )
367
+ kl .append (torch . tensor ( 0.0 ) )
379
368
380
369
loss , _ = self .policy_loss_fn (
381
370
action_log_probs ,
@@ -479,6 +468,7 @@ def _criterion(outputs, inputs):
479
468
self .optimizer .zero_grad ()
480
469
self .global_step += 1
481
470
sample_utilization = self .effective_sample_count / self .total_sample_count
471
+ self .effective_prompt_count = 0
482
472
self .effective_sample_count = 0
483
473
self .total_sample_count = 0
484
474
loss_scalar = self .accum_loss .item ()
@@ -495,6 +485,7 @@ def _criterion(outputs, inputs):
495
485
f"Acc Reward: { self .accum_ans_acc .item () / self .accum_count :.4f} " ,
496
486
f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
497
487
f"Response Length: { self .accum_response_length .item () / self .accum_count :.4f} " ,
488
+ f"Sample_utilization: { sample_utilization :.4f} " ,
498
489
] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
499
490
print ("\n " .join (to_log_msg ))
500
491
metrics = {
@@ -520,9 +511,15 @@ def _criterion(outputs, inputs):
520
511
self .accum_advantages .zero_ ()
521
512
self .accum_response_length .zero_ ()
522
513
self .accum_count = 0
523
- return loss_scalar , num_excessive_samples // self .num_generations
514
+
515
+ if excessive_prompts_idx is not None :
516
+ # All gather excessive prompts index across DP ranks.
517
+ excessive_prompts_idx = [idx + self .dp_rank * self .minibatch_size for idx in excessive_prompts_idx ]
518
+ excessive_prompts_idx = all_gather_tensors (excessive_prompts_idx , self .plugin )
519
+
520
+ return loss_scalar , excessive_prompts_idx
524
521
else :
525
- return None , num_excessive_samples // self . num_generations
522
+ return None , excessive_prompts_idx
526
523
527
524
def state_dict (self ):
528
525
self .policy_model ._force_wait_all_gather ()
0 commit comments