@@ -118,48 +118,49 @@ def loop(self) -> None:
118118 for r in range (self .num_producers ):
119119 print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
120120 raw_batch = ray_broadcast_tensor_dict (None , src = 0 , device = self .device , group_name = f"sync_data_{ r } " )
121- recv_effective_count = 0
122121 # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
123122 # we need to calculate the metrics before filtering here for logging
124- raw_batch_with_reward = unbind_batch (self .calculate_reward (raw_batch ))
125- for group_with_reward in raw_batch_with_reward :
126- group_reward_mean = group_with_reward ["reward" ].mean ().cpu ().item ()
127- group_format_acc_mean = group_with_reward ["format_acc" ].mean ().cpu ().item ()
128- group_ans_acc_mean = group_with_reward ["ans_acc" ].mean ().cpu ().item ()
129- group_response_len = (
130- (
131- group_with_reward ["response_idx" ][:, 1 ]
132- - group_with_reward ["response_idx" ][:, 0 ]
133- + 1
134- )
135- .type (torch .float32 )
136- .mean ()
137- .cpu ()
138- .item ()
123+ # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
124+ raw_batch_with_reward = self .calculate_reward ({k :v .view (- 1 , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch .items ()})
125+ raw_batch_with_reward = {k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch_with_reward .items ()}
126+ # [batch_size, num_generations] -> [batch_size]
127+ group_reward_mean = raw_batch_with_reward ["reward" ][:,:,0 ].mean (dim = - 1 )
128+ group_format_acc_mean = raw_batch_with_reward ["format_acc" ][:,:,0 ].mean (dim = - 1 )
129+ group_ans_acc_mean = raw_batch_with_reward ["ans_acc" ][:,:,0 ].mean (dim = - 1 )
130+ group_response_len = (
131+ (raw_batch_with_reward ["response_idx" ][:, :, 1 ] - raw_batch_with_reward ["response_idx" ][:, :, 0 ] + 1 )
132+ .type (torch .float32 )
133+ .mean (dim = - 1 )
134+ )
135+ effective_group_mask = None
136+ if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
137+ # filter the group based on the reward and accuracy
138+ effective_group_mask = torch .logical_and (
139+ group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
139140 )
140- if self .grpo_config .get ("dynamic_batching" , True ):
141- filtered_group = self .prompt_level_filtering (group_with_reward )
142- recv_effective_count += 1 if filtered_group is not None else 0
141+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
142+ for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
143143 self .buffer .append (
144144 [
145- filtered_group ,
146- group_reward_mean ,
147- group_format_acc_mean ,
148- group_ans_acc_mean ,
149- group_response_len ,
145+ group_with_reward if effective_group_mask is None or effective_group_mask [ group_idx ] else None ,
146+ group_reward_mean [ group_idx ] ,
147+ group_format_acc_mean [ group_idx ] ,
148+ group_ans_acc_mean [ group_idx ] ,
149+ group_response_len [ group_idx ] ,
150150 ]
151151 )
152- if self . filter_range is not None :
152+ if effective_group_mask is not None :
153153 print (
154- f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { recv_effective_count } "
154+ f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { torch . sum ( effective_group_mask ). cpu (). item () } effective groups "
155155 )
156- # mapping the effective group to the raw group for indexing
157- effective_group_to_raw_group_mapping = {}
158- for buffer_idx in range (len (self .buffer )):
159- if self .buffer [buffer_idx ][0 ] is not None :
160- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
161- buffer_idx
162- )
156+ # mapping the effective group to the raw group for indexing
157+ effective_group_to_raw_group_mapping = {}
158+ for buffer_idx in range (len (self .buffer )):
159+ if self .buffer [buffer_idx ][0 ] is not None :
160+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
161+ buffer_idx
162+ )
163+ pbar .set_postfix ({"Collect Effective Prompt" : f"{ len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } " })
163164
164165 while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
165166 # on each dp_rank, we use minibatch_size effective samples to form a batch
0 commit comments