@@ -117,26 +117,102 @@ def loop(self) -> None:
117117 # receive data from producers
118118 for r in range (self .num_producers ):
119119 print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
120- self .buffer .extend (
121- unbind_batch (
122- ray_broadcast_tensor_dict (
123- None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
124- )
125- )
120+ raw_batch = ray_broadcast_tensor_dict (
121+ None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
126122 )
127- while len (self .buffer ) >= self .dp_size * self .minibatch_size :
128- batches = self .buffer [
129- self .dp_rank * self .minibatch_size : (self .dp_rank + 1 ) * self .minibatch_size
123+ # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124+ # we need to calculate the metrics before filtering here for logging
125+ # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
126+ raw_batch_with_reward = self .calculate_reward (
127+ {k : v .view (- 1 , v .size (- 1 )) if k != "temperature" else v for k , v in raw_batch .items ()}
128+ )
129+ raw_batch_with_reward = {
130+ k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != "temperature" else v
131+ for k , v in raw_batch_with_reward .items ()
132+ }
133+ # [batch_size, num_generations] -> [batch_size]
134+ reward = raw_batch_with_reward ["reward" ][:, :, 0 ]
135+ format_acc = raw_batch_with_reward ["format_acc" ][:, :, 0 ]
136+ ans_acc = raw_batch_with_reward ["ans_acc" ][:, :, 0 ]
137+ response_len = (
138+ raw_batch_with_reward ["response_idx" ][:, :, 1 ]
139+ - raw_batch_with_reward ["response_idx" ][:, :, 0 ]
140+ + 1
141+ ).type (torch .float32 )
142+ effective_group_mask = None
143+ if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
144+ # filter the group based on the reward and accuracy
145+ group_ans_acc_mean = ans_acc .mean (dim = 1 )
146+ effective_group_mask = torch .logical_and (
147+ group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
148+ )
149+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
150+ for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
151+ self .buffer .append (
152+ [
153+ (
154+ group_with_reward
155+ if effective_group_mask is None or effective_group_mask [group_idx ]
156+ else None
157+ ),
158+ reward [group_idx ],
159+ format_acc [group_idx ],
160+ ans_acc [group_idx ],
161+ response_len [group_idx ],
162+ ]
163+ )
164+ if effective_group_mask is not None :
165+ print (
166+ f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch_with_reward )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
167+ )
168+ # mapping the effective group to the raw group for indexing
169+ effective_group_to_raw_group_mapping = {}
170+ for buffer_idx in range (len (self .buffer )):
171+ if self .buffer [buffer_idx ][0 ] is not None :
172+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
173+ buffer_idx
174+ )
175+ print (
176+ f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
177+ )
178+
179+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
180+ # on each dp_rank, we use minibatch_size effective samples to form a batch
181+ batches = [
182+ self .buffer [effective_group_to_raw_group_mapping [i ]]
183+ for i in range (
184+ self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
185+ )
130186 ]
131- batch = bind_batch (batches )
187+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
188+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
189+ raw_mini_batches = self .buffer [
190+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
191+ ] # include the last effective sample
192+ raw_mini_batches_metric_dict = {
193+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
194+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
195+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
196+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
197+ }
198+ batch = bind_batch ([t [0 ] for t in batches ])
132199 batch = post_recv (batch )
133- loss , excessive_prompts_idx = self .step (i , pbar , ** batch )
134-
135- if excessive_prompts_idx is not None :
136- excessive_prompts = [self .buffer [idx ] for idx in excessive_prompts_idx ]
137- self .buffer = excessive_prompts + self .buffer [self .dp_size * self .minibatch_size :]
138- else :
139- self .buffer = self .buffer [self .dp_size * self .minibatch_size :]
200+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
201+ self .buffer = self .buffer [
202+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
203+ ]
204+ # recalculate the effective group to raw group mapping
205+ effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
206+ effective_group_to_raw_group_mapping = {}
207+ for buffer_idx in range (len (self .buffer )):
208+ if self .buffer [buffer_idx ][0 ] is not None :
209+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
210+ buffer_idx
211+ )
212+ assert (
213+ len (effective_group_to_raw_group_mapping )
214+ == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
215+ )
140216 if loss is not None :
141217 pbar .set_postfix ({"loss" : loss })
142218 i += 1
0 commit comments