@@ -68,6 +68,7 @@ def __init__(
68
68
self .accum_response_length = torch .zeros (1 , device = self .device )
69
69
self .accum_count = 0
70
70
self .generate_config = generate_config
71
+ self .training_config = training_config
71
72
72
73
# Reference model is initialized from policy model.
73
74
self .reference_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
@@ -131,40 +132,16 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
131
132
num_action = action_mask .shape [1 ]
132
133
old_action_log_probs = data ["action_log_probs" ]
133
134
response_length = torch .sum (action_mask , dim = 1 ).to (torch .float32 )
135
+ forward_batch_size = self .training_config .get ("train_microbatch_size" , data ["input_ids" ].size (0 ))
134
136
135
137
need_update = (step_idx + 1 ) % self .num_microbatches == 0
136
- ctx = nullcontext () if need_update else self .booster .no_sync (self .policy_model , self .optimizer )
138
+ # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
139
+ ctx = (
140
+ nullcontext ()
141
+ if need_update or self .booster .plugin .zero_stage == 2
142
+ else self .booster .no_sync (self .policy_model , self .optimizer )
143
+ )
137
144
with ctx :
138
- policy_model_logits = self .policy_model (
139
- input_ids = data ["input_ids" ],
140
- attention_mask = data ["attention_mask" ],
141
- )["logits" ]
142
- action_log_probs = calc_action_log_probs (
143
- policy_model_logits / self .generate_config ["temperature" ],
144
- data ["input_ids" ],
145
- num_action ,
146
- self .plugin .shard_config ,
147
- )
148
-
149
- with torch .no_grad ():
150
- reference_model_logits = self .reference_model (
151
- input_ids = data ["input_ids" ],
152
- attention_mask = data ["attention_mask" ],
153
- )["logits" ]
154
- reference_action_log_probs = calc_action_log_probs (
155
- reference_model_logits / self .generate_config ["temperature" ],
156
- data ["input_ids" ],
157
- num_action ,
158
- self .plugin .shard_config ,
159
- )
160
-
161
- per_token_kl = (
162
- torch .exp (reference_action_log_probs - action_log_probs )
163
- - (reference_action_log_probs - action_log_probs )
164
- - 1
165
- )
166
- kl = torch .sum (per_token_kl * action_mask , dim = - 1 ) / torch .sum (action_mask , dim = - 1 )
167
-
168
145
reward_group = self .reward_model (
169
146
data ["input_ids" ], gt_answer = data ["gt_answer" ], response_idx = data ["response_idx" ]
170
147
)
@@ -177,6 +154,11 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
177
154
178
155
group_reward = reward .view (- 1 , self .num_generations )
179
156
reward_mean = group_reward .mean (dim = 1 )
157
+ # [batch_size x num_generations]
158
+ reward_mean = reward_mean .repeat_interleave (self .num_generations , dim = 0 )
159
+ reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
160
+ # [batch_size x num_generations]
161
+ advantages = ((reward - reward_mean ) / (reward_std + 1e-4 )).unsqueeze (dim = - 1 )
180
162
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
181
163
loss_mask = (
182
164
None
@@ -185,35 +167,82 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
185
167
reward_mean > self .filter_range [0 ], reward_mean < self .filter_range [1 ]
186
168
).repeat_interleave (self .num_generations , dim = 0 )
187
169
)
170
+ mean_kl , mean_loss = [], []
171
+ for forward_micro_batch_start in range (0 , data ["input_ids" ].size (0 ), forward_batch_size ):
172
+ input_ids_forward_micro_batch = data ["input_ids" ][
173
+ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
174
+ ]
175
+ attention_mask_forward_micro_batch = data ["attention_mask" ][
176
+ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
177
+ ]
178
+ action_mask_forward_micro_batch = action_mask [
179
+ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
180
+ ]
181
+ loss_mask_forward_micro_batch = (
182
+ loss_mask [forward_micro_batch_start : forward_micro_batch_start + forward_batch_size ]
183
+ if loss_mask is not None
184
+ else None
185
+ )
186
+ advantages_forward_micro_batch = advantages [
187
+ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
188
+ ]
189
+ policy_model_logits = self .policy_model (
190
+ input_ids = input_ids_forward_micro_batch ,
191
+ attention_mask = attention_mask_forward_micro_batch ,
192
+ ).logits
193
+ action_log_probs = calc_action_log_probs (
194
+ policy_model_logits / self .generate_config ["temperature" ],
195
+ input_ids_forward_micro_batch ,
196
+ num_action ,
197
+ self .plugin .shard_config ,
198
+ )
188
199
189
- # [batch_size x num_generations]
190
- reward_mean = reward_mean .repeat_interleave (self .num_generations , dim = 0 )
191
- reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
192
- # [batch_size x num_generations]
193
- advantages = (reward - reward_mean ) / (reward_std + 1e-4 )
194
-
195
- loss , skip_update , _ = self .policy_loss_fn (
196
- action_log_probs ,
197
- old_action_log_probs ,
198
- advantages .unsqueeze (dim = - 1 ).repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
199
- per_token_kl ,
200
- action_mask ,
201
- loss_mask = loss_mask ,
202
- )
200
+ with torch .no_grad ():
201
+ reference_model_logits = self .reference_model (
202
+ input_ids = input_ids_forward_micro_batch ,
203
+ attention_mask = attention_mask_forward_micro_batch ,
204
+ ).logits
205
+ reference_action_log_probs = calc_action_log_probs (
206
+ reference_model_logits / self .generate_config ["temperature" ],
207
+ input_ids_forward_micro_batch ,
208
+ num_action ,
209
+ self .plugin .shard_config ,
210
+ )
211
+
212
+ per_token_kl = (
213
+ torch .exp (reference_action_log_probs - action_log_probs )
214
+ - (reference_action_log_probs - action_log_probs )
215
+ - 1
216
+ )
217
+ kl = torch .sum (per_token_kl * action_mask_forward_micro_batch , dim = - 1 ) / torch .sum (
218
+ action_mask_forward_micro_batch , dim = - 1
219
+ )
220
+
221
+ loss , skip_update , _ = self .policy_loss_fn (
222
+ action_log_probs ,
223
+ old_action_log_probs ,
224
+ advantages_forward_micro_batch .repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
225
+ per_token_kl ,
226
+ action_mask_forward_micro_batch ,
227
+ loss_mask = loss_mask_forward_micro_batch ,
228
+ )
229
+
230
+ if not skip_update :
231
+ self .booster .backward (loss , self .optimizer )
232
+ loss = all_reduce_mean (loss , self .plugin )
233
+ kl = all_reduce_mean (kl .mean (), self .plugin )
234
+ # Calculate accumulate value.
235
+ mean_kl .append (kl .data )
236
+ mean_loss .append (loss .data )
203
237
204
- if not skip_update :
205
- self .booster .backward (loss , self .optimizer )
206
- loss = all_reduce_mean (loss , self .plugin )
207
238
reward = all_reduce_mean (reward .mean (), self .plugin )
208
- kl = all_reduce_mean (kl .mean (), self .plugin )
209
239
format_reward = all_reduce_mean (format_reward .mean (), self .plugin )
210
240
acc_reward = all_reduce_mean (acc_reward .mean (), self .plugin )
211
241
advantages = all_reduce_mean (advantages .mean (), self .plugin )
212
242
response_length = all_reduce_mean (response_length .mean (), self .plugin )
213
- # Calculate accumulate value.
214
- self .accum_loss .add_ (loss . data )
243
+ self . accum_loss . add_ ( sum ( mean_loss ) / len ( mean_loss ))
244
+ self .accum_kl .add_ (sum ( mean_kl ) / len ( mean_kl ) )
215
245
self .accum_reward .add_ (reward .data )
216
- self .accum_kl .add_ (kl .data )
217
246
self .accum_format_reward .add_ (format_reward .data )
218
247
self .accum_acc_reward .add_ (acc_reward .data )
219
248
self .accum_advantages .add_ (advantages .data )
0 commit comments