@@ -118,48 +118,49 @@ def loop(self) -> None:
118
118
for r in range (self .num_producers ):
119
119
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
120
120
raw_batch = ray_broadcast_tensor_dict (None , src = 0 , device = self .device , group_name = f"sync_data_{ r } " )
121
- recv_effective_count = 0
122
121
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
123
122
# we need to calculate the metrics before filtering here for logging
124
- raw_batch_with_reward = unbind_batch (self .calculate_reward (raw_batch ))
125
- for group_with_reward in raw_batch_with_reward :
126
- group_reward_mean = group_with_reward ["reward" ].mean ().cpu ().item ()
127
- group_format_acc_mean = group_with_reward ["format_acc" ].mean ().cpu ().item ()
128
- group_ans_acc_mean = group_with_reward ["ans_acc" ].mean ().cpu ().item ()
129
- group_response_len = (
130
- (
131
- group_with_reward ["response_idx" ][:, 1 ]
132
- - group_with_reward ["response_idx" ][:, 0 ]
133
- + 1
134
- )
135
- .type (torch .float32 )
136
- .mean ()
137
- .cpu ()
138
- .item ()
123
+ # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
124
+ raw_batch_with_reward = self .calculate_reward ({k :v .view (- 1 , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch .items ()})
125
+ raw_batch_with_reward = {k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch_with_reward .items ()}
126
+ # [batch_size, num_generations] -> [batch_size]
127
+ group_reward_mean = raw_batch_with_reward ["reward" ][:,:,0 ].mean (dim = - 1 )
128
+ group_format_acc_mean = raw_batch_with_reward ["format_acc" ][:,:,0 ].mean (dim = - 1 )
129
+ group_ans_acc_mean = raw_batch_with_reward ["ans_acc" ][:,:,0 ].mean (dim = - 1 )
130
+ group_response_len = (
131
+ (raw_batch_with_reward ["response_idx" ][:, :, 1 ] - raw_batch_with_reward ["response_idx" ][:, :, 0 ] + 1 )
132
+ .type (torch .float32 )
133
+ .mean (dim = - 1 )
134
+ )
135
+ effective_group_mask = None
136
+ if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
137
+ # filter the group based on the reward and accuracy
138
+ effective_group_mask = torch .logical_and (
139
+ group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
139
140
)
140
- if self .grpo_config .get ("dynamic_batching" , True ):
141
- filtered_group = self .prompt_level_filtering (group_with_reward )
142
- recv_effective_count += 1 if filtered_group is not None else 0
141
+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
142
+ for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
143
143
self .buffer .append (
144
144
[
145
- filtered_group ,
146
- group_reward_mean ,
147
- group_format_acc_mean ,
148
- group_ans_acc_mean ,
149
- group_response_len ,
145
+ group_with_reward if effective_group_mask is None or effective_group_mask [ group_idx ] else None ,
146
+ group_reward_mean [ group_idx ] ,
147
+ group_format_acc_mean [ group_idx ] ,
148
+ group_ans_acc_mean [ group_idx ] ,
149
+ group_response_len [ group_idx ] ,
150
150
]
151
151
)
152
- if self . filter_range is not None :
152
+ if effective_group_mask is not None :
153
153
print (
154
- f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { recv_effective_count } "
154
+ f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch )} -> { torch . sum ( effective_group_mask ). cpu (). item () } effective groups "
155
155
)
156
- # mapping the effective group to the raw group for indexing
157
- effective_group_to_raw_group_mapping = {}
158
- for buffer_idx in range (len (self .buffer )):
159
- if self .buffer [buffer_idx ][0 ] is not None :
160
- effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
161
- buffer_idx
162
- )
156
+ # mapping the effective group to the raw group for indexing
157
+ effective_group_to_raw_group_mapping = {}
158
+ for buffer_idx in range (len (self .buffer )):
159
+ if self .buffer [buffer_idx ][0 ] is not None :
160
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
161
+ buffer_idx
162
+ )
163
+ pbar .set_postfix ({"Collect Effective Prompt" : f"{ len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } " })
163
164
164
165
while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
165
166
# on each dp_rank, we use minibatch_size effective samples to form a batch
0 commit comments