@@ -117,32 +117,44 @@ def loop(self) -> None:
117
117
# receive data from producers
118
118
for r in range (self .num_producers ):
119
119
print (f"[T{ dist .get_rank ()} ] Recv data episode { episode } step { step } from { r } " )
120
- raw_batch = ray_broadcast_tensor_dict (None , src = 0 , device = self .device , group_name = f"sync_data_{ r } " )
120
+ raw_batch = ray_broadcast_tensor_dict (
121
+ None , src = 0 , device = self .device , group_name = f"sync_data_{ r } "
122
+ )
121
123
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
122
124
# we need to calculate the metrics before filtering here for logging
123
125
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
124
- raw_batch_with_reward = self .calculate_reward ({k :v .view (- 1 , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch .items ()})
125
- raw_batch_with_reward = {k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != 'temperature' else v for k , v in raw_batch_with_reward .items ()}
126
+ raw_batch_with_reward = self .calculate_reward (
127
+ {k : v .view (- 1 , v .size (- 1 )) if k != "temperature" else v for k , v in raw_batch .items ()}
128
+ )
129
+ raw_batch_with_reward = {
130
+ k : v .view (- 1 , self .num_generations , v .size (- 1 )) if k != "temperature" else v
131
+ for k , v in raw_batch_with_reward .items ()
132
+ }
126
133
# [batch_size, num_generations] -> [batch_size]
127
- reward = raw_batch_with_reward ["reward" ][:,:, 0 ]
128
- format_acc = raw_batch_with_reward ["format_acc" ][:,:, 0 ]
129
- ans_acc = raw_batch_with_reward ["ans_acc" ][:,:, 0 ]
134
+ reward = raw_batch_with_reward ["reward" ][:, :, 0 ]
135
+ format_acc = raw_batch_with_reward ["format_acc" ][:, :, 0 ]
136
+ ans_acc = raw_batch_with_reward ["ans_acc" ][:, :, 0 ]
130
137
response_len = (
131
- (raw_batch_with_reward ["response_idx" ][:, :, 1 ] - raw_batch_with_reward ["response_idx" ][:, :, 0 ] + 1 )
132
- .type (torch .float32 )
133
- )
138
+ raw_batch_with_reward ["response_idx" ][:, :, 1 ]
139
+ - raw_batch_with_reward ["response_idx" ][:, :, 0 ]
140
+ + 1
141
+ ).type (torch .float32 )
134
142
effective_group_mask = None
135
143
if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , True ):
136
144
# filter the group based on the reward and accuracy
137
145
group_ans_acc_mean = ans_acc .mean (dim = 1 )
138
146
effective_group_mask = torch .logical_and (
139
147
group_ans_acc_mean > self .filter_range [0 ], group_ans_acc_mean < self .filter_range [1 ]
140
148
)
141
- raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
149
+ raw_batch_with_reward = unbind_batch (raw_batch_with_reward ) # List[Dict[str, torch.Tensor]]
142
150
for group_idx , group_with_reward in enumerate (raw_batch_with_reward ):
143
151
self .buffer .append (
144
152
[
145
- group_with_reward if effective_group_mask is None or effective_group_mask [group_idx ] else None ,
153
+ (
154
+ group_with_reward
155
+ if effective_group_mask is None or effective_group_mask [group_idx ]
156
+ else None
157
+ ),
146
158
reward [group_idx ],
147
159
format_acc [group_idx ],
148
160
ans_acc [group_idx ],
@@ -160,7 +172,9 @@ def loop(self) -> None:
160
172
effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] = (
161
173
buffer_idx
162
174
)
163
- print (f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } " )
175
+ print (
176
+ f"[T{ dist .get_rank ()} ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )} /{ self .dp_size * self .minibatch_size } "
177
+ )
164
178
165
179
while len (effective_group_to_raw_group_mapping ) >= self .dp_size * self .minibatch_size :
166
180
# on each dp_rank, we use minibatch_size effective samples to form a batch
0 commit comments