1
1
import os
2
+ import threading
3
+ import time
2
4
from contextlib import nullcontext
3
5
from typing import Any , Dict , Optional
4
6
16
18
from colossalai .nn .optimizer import HybridAdam
17
19
from colossalai .utils import get_current_device
18
20
19
- from .comm import ray_broadcast_tensor_dict
21
+ from .comm import SharedVariableActor , ray_broadcast_tensor_dict
20
22
from .utils import bind_batch , post_recv , unbind_batch
21
23
22
24
23
25
class BaseConsumer :
24
26
def __init__ (
25
27
self ,
28
+ shared_sync_data_actor : SharedVariableActor ,
29
+ shared_sync_model_actor : SharedVariableActor ,
26
30
num_producers : int ,
27
31
num_episodes : int ,
28
32
rank : int ,
@@ -63,6 +67,13 @@ def __init__(
63
67
self .lr_scheduler = None
64
68
self .n_behind = n_behind
65
69
70
+ # for running sync data and model in separate actors/threads
71
+ self .shared_sync_data_actor = shared_sync_data_actor
72
+ self .shared_sync_model_actor = shared_sync_model_actor
73
+ self .thread_started = False
74
+ self .model_sync_step = 0
75
+ self .state_dict_cpu = {}
76
+
66
77
def setup (self ) -> None :
67
78
launch (self .rank , self .world_size , self .master_addr , self .master_port , local_rank = 0 )
68
79
@@ -85,6 +96,7 @@ def setup(self) -> None:
85
96
self .pp_size = dist .get_world_size (self .plugin .pp_group )
86
97
87
98
# Init Hybrid ray process group
99
+ cc .init_collective_group (self .world_size , self .rank , group_name = "consumer_pg" )
88
100
for i in range (self .num_producers ):
89
101
cc .init_collective_group (self .world_size + 1 , self .rank + 1 , group_name = f"sync_data_{ i } " )
90
102
if self .pp_size > 1 :
@@ -152,44 +164,12 @@ def loop(self) -> None:
152
164
torch .cuda .reset_peak_memory_stats ()
153
165
i = 0
154
166
for _ in range (self .num_recv_per_update ):
155
- # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
156
- effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
157
- while len (effective_group_to_raw_group_mapping ) > max (
158
- self .dp_size * self .batch_size
159
- - self .dp_size
160
- * self .minibatch_size
161
- * self .grpo_config .get ("num_minibatch_during_rollout" , 1 ),
162
- self .dp_size * self .minibatch_size ,
163
- ):
164
- self .profiler .log (
165
- f"Still have { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
166
- )
167
- batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
168
- effective_group_to_raw_group_mapping
169
- )
170
- self .profiler .enter ("step" )
171
- loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
172
- self .profiler .exit ("step" )
173
- self .buffer = self .buffer [
174
- effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
175
- ]
176
- # recalculate the effective group to raw group mapping
177
- effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
178
- effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
179
- assert (
180
- len (effective_group_to_raw_group_mapping )
181
- == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
182
- )
183
- if loss is not None :
184
- pbar .set_postfix ({"loss" : loss })
185
- i += 1
186
-
187
167
# receive data from producers
188
168
for r in range (self .num_producers ):
189
169
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
190
170
self .profiler .enter (f"recv_broadcast_data_P{ r } " )
191
171
raw_batch = ray_broadcast_tensor_dict (
192
- None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
172
+ None , src = 0 , device = self .device , group_name = f"sync_data_{ r } " , offload_to_cpu = False
193
173
)
194
174
self .profiler .exit (f"recv_broadcast_data_P{ r } " )
195
175
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
@@ -238,10 +218,7 @@ def loop(self) -> None:
238
218
f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
239
219
)
240
220
241
- while len (effective_group_to_raw_group_mapping ) > self .dp_size * self .batch_size :
242
- self .profiler .log (
243
- f"Received { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .batch_size } , start training after recv"
244
- )
221
+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
245
222
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
246
223
# on each dp_rank, we use minibatch_size effective samples to form a batch
247
224
batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
@@ -273,34 +250,67 @@ def loop(self) -> None:
273
250
if self .rank == 0 :
274
251
print (f"Saved model checkpoint at step { step + 1 } in folder { save_path } " )
275
252
276
- if (episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 ) and (
277
- episode != 0 or step >= self .n_behind
278
- ):
253
+ if episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 :
279
254
if self .pp_size > 1 :
280
255
print (
281
256
f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
282
257
)
283
258
else :
284
259
print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
285
- self .profiler .enter ("sync_model" )
286
260
torch .cuda .empty_cache ()
287
- state_dict = self .state_dict ()
261
+ self .state_dict_cpu = {k : v .cpu () for k , v in self .state_dict ().items ()}
262
+ cc .barrier (group_name = "consumer_pg" )
288
263
if self .pp_size > 1 :
289
264
if self .tp_rank == 0 and self .dp_rank == 0 :
265
+ self .profiler .enter ("sync_model" )
290
266
ray_broadcast_tensor_dict (
291
- state_dict ,
267
+ self . state_dict_cpu ,
292
268
src = self .num_producers ,
293
269
device = self .device ,
294
270
group_name = f"sync_model_{ self .pp_rank } " ,
271
+ offload_to_cpu = True ,
295
272
)
273
+ self .profiler .exit ("sync_model" )
296
274
else :
297
275
if self .rank == 0 :
298
- ray_broadcast_tensor_dict (
299
- state_dict , src = self .num_producers , device = self .device , group_name = "sync_model"
300
- )
301
- del state_dict
302
- torch .cuda .empty_cache ()
303
- self .profiler .exit ("sync_model" )
276
+ # ray_broadcast_tensor_dict(
277
+ # self.state_dict_cpu, src=self.num_producers, device=self.device, group_name="sync_model", offload_to_cpu=True
278
+ # )
279
+ if not self .thread_started :
280
+
281
+ def broadcast_state_dict ():
282
+ self .thread_started = True
283
+ self .profiler .enter ("sync_model" )
284
+ # lazy broadcast state_dict if and only if both consumer and all producers are idle (not broadcasting the last state_dict)
285
+ ray .get (
286
+ self .shared_sync_model_actor .increase_ready_process_count .remote (
287
+ name = self .model_sync_step
288
+ )
289
+ )
290
+ ready_process_count = ray .get (
291
+ self .shared_sync_model_actor .get_ready_process_count .remote (
292
+ name = self .model_sync_step
293
+ )
294
+ )
295
+ while ready_process_count != self .num_producers + 1 :
296
+ time .sleep (0.5 )
297
+ ready_process_count = ray .get (
298
+ self .shared_sync_model_actor .get_ready_process_count .remote (
299
+ name = self .model_sync_step
300
+ )
301
+ )
302
+ ray_broadcast_tensor_dict (
303
+ self .state_dict_cpu ,
304
+ src = self .num_producers ,
305
+ device = self .device ,
306
+ group_name = "sync_model" ,
307
+ offload_to_cpu = True ,
308
+ )
309
+ self .model_sync_step += 1
310
+ self .thread_started = False
311
+ self .profiler .exit ("sync_model" )
312
+
313
+ threading .Thread (target = broadcast_state_dict , daemon = True ).start ()
304
314
self .profiler .log (f"Peak memory usage: { torch .cuda .max_memory_allocated () / 1024 ** 2 :.2f} MB" )
305
315
306
316
def __del__ (self ):
@@ -312,6 +322,8 @@ def __del__(self):
312
322
class SimpleConsumer (BaseConsumer ):
313
323
def __init__ (
314
324
self ,
325
+ shared_sync_data_actor : SharedVariableActor ,
326
+ shared_sync_model_actor : SharedVariableActor ,
315
327
num_producers ,
316
328
num_episodes ,
317
329
rank ,
@@ -328,6 +340,8 @@ def __init__(
328
340
save_dir = "./model" ,
329
341
):
330
342
super ().__init__ (
343
+ shared_sync_data_actor ,
344
+ shared_sync_model_actor ,
331
345
num_producers ,
332
346
num_episodes ,
333
347
rank ,
0 commit comments