1
1
import os
2
+ import threading
2
3
import time
3
4
from typing import Any , Dict , Optional
4
5
@@ -54,6 +55,7 @@ def __init__(
54
55
assert batch_size % minibatch_size == 0 , "batch_size should be divisible by microbatch_size"
55
56
self .num_microbatches = batch_size // minibatch_size
56
57
self .data_uid = 0
58
+ self .sync_model_thread_started = False
57
59
58
60
self .model_config = model_config
59
61
self .plugin_config = plugin_config
@@ -64,7 +66,6 @@ def __init__(
64
66
self .shared_sync_data_actor = shared_sync_data_actor
65
67
self .shared_signal_actor = shared_signal_actor
66
68
self .state_dict_cpu = {}
67
- self .next_data_source = 0 # used to track which producer to get data from next
68
69
69
70
def setup (self ) -> None :
70
71
launch (self .rank , self .world_size , self .master_addr , self .master_port , local_rank = 0 )
@@ -183,7 +184,6 @@ def loop(self) -> None:
183
184
raw_batch = ray .get (self .shared_sync_data_actor .get_data .remote (self .data_uid ))
184
185
continue
185
186
self .data_uid += 1
186
- self .next_data_source = (self .next_data_source + 1 ) % self .num_producers
187
187
raw_batch = {k : v .to (self .device ) for k , v in raw_batch .items ()}
188
188
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
189
189
# we need to calculate the metrics before filtering here for logging
@@ -253,6 +253,7 @@ def loop(self) -> None:
253
253
if loss is not None :
254
254
pbar .set_postfix ({"loss" : loss })
255
255
need_sync_model = True
256
+ ray .get (self .shared_signal_actor .set_signal .remote ("global_step" , self .global_step + 1 ))
256
257
if need_sync_model and (
257
258
(self .global_step + 1 ) % self .save_interval == 0
258
259
or self .received_prompts >= self .train_dataset_size
@@ -269,49 +270,76 @@ def loop(self) -> None:
269
270
if need_sync_model and (
270
271
episode != self .num_episodes - 1 or self .received_prompts != self .train_dataset_size
271
272
):
272
- # sync model weights to all producers, if no model update or it is the last training step, skip syncing
273
- if self .pp_size > 1 :
274
- print (
275
- f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { self .global_step } "
276
- )
277
- else :
278
- print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { self .global_step } " )
279
- torch .cuda .empty_cache ()
280
- self .state_dict_cpu = {k : v .cpu () for k , v in self .state_dict ().items ()}
281
- cc .barrier (group_name = "consumer_pg" )
282
- if self .pp_size > 1 :
283
- if self .tp_rank == 0 and self .dp_rank == 0 :
284
- self .profiler .enter ("sync_model" )
285
- ray .get (
286
- self .shared_signal_actor .set_signal .remote (
287
- f"consumer_pp_{ self .pp_rank } " , "ready_sync_model"
288
- )
289
- )
273
+
274
+ def sync_model_thread ():
275
+ # sync model weights to all producers, if no model update or it is the last training step, skip syncing
276
+ if self .pp_size > 1 :
290
277
print (
291
278
f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { self .global_step } "
292
279
)
293
- ray_broadcast_tensor_dict (
294
- self .state_dict_cpu ,
295
- src = 0 ,
296
- device = torch .device ("cpu" ),
297
- group_name = f"sync_model_consumer_pp_{ self .pp_rank } " ,
298
- backend = "gloo" ,
299
- )
300
- self .profiler .exit ("sync_model" )
301
- else :
302
- if self .rank == 0 :
303
- self .profiler .enter ("sync_model" )
304
- ray .get (self .shared_signal_actor .set_signal .remote ("consumer" , "ready_sync_model" ))
280
+ else :
305
281
print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { self .global_step } " )
306
- ray_broadcast_tensor_dict (
307
- self .state_dict_cpu ,
308
- src = 0 ,
309
- device = torch .device ("cpu" ),
310
- group_name = "sync_model_consumer" ,
311
- backend = "gloo" ,
312
- )
313
- self .profiler .exit ("sync_model" )
282
+ torch .cuda .empty_cache ()
283
+ if self .pp_size > 1 :
284
+ if self .tp_rank == 0 and self .dp_rank == 0 :
285
+ self .profiler .enter ("sync_model" )
286
+ ray .get (
287
+ self .shared_signal_actor .set_signal .remote (
288
+ f"consumer_pp_{ self .pp_rank } " , "ready_sync_model"
289
+ )
290
+ )
291
+ print (
292
+ f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { self .global_step } "
293
+ )
294
+ ray_broadcast_tensor_dict (
295
+ self .state_dict_cpu ,
296
+ src = 0 ,
297
+ device = torch .device ("cpu" ),
298
+ group_name = f"sync_model_consumer_pp_{ self .pp_rank } " ,
299
+ backend = "gloo" ,
300
+ )
301
+ self .profiler .exit ("sync_model" )
302
+ else :
303
+ if self .rank == 0 :
304
+ self .profiler .enter ("sync_model" )
305
+ ray .get (self .shared_signal_actor .set_signal .remote ("consumer" , "ready_sync_model" ))
306
+ print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { self .global_step } " )
307
+ ray_broadcast_tensor_dict (
308
+ self .state_dict_cpu ,
309
+ src = 0 ,
310
+ device = torch .device ("cpu" ),
311
+ group_name = "sync_model_consumer" ,
312
+ backend = "gloo" ,
313
+ )
314
+ self .profiler .exit ("sync_model" )
315
+
316
+ if not self .sync_model_thread_started :
317
+ # only sync model when the thread is not started and no other thread is broadcasting
318
+ self .sync_model_thread_started = True
319
+ state_dict_ = self .state_dict ()
320
+ if (self .pp_size > 1 and self .tp_rank == 0 and self .dp_rank == 0 ) or (
321
+ self .pp_size == 1 and self .rank == 0
322
+ ):
323
+ if len (self .state_dict_cpu ) == 0 :
324
+ # use pinned memory to speed up the transfer
325
+ self .state_dict_cpu = {k : v .cpu ().pin_memory () for k , v in state_dict_ .items ()}
326
+ torch .cuda .synchronize ()
327
+ for k , v in state_dict_ .items ():
328
+ self .state_dict_cpu [k ].copy_ (v , non_blocking = True )
329
+ torch .cuda .synchronize ()
330
+ cc .barrier (
331
+ group_name = "consumer_pg"
332
+ ) # to make sure all ranks have state dict offloaded to CPU before starting the thread
333
+ time_before_starting_thread = time .time ()
334
+ threading .Thread (target = sync_model_thread ).start ()
335
+ # sync_model_thread()
336
+ self .profiler .log (
337
+ f"Sync model, took { time .time () - time_before_starting_thread :.2f} seconds"
338
+ )
339
+ self .sync_model_thread_started = False
340
+ # ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
314
341
self .profiler .log (f"Peak memory usage: { torch .cuda .max_memory_allocated () / 1024 ** 2 :.2f} MB" )
342
+ self .received_prompts = 0
315
343
ray .get (self .shared_signal_actor .set_signal .remote ("consumer" , "terminate" ))
316
344
317
345
def __del__ (self ):
0 commit comments