66import wandb
77from coati .distributed .consumer import BaseConsumer
88from coati .distributed .loss import PolicyLoss
9- from coati .distributed .utils import calc_action_log_probs
9+ from coati .distributed .utils import memory_efficient_logprob
1010from coati .trainer .utils import all_reduce_mean , all_reduce_sum
1111from transformers import AutoModelForCausalLM , AutoTokenizer
1212
@@ -280,21 +280,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
280280 )
281281
282282 if self .booster .plugin .stage_manager .is_last_stage ():
283- reference_action_log_probs = torch .zeros (
284- (input_ids_forward_micro_batch .size (0 ), num_action ),
285- device = input_ids_forward_micro_batch .device ,
283+ reference_action_log_probs = memory_efficient_logprob (
284+ reference_model_outputs ["outputs" ]["logits" ],
285+ input_ids_forward_micro_batch ,
286+ num_action ,
287+ shard_config = self .plugin .shard_config ,
286288 )
287- for i in range (reference_action_log_probs .size (0 )):
288- # activation for log_softmax is too large if vocab size and sequence length are large
289- # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
290- # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
291- reference_action_log_probs [i , :] += calc_action_log_probs (
292- reference_model_outputs ["outputs" ]["logits" ][i : i + 1 ]
293- / self .generate_config ["temperature" ],
294- input_ids_forward_micro_batch [i : i + 1 ],
295- num_action ,
296- self .plugin .shard_config ,
297- )[0 ]
298289 else :
299290 # Dummy reference logprobs for data iterator.
300291 reference_action_log_probs = None
@@ -316,19 +307,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
316307
317308 def _criterion (outputs , inputs ):
318309 action_logits = outputs .logits
319- action_log_probs = torch .zeros (
320- (inputs ["input_ids" ].size (0 ), num_action ), device = action_logits .device
310+ action_log_probs = memory_efficient_logprob (
311+ action_logits ,
312+ inputs ["input_ids" ],
313+ num_action ,
314+ shard_config = self .plugin .shard_config ,
321315 )
322- for i in range (action_log_probs .size (0 )):
323- # activation for log_softmax is too large if vocab size and sequence length are large
324- # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
325- # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
326- action_log_probs [i , :] += calc_action_log_probs (
327- action_logits [i : i + 1 ] / self .generate_config ["temperature" ],
328- inputs ["input_ids" ][i : i + 1 ],
329- num_action ,
330- self .plugin .shard_config ,
331- )[0 ]
332316 if "reference_action_log_probs" in inputs :
333317 per_token_kl = (
334318 torch .exp (inputs ["reference_action_log_probs" ] - action_log_probs )
@@ -370,16 +354,15 @@ def _criterion(outputs, inputs):
370354 mean_kl .append (kl )
371355 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
372356 else :
373-
374357 policy_model_logits = self .policy_model (
375358 input_ids = input_ids_forward_micro_batch ,
376359 attention_mask = attention_mask_forward_micro_batch ,
377360 ).logits
378- action_log_probs = calc_action_log_probs (
361+ action_log_probs = memory_efficient_logprob (
379362 policy_model_logits / self .generate_config ["temperature" ],
380363 input_ids_forward_micro_batch ,
381364 num_action ,
382- self .plugin .shard_config ,
365+ shard_config = self .plugin .shard_config ,
383366 )
384367
385368 if self .policy_loss_fn .beta > 0 :
@@ -388,7 +371,7 @@ def _criterion(outputs, inputs):
388371 input_ids = input_ids_forward_micro_batch ,
389372 attention_mask = attention_mask_forward_micro_batch ,
390373 ).logits
391- reference_action_log_probs = calc_action_log_probs (
374+ reference_action_log_probs = memory_efficient_logprob (
392375 reference_model_logits / self .generate_config ["temperature" ],
393376 input_ids_forward_micro_batch ,
394377 num_action ,
0 commit comments