66import ray .util .collective as cc
77import torch
88import torch .distributed as dist
9+ from coati .distributed .profiling_utils import CustomProfiler
910from tqdm import tqdm
1011from transformers import AutoModelForCausalLM
1112
@@ -36,6 +37,8 @@ def __init__(
3637 minibatch_size : int = 1 ,
3738 save_interval : int = 100 ,
3839 save_dir : str = "./model" ,
40+ enable_profiling : bool = False ,
41+ n_behind : int = 0 ,
3942 ):
4043 self .num_producers = num_producers
4144 self .num_episodes = num_episodes
@@ -49,6 +52,7 @@ def __init__(
4952 self .minibatch_size = minibatch_size
5053 self .save_interval = save_interval
5154 self .save_dir = save_dir
55+ self .enable_profiling = enable_profiling
5256 assert batch_size % minibatch_size == 0 , "batch_size should be divisible by microbatch_size"
5357 self .num_microbatches = batch_size // minibatch_size
5458
@@ -57,6 +61,7 @@ def __init__(
5761
5862 self .device = get_current_device ()
5963 self .lr_scheduler = None
64+ self .n_behind = n_behind
6065
6166 def setup (self ) -> None :
6267 launch (self .rank , self .world_size , self .master_addr , self .master_port , local_rank = 0 )
@@ -94,13 +99,49 @@ def setup(self) -> None:
9499
95100 self .buffer = []
96101 self .recv_cnt = 0
102+ self .profiler = CustomProfiler (f"C{ self .rank } " , disabled = not self .enable_profiling )
97103
98104 def state_dict (self ) -> Dict [str , torch .Tensor ]:
99105 raise NotImplementedError
100106
101107 def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
102108 raise NotImplementedError
103109
110+ def prepare_mini_batch (self , effective_group_to_raw_group_mapping : Dict [int , int ]) -> Dict [str , torch .Tensor ]:
111+ """
112+ Prepare a mini-batch from the effective group to raw group mapping.
113+ This method is used to create a mini-batch for training.
114+ """
115+ batches = [
116+ self .buffer [effective_group_to_raw_group_mapping [i ]]
117+ for i in range (self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size )
118+ ]
119+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
120+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
121+ raw_mini_batches = self .buffer [
122+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
123+ ] # include the last effective sample
124+ raw_mini_batches_metric_dict = {
125+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
126+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
127+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
128+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
129+ }
130+ batch = bind_batch ([t [0 ] for t in batches ])
131+ batch = post_recv (batch )
132+ return batch , raw_mini_batches_metric_dict
133+
134+ def calculate_effective_group_to_raw_group_mapping (self , step ):
135+ effective_group_to_raw_group_mapping = {}
136+ for buffer_idx in range (len (self .buffer )):
137+ if self .buffer [buffer_idx ][0 ] is not None :
138+ if self .n_behind == 0 :
139+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
140+ else :
141+ if self .buffer [buffer_idx ][- 1 ] <= step - self .n_behind :
142+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
143+ return effective_group_to_raw_group_mapping
144+
104145 def loop (self ) -> None :
105146 print (
106147 f"Consumer{ self .rank } num_update: { self .num_update_per_episode } , num_recv: { self .num_recv_per_update } , nmb: { self .num_microbatches } "
@@ -112,14 +153,53 @@ def loop(self) -> None:
112153 disable = self .rank != 0 ,
113154 ) as pbar :
114155 for step in pbar :
156+ torch .cuda .reset_peak_memory_stats ()
115157 i = 0
158+
159+ self .profiler .enter (f"rollout_episode_{ episode } _step_{ step } " )
116160 for _ in range (self .num_recv_per_update ):
161+ if self .n_behind > 0 :
162+ # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
163+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping (
164+ step = step
165+ )
166+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
167+ self .profiler .log (
168+ f"Still have { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
169+ )
170+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
171+ effective_group_to_raw_group_mapping
172+ )
173+ self .profiler .enter ("step" )
174+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
175+ self .profiler .exit ("step" )
176+ self .buffer = self .buffer [
177+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
178+ ]
179+ # recalculate the effective group to raw group mapping
180+ effective_group_to_raw_group_mapping_size_before = len (
181+ effective_group_to_raw_group_mapping
182+ )
183+ effective_group_to_raw_group_mapping = (
184+ self .calculate_effective_group_to_raw_group_mapping (step = step )
185+ )
186+ assert (
187+ len (effective_group_to_raw_group_mapping )
188+ == effective_group_to_raw_group_mapping_size_before
189+ - self .dp_size * self .minibatch_size
190+ )
191+ if loss is not None :
192+ pbar .set_postfix ({"loss" : loss })
193+ i += 1
194+
117195 # receive data from producers
118196 for r in range (self .num_producers ):
119197 print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
198+ self .profiler .enter (f"recv_broadcast_data_P{ r } " )
120199 raw_batch = ray_broadcast_tensor_dict (
121200 None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
122201 )
202+ self .profiler .exit (f"recv_broadcast_data_P{ r } " )
123203 # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124204 # we need to calculate the metrics before filtering here for logging
125205 # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@@ -154,63 +234,52 @@ def loop(self) -> None:
154234 format_acc [group_idx ],
155235 ans_acc [group_idx ],
156236 response_len [group_idx ],
237+ step ,
157238 ]
158239 )
159240 if effective_group_mask is not None :
160241 print (
161242 f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
162243 )
163244 # mapping the effective group to the raw group for indexing
164- effective_group_to_raw_group_mapping = {}
165- for buffer_idx in range (len (self .buffer )):
166- if self .buffer [buffer_idx ][0 ] is not None :
167- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
168- buffer_idx
169- )
245+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping (
246+ step = step
247+ )
170248 print (
171249 f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
172250 )
173251
174- while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
175- # on each dp_rank, we use minibatch_size effective samples to form a batch
176- batches = [
177- self .buffer [effective_group_to_raw_group_mapping [i ]]
178- for i in range (
179- self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
252+ if self .n_behind == 0 :
253+ # If n_behind is 0, we start training after receiving data from producers.
254+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
255+ self .profiler .log (
256+ f"Collect { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
180257 )
181- ]
182- # every dp_rank will receive a complete mini-batch, no need to sync within step() later
183- # each mini-batch use the first self.dp_size * minibatch_size effective samples
184- raw_mini_batches = self .buffer [
185- : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
186- ] # include the last effective sample
187- raw_mini_batches_metric_dict = {
188- "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
189- "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
190- "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
191- "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
192- }
193- batch = bind_batch ([t [0 ] for t in batches ])
194- batch = post_recv (batch )
195- loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
196- self .buffer = self .buffer [
197- effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
198- ]
199- # recalculate the effective group to raw group mapping
200- effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
201- effective_group_to_raw_group_mapping = {}
202- for buffer_idx in range (len (self .buffer )):
203- if self .buffer [buffer_idx ][0 ] is not None :
204- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
205- buffer_idx
206- )
207- assert (
208- len (effective_group_to_raw_group_mapping )
209- == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
210- )
211- if loss is not None :
212- pbar .set_postfix ({"loss" : loss })
213- i += 1
258+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
259+ effective_group_to_raw_group_mapping
260+ )
261+ self .profiler .enter ("step" )
262+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
263+ self .profiler .exit ("step" )
264+ self .buffer = self .buffer [
265+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
266+ ]
267+ # recalculate the effective group to raw group mapping
268+ effective_group_to_raw_group_mapping_size_before = len (
269+ effective_group_to_raw_group_mapping
270+ )
271+ effective_group_to_raw_group_mapping = (
272+ self .calculate_effective_group_to_raw_group_mapping (step = step )
273+ )
274+ assert (
275+ len (effective_group_to_raw_group_mapping )
276+ == effective_group_to_raw_group_mapping_size_before
277+ - self .dp_size * self .minibatch_size
278+ )
279+ if loss is not None :
280+ pbar .set_postfix ({"loss" : loss })
281+ i += 1
282+
214283 if self .lr_scheduler is not None :
215284 self .lr_scheduler .step ()
216285 if (step + 1 ) % self .save_interval == 0 or (step + 1 ) == self .num_update_per_episode :
@@ -221,13 +290,16 @@ def loop(self) -> None:
221290 if self .rank == 0 :
222291 print (f"Saved model checkpoint at step { step + 1 } in folder { save_path } " )
223292
224- if episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 :
293+ if (episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 ) and (
294+ episode != 0 or step >= self .n_behind
295+ ):
225296 if self .pp_size > 1 :
226297 print (
227298 f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
228299 )
229300 else :
230301 print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
302+ self .profiler .enter ("sync_model" )
231303 torch .cuda .empty_cache ()
232304 state_dict = self .state_dict ()
233305 if self .pp_size > 1 :
@@ -245,6 +317,13 @@ def loop(self) -> None:
245317 )
246318 del state_dict
247319 torch .cuda .empty_cache ()
320+ self .profiler .exit ("sync_model" )
321+ self .profiler .log (f"Peak memory usage: { torch .cuda .max_memory_allocated () / 1024 ** 2 :.2f} MB" )
322+ self .profiler .exit (f"rollout_episode_{ episode } _step_{ step } " )
323+
324+ def __del__ (self ):
325+ if hasattr (self , "profiler" ):
326+ self .profiler .close ()
248327
249328
250329@ray .remote
0 commit comments