@@ -117,32 +117,44 @@ def loop(self) -> None:
117117 # receive data from producers
118118 for r in range (self .num_producers ):
119119 print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
120- raw_batch = ray_broadcast_tensor_dict (None , src = 0 , device = self .device , group_name = f"sync_data_{ r } " )
120+ raw_batch = ray_broadcast_tensor_dict (
121+ None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
122+ )
121123 # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
122124 # we need to calculate the metrics before filtering here for logging
123125 # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
124- raw_batch_with_reward = self .calculate_reward ({k :v .view (- 1 , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch .items ()})
125- raw_batch_with_reward = {k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch_with_reward .items ()}
126+ raw_batch_with_reward = self .calculate_reward (
127+ {k : v .view (- 1 , v .size (- 1 )) if k != "temperature" else v for k , v in raw_batch .items ()}
128+ )
129+ raw_batch_with_reward = {
130+ k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != "temperature" else v
131+ for k , v in raw_batch_with_reward .items ()
132+ }
126133 # [batch_size, num_generations] -> [batch_size]
127- reward = raw_batch_with_reward ["reward" ][:,:, 0 ]
128- format_acc = raw_batch_with_reward ["format_acc" ][:,:, 0 ]
129- ans_acc = raw_batch_with_reward ["ans_acc" ][:,:, 0 ]
134+ reward = raw_batch_with_reward ["reward" ][:, :, 0 ]
135+ format_acc = raw_batch_with_reward ["format_acc" ][:, :, 0 ]
136+ ans_acc = raw_batch_with_reward ["ans_acc" ][:, :, 0 ]
130137 response_len = (
131- (raw_batch_with_reward ["response_idx" ][:, :, 1 ] - raw_batch_with_reward ["response_idx" ][:, :, 0 ] + 1 )
132- .type (torch .float32 )
133- )
138+ raw_batch_with_reward ["response_idx" ][:, :, 1 ]
139+ - raw_batch_with_reward ["response_idx" ][:, :, 0 ]
140+ + 1
141+ ).type (torch .float32 )
134142 effective_group_mask = None
135143 if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
136144 # filter the group based on the reward and accuracy
137145 group_ans_acc_mean = ans_acc .mean (dim = 1 )
138146 effective_group_mask = torch .logical_and (
139147 group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
140148 )
141- raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
149+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
142150 for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
143151 self .buffer .append (
144152 [
145- group_with_reward if effective_group_mask is None or effective_group_mask [group_idx ] else None ,
153+ (
154+ group_with_reward
155+ if effective_group_mask is None or effective_group_mask [group_idx ]
156+ else None
157+ ),
146158 reward [group_idx ],
147159 format_acc [group_idx ],
148160 ans_acc [group_idx ],
@@ -160,7 +172,9 @@ def loop(self) -> None:
160172 effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
161173 buffer_idx
162174 )
163- print (f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } " )
175+ print (
176+ f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
177+ )
164178
165179 while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
166180 # on each dp_rank, we use minibatch_size effective samples to form a batch
0 commit comments