6
6
import ray .util .collective as cc
7
7
import torch
8
8
import torch .distributed as dist
9
+ from coati .distributed .profiling_utils import CustomProfiler
9
10
from tqdm import tqdm
10
11
from transformers import AutoModelForCausalLM
11
12
@@ -36,6 +37,8 @@ def __init__(
36
37
minibatch_size : int = 1 ,
37
38
save_interval : int = 100 ,
38
39
save_dir : str = "./model" ,
40
+ enable_profiling : bool = False ,
41
+ n_behind : int = 0 ,
39
42
):
40
43
self .num_producers = num_producers
41
44
self .num_episodes = num_episodes
@@ -49,6 +52,7 @@ def __init__(
49
52
self .minibatch_size = minibatch_size
50
53
self .save_interval = save_interval
51
54
self .save_dir = save_dir
55
+ self .enable_profiling = enable_profiling
52
56
assert batch_size % minibatch_size == 0 , "batch_size should be divisible by microbatch_size"
53
57
self .num_microbatches = batch_size // minibatch_size
54
58
@@ -57,6 +61,7 @@ def __init__(
57
61
58
62
self .device = get_current_device ()
59
63
self .lr_scheduler = None
64
+ self .n_behind = n_behind
60
65
61
66
def setup (self ) -> None :
62
67
launch (self .rank , self .world_size , self .master_addr , self .master_port , local_rank = 0 )
@@ -94,13 +99,45 @@ def setup(self) -> None:
94
99
95
100
self .buffer = []
96
101
self .recv_cnt = 0
102
+ self .profiler = CustomProfiler (f"C{ self .rank } " , disabled = not self .enable_profiling )
97
103
98
104
def state_dict (self ) -> Dict [str , torch .Tensor ]:
99
105
raise NotImplementedError
100
106
101
107
def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
102
108
raise NotImplementedError
103
109
110
+ def prepare_mini_batch (self , effective_group_to_raw_group_mapping : Dict [int , int ]) -> Dict [str , torch .Tensor ]:
111
+ """
112
+ Prepare a mini-batch from the effective group to raw group mapping.
113
+ This method is used to create a mini-batch for training.
114
+ """
115
+ batches = [
116
+ self .buffer [effective_group_to_raw_group_mapping [i ]]
117
+ for i in range (self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size )
118
+ ]
119
+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
120
+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
121
+ raw_mini_batches = self .buffer [
122
+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
123
+ ] # include the last effective sample
124
+ raw_mini_batches_metric_dict = {
125
+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
126
+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
127
+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
128
+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
129
+ }
130
+ batch = bind_batch ([t [0 ] for t in batches ])
131
+ batch = post_recv (batch )
132
+ return batch , raw_mini_batches_metric_dict
133
+
134
+ def calculate_effective_group_to_raw_group_mapping (self ):
135
+ effective_group_to_raw_group_mapping = {}
136
+ for buffer_idx in range (len (self .buffer )):
137
+ if self .buffer [buffer_idx ][0 ] is not None :
138
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
139
+ return effective_group_to_raw_group_mapping
140
+
104
141
def loop (self ) -> None :
105
142
print (
106
143
f"Consumer{ self .rank } num_update: { self .num_update_per_episode } , num_recv: { self .num_recv_per_update } , nmb: { self .num_microbatches } "
@@ -112,14 +149,49 @@ def loop(self) -> None:
112
149
disable = self .rank != 0 ,
113
150
) as pbar :
114
151
for step in pbar :
152
+ torch .cuda .reset_peak_memory_stats ()
115
153
i = 0
116
154
for _ in range (self .num_recv_per_update ):
155
+ # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
156
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
157
+ while len (effective_group_to_raw_group_mapping ) > max (
158
+ self .dp_size * self .batch_size
159
+ - self .dp_size
160
+ * self .minibatch_size
161
+ * self .grpo_config .get ("num_minibatch_during_rollout" , 1 ),
162
+ self .dp_size * self .minibatch_size ,
163
+ ):
164
+ self .profiler .log (
165
+ f"Still have { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
166
+ )
167
+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
168
+ effective_group_to_raw_group_mapping
169
+ )
170
+ self .profiler .enter ("step" )
171
+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
172
+ self .profiler .exit ("step" )
173
+ self .buffer = self .buffer [
174
+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
175
+ ]
176
+ # recalculate the effective group to raw group mapping
177
+ effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
178
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
179
+ assert (
180
+ len (effective_group_to_raw_group_mapping )
181
+ == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
182
+ )
183
+ if loss is not None :
184
+ pbar .set_postfix ({"loss" : loss })
185
+ i += 1
186
+
117
187
# receive data from producers
118
188
for r in range (self .num_producers ):
119
189
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
190
+ self .profiler .enter (f"recv_broadcast_data_P{ r } " )
120
191
raw_batch = ray_broadcast_tensor_dict (
121
192
None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
122
193
)
194
+ self .profiler .exit (f"recv_broadcast_data_P{ r } " )
123
195
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124
196
# we need to calculate the metrics before filtering here for logging
125
197
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@@ -161,49 +233,29 @@ def loop(self) -> None:
161
233
f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
162
234
)
163
235
# mapping the effective group to the raw group for indexing
164
- effective_group_to_raw_group_mapping = {}
165
- for buffer_idx in range (len (self .buffer )):
166
- if self .buffer [buffer_idx ][0 ] is not None :
167
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
168
- buffer_idx
169
- )
236
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
170
237
print (
171
238
f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
172
239
)
173
240
174
- while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
241
+ while len (effective_group_to_raw_group_mapping ) > self .dp_size * self .batch_size :
242
+ self .profiler .log (
243
+ f"Received { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .batch_size } , start training after recv"
244
+ )
245
+ # always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
175
246
# on each dp_rank, we use minibatch_size effective samples to form a batch
176
- batches = [
177
- self .buffer [effective_group_to_raw_group_mapping [i ]]
178
- for i in range (
179
- self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
180
- )
181
- ]
182
- # every dp_rank will receive a complete mini-batch, no need to sync within step() later
183
- # each mini-batch use the first self.dp_size * minibatch_size effective samples
184
- raw_mini_batches = self .buffer [
185
- : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
186
- ] # include the last effective sample
187
- raw_mini_batches_metric_dict = {
188
- "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
189
- "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
190
- "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
191
- "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
192
- }
193
- batch = bind_batch ([t [0 ] for t in batches ])
194
- batch = post_recv (batch )
247
+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
248
+ effective_group_to_raw_group_mapping
249
+ )
250
+ self .profiler .enter ("step" )
195
251
loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
252
+ self .profiler .exit ("step" )
196
253
self .buffer = self .buffer [
197
254
effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
198
255
]
199
256
# recalculate the effective group to raw group mapping
200
257
effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
201
- effective_group_to_raw_group_mapping = {}
202
- for buffer_idx in range (len (self .buffer )):
203
- if self .buffer [buffer_idx ][0 ] is not None :
204
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
205
- buffer_idx
206
- )
258
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping ()
207
259
assert (
208
260
len (effective_group_to_raw_group_mapping )
209
261
== effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
@@ -221,13 +273,16 @@ def loop(self) -> None:
221
273
if self .rank == 0 :
222
274
print (f"Saved model checkpoint at step { step + 1 } in folder { save_path } " )
223
275
224
- if episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 :
276
+ if (episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 ) and (
277
+ episode != 0 or step >= self .n_behind
278
+ ):
225
279
if self .pp_size > 1 :
226
280
print (
227
281
f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
228
282
)
229
283
else :
230
284
print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
285
+ self .profiler .enter ("sync_model" )
231
286
torch .cuda .empty_cache ()
232
287
state_dict = self .state_dict ()
233
288
if self .pp_size > 1 :
@@ -245,6 +300,12 @@ def loop(self) -> None:
245
300
)
246
301
del state_dict
247
302
torch .cuda .empty_cache ()
303
+ self .profiler .exit ("sync_model" )
304
+ self .profiler .log (f"Peak memory usage: { torch .cuda .max_memory_allocated () / 1024 ** 2 :.2f} MB" )
305
+
306
+ def __del__ (self ):
307
+ if hasattr (self , "profiler" ):
308
+ self .profiler .close ()
248
309
249
310
250
311
@ray .remote
0 commit comments