@@ -117,26 +117,102 @@ def loop(self) -> None:
117
117
# receive data from producers
118
118
for r in range (self .num_producers ):
119
119
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
120
- self .buffer .extend (
121
- unbind_batch (
122
- ray_broadcast_tensor_dict (
123
- None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
124
- )
125
- )
120
+ raw_batch = ray_broadcast_tensor_dict (
121
+ None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
126
122
)
127
- while len (self .buffer ) >= self .dp_size * self .minibatch_size :
128
- batches = self .buffer [
129
- self .dp_rank * self .minibatch_size : (self .dp_rank + 1 ) * self .minibatch_size
123
+ # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124
+ # we need to calculate the metrics before filtering here for logging
125
+ # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
126
+ raw_batch_with_reward = self .calculate_reward (
127
+ {k : v .view (- 1 , v .size (- 1 )) if k != "temperature" else v for k , v in raw_batch .items ()}
128
+ )
129
+ raw_batch_with_reward = {
130
+ k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != "temperature" else v
131
+ for k , v in raw_batch_with_reward .items ()
132
+ }
133
+ # [batch_size, num_generations] -> [batch_size]
134
+ reward = raw_batch_with_reward ["reward" ][:, :, 0 ]
135
+ format_acc = raw_batch_with_reward ["format_acc" ][:, :, 0 ]
136
+ ans_acc = raw_batch_with_reward ["ans_acc" ][:, :, 0 ]
137
+ response_len = (
138
+ raw_batch_with_reward ["response_idx" ][:, :, 1 ]
139
+ - raw_batch_with_reward ["response_idx" ][:, :, 0 ]
140
+ + 1
141
+ ).type (torch .float32 )
142
+ effective_group_mask = None
143
+ if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
144
+ # filter the group based on the reward and accuracy
145
+ group_ans_acc_mean = ans_acc .mean (dim = 1 )
146
+ effective_group_mask = torch .logical_and (
147
+ group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
148
+ )
149
+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
150
+ for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
151
+ self .buffer .append (
152
+ [
153
+ (
154
+ group_with_reward
155
+ if effective_group_mask is None or effective_group_mask [group_idx ]
156
+ else None
157
+ ),
158
+ reward [group_idx ],
159
+ format_acc [group_idx ],
160
+ ans_acc [group_idx ],
161
+ response_len [group_idx ],
162
+ ]
163
+ )
164
+ if effective_group_mask is not None :
165
+ print (
166
+ f"[T{ dist .get_rank ()} ] Filter recv data: { len (raw_batch_with_reward )} -> { torch .sum (effective_group_mask ).cpu ().item ()} effective groups"
167
+ )
168
+ # mapping the effective group to the raw group for indexing
169
+ effective_group_to_raw_group_mapping = {}
170
+ for buffer_idx in range (len (self .buffer )):
171
+ if self .buffer [buffer_idx ][0 ] is not None :
172
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
173
+ buffer_idx
174
+ )
175
+ print (
176
+ f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
177
+ )
178
+
179
+ while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
180
+ # on each dp_rank, we use minibatch_size effective samples to form a batch
181
+ batches = [
182
+ self .buffer [effective_group_to_raw_group_mapping [i ]]
183
+ for i in range (
184
+ self .dp_rank * self .minibatch_size , (self .dp_rank + 1 ) * self .minibatch_size
185
+ )
130
186
]
131
- batch = bind_batch (batches )
187
+ # every dp_rank will receive a complete mini-batch, no need to sync within step() later
188
+ # each mini-batch use the first self.dp_size * minibatch_size effective samples
189
+ raw_mini_batches = self .buffer [
190
+ : effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1
191
+ ] # include the last effective sample
192
+ raw_mini_batches_metric_dict = {
193
+ "raw_train_mini_batch_reward" : [t [1 ] for t in raw_mini_batches ],
194
+ "raw_train_mini_batch_format_acc" : [t [2 ] for t in raw_mini_batches ],
195
+ "raw_train_mini_batch_ans_acc" : [t [3 ] for t in raw_mini_batches ],
196
+ "raw_train_mini_batch_response_len" : [t [4 ] for t in raw_mini_batches ],
197
+ }
198
+ batch = bind_batch ([t [0 ] for t in batches ])
132
199
batch = post_recv (batch )
133
- loss , excessive_prompts_idx = self .step (i , pbar , ** batch )
134
-
135
- if excessive_prompts_idx is not None :
136
- excessive_prompts = [self .buffer [idx ] for idx in excessive_prompts_idx ]
137
- self .buffer = excessive_prompts + self .buffer [self .dp_size * self .minibatch_size :]
138
- else :
139
- self .buffer = self .buffer [self .dp_size * self .minibatch_size :]
200
+ loss = self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
201
+ self .buffer = self .buffer [
202
+ effective_group_to_raw_group_mapping [self .dp_size * self .minibatch_size - 1 ] + 1 :
203
+ ]
204
+ # recalculate the effective group to raw group mapping
205
+ effective_group_to_raw_group_mapping_size_before = len (effective_group_to_raw_group_mapping )
206
+ effective_group_to_raw_group_mapping = {}
207
+ for buffer_idx in range (len (self .buffer )):
208
+ if self .buffer [buffer_idx ][0 ] is not None :
209
+ effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
210
+ buffer_idx
211
+ )
212
+ assert (
213
+ len (effective_group_to_raw_group_mapping )
214
+ == effective_group_to_raw_group_mapping_size_before - self .dp_size * self .minibatch_size
215
+ )
140
216
if loss is not None :
141
217
pbar .set_postfix ({"loss" : loss })
142
218
i += 1
0 commit comments