6
6
import ray .util .collective as cc
7
7
import torch
8
8
import torch .distributed as dist
9
+ from coati .distributed .profiling_utils import CustomProfiler
9
10
from tqdm import tqdm
10
11
from transformers import AutoModelForCausalLM
11
12
@@ -36,6 +37,8 @@ def __init__(
36
37
minibatch_size : int = 1 ,
37
38
save_interval : int = 100 ,
38
39
save_dir : str = "./model" ,
40
+ enable_profiling : bool = False ,
41
+ n_behind : int = 0 ,
39
42
):
40
43
self .num_producers = num_producers
41
44
self .num_episodes = num_episodes
@@ -49,6 +52,7 @@ def __init__(
49
52
self .minibatch_size = minibatch_size
50
53
self .save_interval = save_interval
51
54
self .save_dir = save_dir
55
+ self .enable_profiling = enable_profiling
52
56
assert batch_size % minibatch_size == 0 , "batch_size should be divisible by microbatch_size"
53
57
self .num_microbatches = batch_size // minibatch_size
54
58
@@ -57,6 +61,7 @@ def __init__(
57
61
58
62
self .device = get_current_device ()
59
63
self .lr_scheduler = None
64
+ self .n_behind = n_behind
60
65
61
66
def setup (self ) -> None :
62
67
launch (self .rank , self .world_size , self .master_addr , self .master_port , local_rank = 0 )
@@ -94,13 +99,49 @@ def setup(self) -> None:
94
99
95
100
self .buffer = []
96
101
self .recv_cnt = 0
102
+ self .profiler = CustomProfiler (f"C{ self .rank } " , disabled = not self .enable_profiling )
97
103
98
104
def state_dict (self ) -> Dict [str , torch .Tensor ]:
99
105
raise NotImplementedError
100
106
101
107
def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
102
108
raise NotImplementedError
103
109
110
+ def prepare_mini_batch (self , effective_group_to_raw_group_mapping : Dict [int , int ]) -> Dict [str , torch .Tensor ]:
111
+ """
112
+ Prepare a mini-batch from the effective group to raw group mapping.
113
+ This method is used to create a mini-batch for training.
114
+ """
115
+ batches = [
116
+ self .buffer [effective_group_to_raw_group_mapping [i ]]
117
+ for i in range (self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size )
118
+ ]
119
+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
120
+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
121
+ raw_mini_batches = self .buffer [
122
+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
123
+ ] # include the last effective sample
124
+ raw_mini_batches_metric_dict = {
125
+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
126
+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
127
+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
128
+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
129
+ }
130
+ batch = bind_batch ([t [0 ] for t in batches ])
131
+ batch = post_recv (batch )
132
+ return batch , raw_mini_batches_metric_dict
133
+
134
+ def calculate_effective_group_to_raw_group_mapping (self , step ):
135
+ effective_group_to_raw_group_mapping = {}
136
+ for buffer_idx in range (len (self .buffer )):
137
+ if self .buffer [buffer_idx ][0 ] is not None :
138
+ if self .n_behind == 0 :
139
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
140
+ else :
141
+ if self .buffer [buffer_idx ][- 1 ] <= step - self .n_behind :
142
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
143
+ return effective_group_to_raw_group_mapping
144
+
104
145
def loop (self ) -> None :
105
146
print (
106
147
f"Consumer{ self .rank } num_update: { self .num_update_per_episode } , num_recv: { self .num_recv_per_update } , nmb: { self .num_microbatches } "
@@ -112,14 +153,53 @@ def loop(self) -> None:
112
153
disable = self .rank != 0 ,
113
154
) as pbar :
114
155
for step in pbar :
156
+ torch .cuda .reset_peak_memory_stats ()
115
157
i = 0
158
+
159
+ self .profiler .enter (f"rollout_episode_{ episode } _step_{ step } " )
116
160
for _ in range (self .num_recv_per_update ):
161
+ if self .n_behind > 0 :
162
+ # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
163
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping (
164
+ step = step
165
+ )
166
+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
167
+ self .profiler .log (
168
+ f"Still have { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
169
+ )
170
+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
171
+ effective_group_to_raw_group_mapping
172
+ )
173
+ self .profiler .enter ("step" )
174
+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
175
+ self .profiler .exit ("step" )
176
+ self .buffer = self .buffer [
177
+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
178
+ ]
179
+ # recalculate the effective group to raw group mapping
180
+ effective_group_to_raw_group_mapping_size_before = len (
181
+ effective_group_to_raw_group_mapping
182
+ )
183
+ effective_group_to_raw_group_mapping = (
184
+ self .calculate_effective_group_to_raw_group_mapping (step = step )
185
+ )
186
+ assert (
187
+ len (effective_group_to_raw_group_mapping )
188
+ == effective_group_to_raw_group_mapping_size_before
189
+ - self .dp_size * self .minibatch_size
190
+ )
191
+ if loss is not None :
192
+ pbar .set_postfix ({"loss" : loss })
193
+ i += 1
194
+
117
195
# receive data from producers
118
196
for r in range (self .num_producers ):
119
197
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
198
+ self .profiler .enter (f"recv_broadcast_data_P{ r } " )
120
199
raw_batch = ray_broadcast_tensor_dict (
121
200
None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
122
201
)
202
+ self .profiler .exit (f"recv_broadcast_data_P{ r } " )
123
203
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124
204
# we need to calculate the metrics before filtering here for logging
125
205
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@@ -153,63 +233,52 @@ def loop(self) -> None:
153
233
format_acc [group_idx ],
154
234
ans_acc [group_idx ],
155
235
response_len [group_idx ],
236
+ step ,
156
237
]
157
238
)
158
239
if effective_group_mask is not None :
159
240
print (
160
241
f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
161
242
)
162
243
# mapping the effective group to the raw group for indexing
163
- effective_group_to_raw_group_mapping = {}
164
- for buffer_idx in range (len (self .buffer )):
165
- if self .buffer [buffer_idx ][0 ] is not None :
166
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
167
- buffer_idx
168
- )
244
+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping (
245
+ step = step
246
+ )
169
247
print (
170
248
f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
171
249
)
172
250
173
- while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
174
- # on each dp_rank, we use minibatch_size effective samples to form a batch
175
- batches = [
176
- self .buffer [effective_group_to_raw_group_mapping [i ]]
177
- for i in range (
178
- self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
251
+ if self .n_behind == 0 :
252
+ # If n_behind is 0, we start training after receiving data from producers.
253
+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
254
+ self .profiler .log (
255
+ f"Collect { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
179
256
)
180
- ]
181
- # every dp_rank will receive a complete mini-batch, no need to sync within step() later
182
- # each mini-batch use the first self.dp_size * minibatch_size effective samples
183
- raw_mini_batches = self .buffer [
184
- : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
185
- ] # include the last effective sample
186
- raw_mini_batches_metric_dict = {
187
- "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
188
- "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
189
- "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
190
- "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
191
- }
192
- batch = bind_batch ([t [0 ] for t in batches ])
193
- batch = post_recv (batch )
194
- loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
195
- self .buffer = self .buffer [
196
- effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
197
- ]
198
- # recalculate the effective group to raw group mapping
199
- effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
200
- effective_group_to_raw_group_mapping = {}
201
- for buffer_idx in range (len (self .buffer )):
202
- if self .buffer [buffer_idx ][0 ] is not None :
203
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
204
- buffer_idx
205
- )
206
- assert (
207
- len (effective_group_to_raw_group_mapping )
208
- == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
209
- )
210
- if loss is not None :
211
- pbar .set_postfix ({"loss" : loss })
212
- i += 1
257
+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
258
+ effective_group_to_raw_group_mapping
259
+ )
260
+ self .profiler .enter ("step" )
261
+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
262
+ self .profiler .exit ("step" )
263
+ self .buffer = self .buffer [
264
+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
265
+ ]
266
+ # recalculate the effective group to raw group mapping
267
+ effective_group_to_raw_group_mapping_size_before = len (
268
+ effective_group_to_raw_group_mapping
269
+ )
270
+ effective_group_to_raw_group_mapping = (
271
+ self .calculate_effective_group_to_raw_group_mapping (step = step )
272
+ )
273
+ assert (
274
+ len (effective_group_to_raw_group_mapping )
275
+ == effective_group_to_raw_group_mapping_size_before
276
+ - self .dp_size * self .minibatch_size
277
+ )
278
+ if loss is not None :
279
+ pbar .set_postfix ({"loss" : loss })
280
+ i += 1
281
+
213
282
if self .lr_scheduler is not None :
214
283
self .lr_scheduler .step ()
215
284
if (step + 1 ) % self .save_interval == 0 or (step + 1 ) == self .num_update_per_episode :
@@ -220,13 +289,16 @@ def loop(self) -> None:
220
289
if self .rank == 0 :
221
290
print (f"Saved model checkpoint at step { step + 1 } in folder { save_path } " )
222
291
223
- if episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 :
292
+ if (episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 ) and (
293
+ episode != 0 or step >= self .n_behind
294
+ ):
224
295
if self .pp_size > 1 :
225
296
print (
226
297
f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
227
298
)
228
299
else :
229
300
print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
301
+ self .profiler .enter ("sync_model" )
230
302
torch .cuda .empty_cache ()
231
303
state_dict = self .state_dict ()
232
304
if self .pp_size > 1 :
@@ -244,6 +316,13 @@ def loop(self) -> None:
244
316
)
245
317
del state_dict
246
318
torch .cuda .empty_cache ()
319
+ self .profiler .exit ("sync_model" )
320
+ self .profiler .log (f"Peak memory usage: { torch .cuda .max_memory_allocated () / 1024 ** 2 :.2f} MB" )
321
+ self .profiler .exit (f"rollout_episode_{ episode } _step_{ step } " )
322
+
323
+ def __del__ (self ):
324
+ if hasattr (self , "profiler" ):
325
+ self .profiler .close ()
247
326
248
327
249
328
@ray .remote
0 commit comments