6
6
import wandb
7
7
from coati .distributed .consumer import BaseConsumer
8
8
from coati .distributed .loss import PolicyLoss
9
- from coati .distributed .utils import calc_action_log_probs
9
+ from coati .distributed .utils import memory_efficient_logprob
10
10
from coati .trainer .utils import all_reduce_mean , all_reduce_sum
11
11
from transformers import AutoModelForCausalLM , AutoTokenizer
12
12
@@ -280,12 +280,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
280
280
)
281
281
282
282
if self .booster .plugin .stage_manager .is_last_stage ():
283
- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
284
- reference_action_log_probs = calc_action_log_probs (
285
- reference_model_logits / self .generate_config ["temperature" ],
283
+ reference_action_log_probs = memory_efficient_logprob (
284
+ reference_model_outputs ["outputs" ]["logits" ] / self .generate_config ["temperature" ],
286
285
input_ids_forward_micro_batch ,
287
286
num_action ,
288
- self .plugin .shard_config ,
287
+ shard_config = self .plugin .shard_config ,
289
288
)
290
289
else :
291
290
# Dummy reference logprobs for data iterator.
@@ -308,11 +307,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
308
307
309
308
def _criterion (outputs , inputs ):
310
309
action_logits = outputs .logits
311
- action_log_probs = calc_action_log_probs (
310
+ action_log_probs = memory_efficient_logprob (
312
311
action_logits / self .generate_config ["temperature" ],
313
312
inputs ["input_ids" ],
314
313
num_action ,
315
- self .plugin .shard_config ,
314
+ shard_config = self .plugin .shard_config ,
316
315
)
317
316
if "reference_action_log_probs" in inputs :
318
317
per_token_kl = (
@@ -355,16 +354,15 @@ def _criterion(outputs, inputs):
355
354
mean_kl .append (kl )
356
355
mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
357
356
else :
358
-
359
357
policy_model_logits = self .policy_model (
360
358
input_ids = input_ids_forward_micro_batch ,
361
359
attention_mask = attention_mask_forward_micro_batch ,
362
360
).logits
363
- action_log_probs = calc_action_log_probs (
361
+ action_log_probs = memory_efficient_logprob (
364
362
policy_model_logits / self .generate_config ["temperature" ],
365
363
input_ids_forward_micro_batch ,
366
364
num_action ,
367
- self .plugin .shard_config ,
365
+ shard_config = self .plugin .shard_config ,
368
366
)
369
367
370
368
if self .policy_loss_fn .beta > 0 :
@@ -373,11 +371,11 @@ def _criterion(outputs, inputs):
373
371
input_ids = input_ids_forward_micro_batch ,
374
372
attention_mask = attention_mask_forward_micro_batch ,
375
373
).logits
376
- reference_action_log_probs = calc_action_log_probs (
374
+ reference_action_log_probs = memory_efficient_logprob (
377
375
reference_model_logits / self .generate_config ["temperature" ],
378
376
input_ids_forward_micro_batch ,
379
377
num_action ,
380
- self .plugin .shard_config ,
378
+ shard_config = self .plugin .shard_config ,
381
379
)
382
380
per_token_kl = (
383
381
torch .exp (reference_action_log_probs - action_log_probs )
0 commit comments