66import wandb
77from coati .distributed .consumer import BaseConsumer
88from coati .distributed .loss import PolicyLoss
9- from coati .distributed .utils import calc_action_log_probs
9+ from coati .distributed .utils import memory_efficient_logprob
1010from coati .trainer .utils import all_reduce_mean , all_reduce_sum
1111from transformers import AutoModelForCausalLM , AutoTokenizer
1212
@@ -280,12 +280,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
280280 )
281281
282282 if self .booster .plugin .stage_manager .is_last_stage ():
283- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
284- reference_action_log_probs = calc_action_log_probs (
285- reference_model_logits / self .generate_config ["temperature" ],
283+ reference_action_log_probs = memory_efficient_logprob (
284+ reference_model_outputs ["outputs" ]["logits" ] / self .generate_config ["temperature" ],
286285 input_ids_forward_micro_batch ,
287286 num_action ,
288- self .plugin .shard_config ,
287+ shard_config = self .plugin .shard_config ,
289288 )
290289 else :
291290 # Dummy reference logprobs for data iterator.
@@ -308,11 +307,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
308307
309308 def _criterion (outputs , inputs ):
310309 action_logits = outputs .logits
311- action_log_probs = calc_action_log_probs (
310+ action_log_probs = memory_efficient_logprob (
312311 action_logits / self .generate_config ["temperature" ],
313312 inputs ["input_ids" ],
314313 num_action ,
315- self .plugin .shard_config ,
314+ shard_config = self .plugin .shard_config ,
316315 )
317316 if "reference_action_log_probs" in inputs :
318317 per_token_kl = (
@@ -355,16 +354,15 @@ def _criterion(outputs, inputs):
355354 mean_kl .append (kl )
356355 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
357356 else :
358-
359357 policy_model_logits = self .policy_model (
360358 input_ids = input_ids_forward_micro_batch ,
361359 attention_mask = attention_mask_forward_micro_batch ,
362360 ).logits
363- action_log_probs = calc_action_log_probs (
361+ action_log_probs = memory_efficient_logprob (
364362 policy_model_logits / self .generate_config ["temperature" ],
365363 input_ids_forward_micro_batch ,
366364 num_action ,
367- self .plugin .shard_config ,
365+ shard_config = self .plugin .shard_config ,
368366 )
369367
370368 if self .policy_loss_fn .beta > 0 :
@@ -373,11 +371,11 @@ def _criterion(outputs, inputs):
373371 input_ids = input_ids_forward_micro_batch ,
374372 attention_mask = attention_mask_forward_micro_batch ,
375373 ).logits
376- reference_action_log_probs = calc_action_log_probs (
374+ reference_action_log_probs = memory_efficient_logprob (
377375 reference_model_logits / self .generate_config ["temperature" ],
378376 input_ids_forward_micro_batch ,
379377 num_action ,
380- self .plugin .shard_config ,
378+ shard_config = self .plugin .shard_config ,
381379 )
382380 per_token_kl = (
383381 torch .exp (reference_action_log_probs - action_log_probs )
0 commit comments