@@ -280,13 +280,21 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
280
280
)
281
281
282
282
if self .booster .plugin .stage_manager .is_last_stage ():
283
- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
284
- reference_action_log_probs = calc_action_log_probs (
285
- reference_model_logits / self .generate_config ["temperature" ],
286
- input_ids_forward_micro_batch ,
287
- num_action ,
288
- self .plugin .shard_config ,
283
+ reference_action_log_probs = torch .zeros (
284
+ (input_ids_forward_micro_batch .size (0 ), num_action ),
285
+ device = input_ids_forward_micro_batch .device ,
289
286
)
287
+ for i in range (reference_action_log_probs .size (0 )):
288
+ # activation for log_softmax is too large if vocab size and sequence length are large
289
+ # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
290
+ # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
291
+ reference_action_log_probs [i , :] += calc_action_log_probs (
292
+ reference_model_outputs ["outputs" ]["logits" ][i : i + 1 ]
293
+ / self .generate_config ["temperature" ],
294
+ input_ids_forward_micro_batch [i : i + 1 ],
295
+ num_action ,
296
+ self .plugin .shard_config ,
297
+ )[0 ]
290
298
else :
291
299
# Dummy reference logprobs for data iterator.
292
300
reference_action_log_probs = None
@@ -308,12 +316,19 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
308
316
309
317
def _criterion (outputs , inputs ):
310
318
action_logits = outputs .logits
311
- action_log_probs = calc_action_log_probs (
312
- action_logits / self .generate_config ["temperature" ],
313
- inputs ["input_ids" ],
314
- num_action ,
315
- self .plugin .shard_config ,
319
+ action_log_probs = torch .zeros (
320
+ (inputs ["input_ids" ].size (0 ), num_action ), device = action_logits .device
316
321
)
322
+ for i in range (action_log_probs .size (0 )):
323
+ # activation for log_softmax is too large if vocab size and sequence length are large
324
+ # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
325
+ # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
326
+ action_log_probs [i , :] += calc_action_log_probs (
327
+ action_logits [i : i + 1 ] / self .generate_config ["temperature" ],
328
+ inputs ["input_ids" ][i : i + 1 ],
329
+ num_action ,
330
+ self .plugin .shard_config ,
331
+ )[0 ]
317
332
if "reference_action_log_probs" in inputs :
318
333
per_token_kl = (
319
334
torch .exp (inputs ["reference_action_log_probs" ] - action_log_probs )
0 commit comments