@@ -293,13 +293,21 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
293
293
)
294
294
295
295
if self .booster .plugin .stage_manager .is_last_stage ():
296
- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
297
- reference_action_log_probs = calc_action_log_probs (
298
- reference_model_logits / self .generate_config ["temperature" ],
299
- input_ids_forward_micro_batch ,
300
- num_action ,
301
- self .plugin .shard_config ,
296
+ reference_action_log_probs = torch .zeros (
297
+ (input_ids_forward_micro_batch .size (0 ), num_action ),
298
+ device = input_ids_forward_micro_batch .device ,
302
299
)
300
+ for i in range (reference_action_log_probs .size (0 )):
301
+ # activation for log_softmax is too large if vocab size and sequence length are large
302
+ # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
303
+ # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
304
+ reference_action_log_probs [i , :] += calc_action_log_probs (
305
+ reference_model_outputs ["outputs" ]["logits" ][i : i + 1 ]
306
+ / self .generate_config ["temperature" ],
307
+ input_ids_forward_micro_batch [i : i + 1 ],
308
+ num_action ,
309
+ self .plugin .shard_config ,
310
+ )[0 ]
303
311
else :
304
312
# Dummy reference logprobs for data iterator.
305
313
reference_action_log_probs = None
@@ -321,12 +329,19 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
321
329
322
330
def _criterion (outputs , inputs ):
323
331
action_logits = outputs .logits
324
- action_log_probs = calc_action_log_probs (
325
- action_logits / self .generate_config ["temperature" ],
326
- inputs ["input_ids" ],
327
- num_action ,
328
- self .plugin .shard_config ,
332
+ action_log_probs = torch .zeros (
333
+ (inputs ["input_ids" ].size (0 ), num_action ), device = action_logits .device
329
334
)
335
+ for i in range (action_log_probs .size (0 )):
336
+ # activation for log_softmax is too large if vocab size and sequence length are large
337
+ # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
338
+ # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
339
+ action_log_probs [i , :] += calc_action_log_probs (
340
+ action_logits [i : i + 1 ] / self .generate_config ["temperature" ],
341
+ inputs ["input_ids" ][i : i + 1 ],
342
+ num_action ,
343
+ self .plugin .shard_config ,
344
+ )[0 ]
330
345
if "reference_action_log_probs" in inputs :
331
346
per_token_kl = (
332
347
torch .exp (inputs ["reference_action_log_probs" ] - action_log_probs )
0 commit comments