66import ray .util .collective as cc
77import torch
88import torch .distributed as dist
9+ from coati .distributed .profiling_utils import CustomProfiler
910from tqdm import tqdm
1011from transformers import AutoModelForCausalLM
1112
@@ -36,6 +37,8 @@ def __init__(
3637 minibatch_size : int = 1 ,
3738 save_interval : int = 100 ,
3839 save_dir : str = "./model" ,
40+ enable_profiling : bool = False ,
41+ n_behind : int = 0 ,
3942 ):
4043 self .num_producers = num_producers
4144 self .num_episodes = num_episodes
@@ -49,6 +52,7 @@ def __init__(
4952 self .minibatch_size = minibatch_size
5053 self .save_interval = save_interval
5154 self .save_dir = save_dir
55+ self .enable_profiling = enable_profiling
5256 assert batch_size % minibatch_size == 0 , "batch_size should be divisible by microbatch_size"
5357 self .num_microbatches = batch_size // minibatch_size
5458
@@ -57,6 +61,7 @@ def __init__(
5761
5862 self .device = get_current_device ()
5963 self .lr_scheduler = None
64+ self .n_behind = n_behind
6065
6166 def setup (self ) -> None :
6267 launch (self .rank , self .world_size , self .master_addr , self .master_port , local_rank = 0 )
@@ -94,13 +99,49 @@ def setup(self) -> None:
9499
95100 self .buffer = []
96101 self .recv_cnt = 0
102+ self .profiler = CustomProfiler (f"C{ self .rank } " , disabled = not self .enable_profiling )
97103
98104 def state_dict (self ) -> Dict [str , torch .Tensor ]:
99105 raise NotImplementedError
100106
101107 def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
102108 raise NotImplementedError
103109
110+ def prepare_mini_batch (self , effective_group_to_raw_group_mapping : Dict [int , int ]) -> Dict [str , torch .Tensor ]:
111+ """
112+ Prepare a mini-batch from the effective group to raw group mapping.
113+ This method is used to create a mini-batch for training.
114+ """
115+ batches = [
116+ self .buffer [effective_group_to_raw_group_mapping [i ]]
117+ for i in range (self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size )
118+ ]
119+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
120+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
121+ raw_mini_batches = self .buffer [
122+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
123+ ] # include the last effective sample
124+ raw_mini_batches_metric_dict = {
125+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
126+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
127+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
128+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
129+ }
130+ batch = bind_batch ([t [0 ] for t in batches ])
131+ batch = post_recv (batch )
132+ return batch , raw_mini_batches_metric_dict
133+
134+ def calculate_effective_group_to_raw_group_mapping (self , step ):
135+ effective_group_to_raw_group_mapping = {}
136+ for buffer_idx in range (len (self .buffer )):
137+ if self .buffer [buffer_idx ][0 ] is not None :
138+ if self .n_behind == 0 :
139+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
140+ else :
141+ if self .buffer [buffer_idx ][- 1 ] <= step - self .n_behind :
142+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = buffer_idx
143+ return effective_group_to_raw_group_mapping
144+
104145 def loop (self ) -> None :
105146 print (
106147 f"Consumer{ self .rank } num_update: { self .num_update_per_episode } , num_recv: { self .num_recv_per_update } , nmb: { self .num_microbatches } "
@@ -112,14 +153,53 @@ def loop(self) -> None:
112153 disable = self .rank != 0 ,
113154 ) as pbar :
114155 for step in pbar :
156+ torch .cuda .reset_peak_memory_stats ()
115157 i = 0
158+
159+ self .profiler .enter (f"rollout_episode_{ episode } _step_{ step } " )
116160 for _ in range (self .num_recv_per_update ):
161+ if self .n_behind > 0 :
162+ # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
163+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping (
164+ step = step
165+ )
166+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
167+ self .profiler .log (
168+ f"Still have { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
169+ )
170+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
171+ effective_group_to_raw_group_mapping
172+ )
173+ self .profiler .enter ("step" )
174+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
175+ self .profiler .exit ("step" )
176+ self .buffer = self .buffer [
177+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
178+ ]
179+ # recalculate the effective group to raw group mapping
180+ effective_group_to_raw_group_mapping_size_before = len (
181+ effective_group_to_raw_group_mapping
182+ )
183+ effective_group_to_raw_group_mapping = (
184+ self .calculate_effective_group_to_raw_group_mapping (step = step )
185+ )
186+ assert (
187+ len (effective_group_to_raw_group_mapping )
188+ == effective_group_to_raw_group_mapping_size_before
189+ - self .dp_size * self .minibatch_size
190+ )
191+ if loss is not None :
192+ pbar .set_postfix ({"loss" : loss })
193+ i += 1
194+
117195 # receive data from producers
118196 for r in range (self .num_producers ):
119197 print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
198+ self .profiler .enter (f"recv_broadcast_data_P{ r } " )
120199 raw_batch = ray_broadcast_tensor_dict (
121200 None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
122201 )
202+ self .profiler .exit (f"recv_broadcast_data_P{ r } " )
123203 # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124204 # we need to calculate the metrics before filtering here for logging
125205 # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@@ -153,63 +233,52 @@ def loop(self) -> None:
153233 format_acc [group_idx ],
154234 ans_acc [group_idx ],
155235 response_len [group_idx ],
236+ step ,
156237 ]
157238 )
158239 if effective_group_mask is not None :
159240 print (
160241 f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
161242 )
162243 # mapping the effective group to the raw group for indexing
163- effective_group_to_raw_group_mapping = {}
164- for buffer_idx in range (len (self .buffer )):
165- if self .buffer [buffer_idx ][0 ] is not None :
166- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
167- buffer_idx
168- )
244+ effective_group_to_raw_group_mapping = self .calculate_effective_group_to_raw_group_mapping (
245+ step = step
246+ )
169247 print (
170248 f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
171249 )
172250
173- while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
174- # on each dp_rank, we use minibatch_size effective samples to form a batch
175- batches = [
176- self .buffer [effective_group_to_raw_group_mapping [i ]]
177- for i in range (
178- self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
251+ if self .n_behind == 0 :
252+ # If n_behind is 0, we start training after receiving data from producers.
253+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
254+ self .profiler .log (
255+ f"Collect { len (effective_group_to_raw_group_mapping )} effective groups, greater than { self .dp_size * self .minibatch_size } , start training"
179256 )
180- ]
181- # every dp_rank will receive a complete mini-batch, no need to sync within step() later
182- # each mini-batch use the first self.dp_size * minibatch_size effective samples
183- raw_mini_batches = self .buffer [
184- : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
185- ] # include the last effective sample
186- raw_mini_batches_metric_dict = {
187- "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
188- "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
189- "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
190- "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
191- }
192- batch = bind_batch ([t [0 ] for t in batches ])
193- batch = post_recv (batch )
194- loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
195- self .buffer = self .buffer [
196- effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
197- ]
198- # recalculate the effective group to raw group mapping
199- effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
200- effective_group_to_raw_group_mapping = {}
201- for buffer_idx in range (len (self .buffer )):
202- if self .buffer [buffer_idx ][0 ] is not None :
203- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
204- buffer_idx
205- )
206- assert (
207- len (effective_group_to_raw_group_mapping )
208- == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
209- )
210- if loss is not None :
211- pbar .set_postfix ({"loss" : loss })
212- i += 1
257+ batch , raw_mini_batches_metric_dict = self .prepare_mini_batch (
258+ effective_group_to_raw_group_mapping
259+ )
260+ self .profiler .enter ("step" )
261+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
262+ self .profiler .exit ("step" )
263+ self .buffer = self .buffer [
264+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
265+ ]
266+ # recalculate the effective group to raw group mapping
267+ effective_group_to_raw_group_mapping_size_before = len (
268+ effective_group_to_raw_group_mapping
269+ )
270+ effective_group_to_raw_group_mapping = (
271+ self .calculate_effective_group_to_raw_group_mapping (step = step )
272+ )
273+ assert (
274+ len (effective_group_to_raw_group_mapping )
275+ == effective_group_to_raw_group_mapping_size_before
276+ - self .dp_size * self .minibatch_size
277+ )
278+ if loss is not None :
279+ pbar .set_postfix ({"loss" : loss })
280+ i += 1
281+
213282 if self .lr_scheduler is not None :
214283 self .lr_scheduler .step ()
215284 if (step + 1 ) % self .save_interval == 0 or (step + 1 ) == self .num_update_per_episode :
@@ -220,13 +289,16 @@ def loop(self) -> None:
220289 if self .rank == 0 :
221290 print (f"Saved model checkpoint at step { step + 1 } in folder { save_path } " )
222291
223- if episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 :
292+ if (episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 ) and (
293+ episode != 0 or step >= self .n_behind
294+ ):
224295 if self .pp_size > 1 :
225296 print (
226297 f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
227298 )
228299 else :
229300 print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
301+ self .profiler .enter ("sync_model" )
230302 torch .cuda .empty_cache ()
231303 state_dict = self .state_dict ()
232304 if self .pp_size > 1 :
@@ -244,6 +316,13 @@ def loop(self) -> None:
244316 )
245317 del state_dict
246318 torch .cuda .empty_cache ()
319+ self .profiler .exit ("sync_model" )
320+ self .profiler .log (f"Peak memory usage: { torch .cuda .max_memory_allocated () / 1024 ** 2 :.2f} MB" )
321+ self .profiler .exit (f"rollout_episode_{ episode } _step_{ step } " )
322+
323+ def __del__ (self ):
324+ if hasattr (self , "profiler" ):
325+ self .profiler .close ()
247326
248327
249328@ray .remote
0 commit comments