6
6
import ray .util .collective as cc
7
7
import torch
8
8
import torch .distributed as dist
9
+ from coati .distributed .profiling_utils import CustomProfiler
9
10
from tqdm import tqdm
10
11
from transformers import AutoModelForCausalLM
11
12
@@ -36,6 +37,8 @@ def __init__(
36
37
minibatch_size : int = 1 ,
37
38
save_interval : int = 100 ,
38
39
save_dir : str = "./model" ,
40
+ enable_profiling : bool = False ,
41
+ n_behind : int = 0 ,
39
42
):
40
43
self .num_producers = num_producers
41
44
self .num_episodes = num_episodes
@@ -49,6 +52,7 @@ def __init__(
49
52
self .minibatch_size = minibatch_size
50
53
self .save_interval = save_interval
51
54
self .save_dir = save_dir
55
+ self .enable_profiling = enable_profiling
52
56
assert batch_size % minibatch_size == 0 , "batch_size should be divisible by microbatch_size"
53
57
self .num_microbatches = batch_size // minibatch_size
54
58
@@ -57,6 +61,7 @@ def __init__(
57
61
58
62
self .device = get_current_device ()
59
63
self .lr_scheduler = None
64
+ self .n_behind = n_behind
60
65
61
66
def setup (self ) -> None :
62
67
launch (self .rank , self .world_size , self .master_addr , self .master_port , local_rank = 0 )
@@ -94,13 +99,49 @@ def setup(self) -> None:
94
99
95
100
self .buffer = []
96
101
self .recv_cnt = 0
102
+ self .profiler = CustomProfiler (f"C{ self .rank } " , disabled = not self .enable_profiling )
97
103
98
104
def state_dict (self ) -> Dict [str , torch .Tensor ]:
99
105
raise NotImplementedError
100
106
101
107
def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
102
108
raise NotImplementedError
103
109
110
+ def prepare_mini_batch (self , effective_group_to_raw_group_mapping : Dict [int , int ]) -> Dict [str , torch .Tensor ]:
111
+ """
112
+ Prepare a mini-batch from the effective group to raw group mapping.
113
+ This method is used to create a mini-batch for training.
114
+ """
115
+ batches = [
116
+ self .buffer [effective_group_to_raw_group_mapping [i ]]
117
+ for i in range (self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size )
118
+ ]
119
+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
120
+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
121
+ raw_mini_batches = self .buffer [
122
+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
123
+ ] # include the last effective sample
124
+ raw_mini_batches_metric_dict = {
125
+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
126
+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
127
+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
128
+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
129
+ }
130
+ batch = bind_batch ([t [0 ] for t in batches ])
131
+ batch = post_recv (batch )
132
+ return batch , raw_mini_batches_metric_dict
133
+
134
+ def calculate_effective_group_to_raw_group_mapping (self , step ):
135
+ effective_group_to_raw_group_mapping = {}
136
+ for buffer_idx in range (len (self .buffer )):
137
+ if self .buffer [buffer_idx ][0 ] is not None :
138
+ if self .n_behind == 0 :
139
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
140
+ else :
141
+ if self .buffer [buffer_idx ][- 1 ] <= step - self .n_behind :
142
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
143
+ return effective_group_to_raw_group_mapping
144
+
104
145
def loop (self ) -> None :
105
146
print (
106
147
f"Consumer{ self .rank } num_update: { self .num_update_per_episode } , num_recv: { self .num_recv_per_update } , nmb: { self .num_microbatches } "
@@ -112,14 +153,53 @@ def loop(self) -> None:
112
153
disable = self .rank != 0 ,
113
154
) as pbar :
114
155
for step in pbar :
156
+ torch .cuda .reset_peak_memory_stats ()
115
157
i = 0
158
+
159
+ self .profiler .enter (f"rollout_episode_{ episode } _step_{ step } " )
116
160
for _ in range (self .num_recv_per_update ):
161
+ if self .n_behind > 0 :
162
+ # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
163
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping (
164
+ step = step
165
+ )
166
+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
167
+ self .profiler .log (
168
+ f"Still have { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
169
+ )
170
+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
171
+ effective_group_to_raw_group_mapping
172
+ )
173
+ self .profiler .enter ("step" )
174
+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
175
+ self .profiler .exit ("step" )
176
+ self .buffer = self .buffer [
177
+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
178
+ ]
179
+ # recalculate the effective group to raw group mapping
180
+ effective_group_to_raw_group_mapping_size_before = len (
181
+ effective_group_to_raw_group_mapping
182
+ )
183
+ effective_group_to_raw_group_mapping = (
184
+ self .calculate_effective_group_to_raw_group_mapping (step = step )
185
+ )
186
+ assert (
187
+ len (effective_group_to_raw_group_mapping )
188
+ == effective_group_to_raw_group_mapping_size_before
189
+ - self .dp_size * self .minibatch_size
190
+ )
191
+ if loss is not None :
192
+ pbar .set_postfix ({"loss" : loss })
193
+ i += 1
194
+
117
195
# receive data from producers
118
196
for r in range (self .num_producers ):
119
197
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
198
+ self .profiler .enter (f"recv_broadcast_data_P{ r } " )
120
199
raw_batch = ray_broadcast_tensor_dict (
121
200
None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
122
201
)
202
+ self .profiler .exit (f"recv_broadcast_data_P{ r } " )
123
203
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124
204
# we need to calculate the metrics before filtering here for logging
125
205
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@@ -154,63 +234,52 @@ def loop(self) -> None:
154
234
format_acc [group_idx ],
155
235
ans_acc [group_idx ],
156
236
response_len [group_idx ],
237
+ step ,
157
238
]
158
239
)
159
240
if effective_group_mask is not None :
160
241
print (
161
242
f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
162
243
)
163
244
# mapping the effective group to the raw group for indexing
164
- effective_group_to_raw_group_mapping = {}
165
- for buffer_idx in range (len (self .buffer )):
166
- if self .buffer [buffer_idx ][0 ] is not None :
167
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
168
- buffer_idx
169
- )
245
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping (
246
+ step = step
247
+ )
170
248
print (
171
249
f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
172
250
)
173
251
174
- while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
175
- # on each dp_rank, we use minibatch_size effective samples to form a batch
176
- batches = [
177
- self .buffer [effective_group_to_raw_group_mapping [i ]]
178
- for i in range (
179
- self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
252
+ if self .n_behind == 0 :
253
+ # If n_behind is 0, we start training after receiving data from producers.
254
+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
255
+ self .profiler .log (
256
+ f"Collect { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
180
257
)
181
- ]
182
- # every dp_rank will receive a complete mini-batch, no need to sync within step() later
183
- # each mini-batch use the first self.dp_size * minibatch_size effective samples
184
- raw_mini_batches = self .buffer [
185
- : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
186
- ] # include the last effective sample
187
- raw_mini_batches_metric_dict = {
188
- "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
189
- "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
190
- "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
191
- "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
192
- }
193
- batch = bind_batch ([t [0 ] for t in batches ])
194
- batch = post_recv (batch )
195
- loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
196
- self .buffer = self .buffer [
197
- effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
198
- ]
199
- # recalculate the effective group to raw group mapping
200
- effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
201
- effective_group_to_raw_group_mapping = {}
202
- for buffer_idx in range (len (self .buffer )):
203
- if self .buffer [buffer_idx ][0 ] is not None :
204
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
205
- buffer_idx
206
- )
207
- assert (
208
- len (effective_group_to_raw_group_mapping )
209
- == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
210
- )
211
- if loss is not None :
212
- pbar .set_postfix ({"loss" : loss })
213
- i += 1
258
+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
259
+ effective_group_to_raw_group_mapping
260
+ )
261
+ self .profiler .enter ("step" )
262
+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
263
+ self .profiler .exit ("step" )
264
+ self .buffer = self .buffer [
265
+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
266
+ ]
267
+ # recalculate the effective group to raw group mapping
268
+ effective_group_to_raw_group_mapping_size_before = len (
269
+ effective_group_to_raw_group_mapping
270
+ )
271
+ effective_group_to_raw_group_mapping = (
272
+ self .calculate_effective_group_to_raw_group_mapping (step = step )
273
+ )
274
+ assert (
275
+ len (effective_group_to_raw_group_mapping )
276
+ == effective_group_to_raw_group_mapping_size_before
277
+ - self .dp_size * self .minibatch_size
278
+ )
279
+ if loss is not None :
280
+ pbar .set_postfix ({"loss" : loss })
281
+ i += 1
282
+
214
283
if self .lr_scheduler is not None :
215
284
self .lr_scheduler .step ()
216
285
if (step + 1 ) % self .save_interval == 0 or (step + 1 ) == self .num_update_per_episode :
@@ -221,13 +290,16 @@ def loop(self) -> None:
221
290
if self .rank == 0 :
222
291
print (f"Saved model checkpoint at step { step + 1 } in folder { save_path } " )
223
292
224
- if episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 :
293
+ if (episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 ) and (
294
+ episode != 0 or step >= self .n_behind
295
+ ):
225
296
if self .pp_size > 1 :
226
297
print (
227
298
f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
228
299
)
229
300
else :
230
301
print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
302
+ self .profiler .enter ("sync_model" )
231
303
torch .cuda .empty_cache ()
232
304
state_dict = self .state_dict ()
233
305
if self .pp_size > 1 :
@@ -245,6 +317,13 @@ def loop(self) -> None:
245
317
)
246
318
del state_dict
247
319
torch .cuda .empty_cache ()
320
+ self .profiler .exit ("sync_model" )
321
+ self .profiler .log (f"Peak memory usage: { torch .cuda .max_memory_allocated () / 1024 ** 2 :.2f} MB" )
322
+ self .profiler .exit (f"rollout_episode_{ episode } _step_{ step } " )
323
+
324
+ def __del__ (self ):
325
+ if hasattr (self , "profiler" ):
326
+ self .profiler .close ()
248
327
249
328
250
329
@ray .remote
0 commit comments