@@ -117,36 +117,47 @@ def loop(self) -> None:
117
117
# receive data from producers
118
118
for r in range (self .num_producers ):
119
119
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
120
- raw_batch = ray_broadcast_tensor_dict (None , src = 0 , device = self .device , group_name = f"sync_data_{ r } " )
120
+ raw_batch = ray_broadcast_tensor_dict (
121
+ None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
122
+ )
121
123
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
122
124
# we need to calculate the metrics before filtering here for logging
123
125
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
124
- raw_batch_with_reward = self .calculate_reward ({k :v .view (- 1 , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch .items ()})
125
- raw_batch_with_reward = {k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch_with_reward .items ()}
126
- # [batch_size, num_generations] -> [batch_size]
127
- group_reward_mean = raw_batch_with_reward ["reward" ][:,:,0 ].mean (dim = - 1 )
128
- group_format_acc_mean = raw_batch_with_reward ["format_acc" ][:,:,0 ].mean (dim = - 1 )
129
- group_ans_acc_mean = raw_batch_with_reward ["ans_acc" ][:,:,0 ].mean (dim = - 1 )
130
- group_response_len = (
131
- (raw_batch_with_reward ["response_idx" ][:, :, 1 ] - raw_batch_with_reward ["response_idx" ][:, :, 0 ] + 1 )
132
- .type (torch .float32 )
133
- .mean (dim = - 1 )
126
+ raw_batch_with_reward = self .calculate_reward (
127
+ {k : v .view (- 1 , v .size (- 1 )) if k != "temperature" else v for k , v in raw_batch .items ()}
134
128
)
129
+ raw_batch_with_reward = {
130
+ k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != "temperature" else v
131
+ for k , v in raw_batch_with_reward .items ()
132
+ }
133
+ # [batch_size, num_generations] -> [batch_size]
134
+ reward = raw_batch_with_reward ["reward" ][:, :, 0 ]
135
+ format_acc = raw_batch_with_reward ["format_acc" ][:, :, 0 ]
136
+ ans_acc = raw_batch_with_reward ["ans_acc" ][:, :, 0 ]
137
+ response_len = (
138
+ raw_batch_with_reward ["response_idx" ][:, :, 1 ]
139
+ - raw_batch_with_reward ["response_idx" ][:, :, 0 ]
140
+ + 1
141
+ ).type (torch .float32 )
135
142
effective_group_mask = None
136
143
if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
137
144
# filter the group based on the reward and accuracy
138
145
effective_group_mask = torch .logical_and (
139
146
group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
140
147
)
141
- raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
148
+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
142
149
for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
143
150
self .buffer .append (
144
151
[
145
- group_with_reward if effective_group_mask is None or effective_group_mask [group_idx ] else None ,
146
- group_reward_mean [group_idx ],
147
- group_format_acc_mean [group_idx ],
148
- group_ans_acc_mean [group_idx ],
149
- group_response_len [group_idx ],
152
+ (
153
+ group_with_reward
154
+ if effective_group_mask is None or effective_group_mask [group_idx ]
155
+ else None
156
+ ),
157
+ reward [group_idx ],
158
+ format_acc [group_idx ],
159
+ ans_acc [group_idx ],
160
+ response_len [group_idx ],
150
161
]
151
162
)
152
163
if effective_group_mask is not None :
@@ -160,7 +171,9 @@ def loop(self) -> None:
160
171
effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
161
172
buffer_idx
162
173
)
163
- pbar .set_postfix ({"Collect Effective Prompt" : f"{ len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } " })
174
+ print (
175
+ f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
176
+ )
164
177
165
178
while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
166
179
# on each dp_rank, we use minibatch_size effective samples to form a batch
0 commit comments