9
9
import tqdm
10
10
import wandb
11
11
from coati .dataset .loader import RawConversationDataset , collate_fn_grpo
12
+ from coati .distributed .profiling_utils import CustomProfiler
12
13
from coati .distributed .reward .reward_fn import boxed_math_reward_fn , code_reward_fn , math_reward_fn
13
14
from coati .distributed .reward .verifiable_reward import VerifiableReward
14
15
from ray .util .collective import allreduce
@@ -52,6 +53,8 @@ def __init__(
52
53
wandb_group_name : str = None ,
53
54
log_rollout_interval : int = 20 ,
54
55
rollout_log_file : str = "./rollout_log.jsonl" ,
56
+ enable_profiling : bool = False ,
57
+ n_behind : int = 0 ,
55
58
):
56
59
self .producer_idx = producer_idx
57
60
self .num_producers = num_producers
@@ -62,6 +65,7 @@ def __init__(
62
65
assert batch_size % microbatch_size == 0
63
66
self .num_microbatches = batch_size // microbatch_size
64
67
self .latest_eval_step = - 1
68
+ self .profiler = CustomProfiler (f"P{ self .producer_idx } " , disabled = not enable_profiling )
65
69
66
70
self .train_dataset_config = train_dataset_config
67
71
self .model_config = model_config
@@ -75,6 +79,7 @@ def __init__(
75
79
self .log_rollout_interval = log_rollout_interval
76
80
self .latest_rollout_log_step = - 1
77
81
self .grpo_config = grpo_config
82
+ self .n_behind = n_behind
78
83
reward_model_kwargs = {
79
84
k : v
80
85
for k , v in grpo_config .items ()
@@ -268,11 +273,14 @@ def loop(self) -> None:
268
273
self .wandb_run .log (to_log_msg , step = self .consumer_global_step )
269
274
self .eval_mode = False
270
275
self .latest_eval_step = self .consumer_global_step
276
+ self .profiler .enter ("rollout" )
271
277
outputs = self .rollout (** batch )
278
+ self .profiler .exit ("rollout" )
272
279
outputs ["temperature" ] = torch .tensor (
273
280
[self .model .generate_config ["temperature" ]] * outputs ["input_ids" ].size (0 )
274
281
).to (outputs ["input_ids" ].device )
275
282
bs , num_gen = outputs ["input_ids" ].size (0 ), outputs ["input_ids" ].size (1 )
283
+ self .profiler .enter ("calculate_reward" )
276
284
if self .grpo_config ["reward_fn_type" ] == "code" :
277
285
test_cases = []
278
286
for prompt_id in range (bs ):
@@ -310,20 +318,26 @@ def loop(self) -> None:
310
318
outputs .pop ("gt_answer" )
311
319
if "test_cases" in outputs :
312
320
outputs .pop ("test_cases" )
321
+ self .profiler .exit ("calculate_reward" )
313
322
314
323
print (f"[P{ self .producer_idx } ] Send data { [(k , v .shape ) for k , v in outputs .items ()]} " )
315
324
outputs = pre_send (outputs )
325
+ self .profiler .enter ("send_broadcast_data" )
316
326
ray_broadcast_tensor_dict (
317
327
outputs , src = 0 , device = self .device , group_name = f"sync_data_{ self .producer_idx } "
318
328
)
319
- if (i + 1 ) % self .num_microbatches == 0 and (
320
- episode != self .num_episodes - 1 or i != num_valid_microbatches - 1
329
+ self .profiler .exit ("send_broadcast_data" )
330
+ if (
331
+ (i + 1 ) % self .num_microbatches == 0
332
+ and (episode != self .num_episodes - 1 or i != num_valid_microbatches - 1 )
333
+ and (episode != 0 or (i + 1 ) > self .n_behind * self .num_microbatches )
321
334
):
322
335
if isinstance (self .model , BACKEND_MAP ["vllm" ]) and self .model .model_config .get (
323
336
"enable_sleep_mode" , False
324
337
):
325
338
self .model .llm .sleep () # revict KV_cache to avoid OOM
326
339
# don't sync model for last iteration
340
+ self .profiler .enter ("sync_model" )
327
341
torch .cuda .empty_cache ()
328
342
329
343
if self .consumer_pp_size > 1 :
@@ -349,6 +363,7 @@ def loop(self) -> None:
349
363
self .load_state_dict (state_dict )
350
364
del state_dict
351
365
torch .cuda .empty_cache ()
366
+ self .profiler .exit ("sync_model" )
352
367
if isinstance (self .model , BACKEND_MAP ["vllm" ]) and self .model .model_config .get (
353
368
"enable_sleep_mode" , False
354
369
):
@@ -364,6 +379,9 @@ def loop(self) -> None:
364
379
"temperature"
365
380
] + ratio * 0.9
366
381
382
+ def __del__ (self ):
383
+ self .profiler .close ()
384
+
367
385
368
386
@ray .remote
369
387
class SimpleProducer (BaseProducer ):
@@ -392,6 +410,8 @@ def __init__(
392
410
wandb_group_name : str = None ,
393
411
log_rollout_interval : int = 20 ,
394
412
rollout_log_file : str = "./rollout_log.jsonl" ,
413
+ enable_profiling : bool = False ,
414
+ n_behind : int = 0 ,
395
415
):
396
416
super ().__init__ (
397
417
producer_idx ,
@@ -415,6 +435,8 @@ def __init__(
415
435
wandb_group_name = wandb_group_name ,
416
436
log_rollout_interval = log_rollout_interval ,
417
437
rollout_log_file = rollout_log_file ,
438
+ enable_profiling = enable_profiling ,
439
+ n_behind = n_behind ,
418
440
)
419
441
self .model = self .backend_cls (model_config , generate_config , self .tokenizer , num_generations )
420
442
self .eval_generation_config = copy .deepcopy (self .model .generate_config )
0 commit comments