@@ -280,13 +280,21 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
280280 )
281281
282282 if self .booster .plugin .stage_manager .is_last_stage ():
283- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
284- reference_action_log_probs = calc_action_log_probs (
285- reference_model_logits / self .generate_config ["temperature" ],
286- input_ids_forward_micro_batch ,
287- num_action ,
288- self .plugin .shard_config ,
283+ reference_action_log_probs = torch .zeros (
284+ (input_ids_forward_micro_batch .size (0 ), num_action ),
285+ device = input_ids_forward_micro_batch .device ,
289286 )
287+ for i in range (reference_action_log_probs .size (0 )):
288+ # activation for log_softmax is too large if vocab size and sequence length are large
289+ # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
290+ # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
291+ reference_action_log_probs [i , :] += calc_action_log_probs (
292+ reference_model_outputs ["outputs" ]["logits" ][i : i + 1 ]
293+ / self .generate_config ["temperature" ],
294+ input_ids_forward_micro_batch [i : i + 1 ],
295+ num_action ,
296+ self .plugin .shard_config ,
297+ )[0 ]
290298 else :
291299 # Dummy reference logprobs for data iterator.
292300 reference_action_log_probs = None
@@ -308,12 +316,19 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
308316
309317 def _criterion (outputs , inputs ):
310318 action_logits = outputs .logits
311- action_log_probs = calc_action_log_probs (
312- action_logits / self .generate_config ["temperature" ],
313- inputs ["input_ids" ],
314- num_action ,
315- self .plugin .shard_config ,
319+ action_log_probs = torch .zeros (
320+ (inputs ["input_ids" ].size (0 ), num_action ), device = action_logits .device
316321 )
322+ for i in range (action_log_probs .size (0 )):
323+ # activation for log_softmax is too large if vocab size and sequence length are large
324+ # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
325+ # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
326+ action_log_probs [i , :] += calc_action_log_probs (
327+ action_logits [i : i + 1 ] / self .generate_config ["temperature" ],
328+ inputs ["input_ids" ][i : i + 1 ],
329+ num_action ,
330+ self .plugin .shard_config ,
331+ )[0 ]
317332 if "reference_action_log_probs" in inputs :
318333 per_token_kl = (
319334 torch .exp (inputs ["reference_action_log_probs" ] - action_log_probs )
0 commit comments