6
6
import wandb
7
7
from coati .distributed .consumer import BaseConsumer
8
8
from coati .distributed .loss import PolicyLoss
9
- from coati .distributed .utils import calc_action_log_probs
9
+ from coati .distributed .utils import memory_efficient_logprob
10
10
from coati .trainer .utils import all_reduce_mean , all_reduce_sum
11
11
from transformers import AutoModelForCausalLM , AutoTokenizer
12
12
@@ -293,21 +293,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
293
293
)
294
294
295
295
if self .booster .plugin .stage_manager .is_last_stage ():
296
- reference_action_log_probs = torch .zeros (
297
- (input_ids_forward_micro_batch .size (0 ), num_action ),
298
- device = input_ids_forward_micro_batch .device ,
296
+ reference_action_log_probs = memory_efficient_logprob (
297
+ reference_model_outputs ["outputs" ]["logits" ],
298
+ input_ids_forward_micro_batch ,
299
+ num_action ,
300
+ shard_config = self .plugin .shard_config ,
299
301
)
300
- for i in range (reference_action_log_probs .size (0 )):
301
- # activation for log_softmax is too large if vocab size and sequence length are large
302
- # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
303
- # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
304
- reference_action_log_probs [i , :] += calc_action_log_probs (
305
- reference_model_outputs ["outputs" ]["logits" ][i : i + 1 ]
306
- / self .generate_config ["temperature" ],
307
- input_ids_forward_micro_batch [i : i + 1 ],
308
- num_action ,
309
- self .plugin .shard_config ,
310
- )[0 ]
311
302
else :
312
303
# Dummy reference logprobs for data iterator.
313
304
reference_action_log_probs = None
@@ -329,19 +320,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
329
320
330
321
def _criterion (outputs , inputs ):
331
322
action_logits = outputs .logits
332
- action_log_probs = torch .zeros (
333
- (inputs ["input_ids" ].size (0 ), num_action ), device = action_logits .device
323
+ action_log_probs = memory_efficient_logprob (
324
+ action_logits ,
325
+ inputs ["input_ids" ],
326
+ num_action ,
327
+ shard_config = self .plugin .shard_config ,
334
328
)
335
- for i in range (action_log_probs .size (0 )):
336
- # activation for log_softmax is too large if vocab size and sequence length are large
337
- # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
338
- # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
339
- action_log_probs [i , :] += calc_action_log_probs (
340
- action_logits [i : i + 1 ] / self .generate_config ["temperature" ],
341
- inputs ["input_ids" ][i : i + 1 ],
342
- num_action ,
343
- self .plugin .shard_config ,
344
- )[0 ]
345
329
if "reference_action_log_probs" in inputs :
346
330
per_token_kl = (
347
331
torch .exp (inputs ["reference_action_log_probs" ] - action_log_probs )
@@ -383,16 +367,15 @@ def _criterion(outputs, inputs):
383
367
mean_kl .append (kl )
384
368
mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
385
369
else :
386
-
387
370
policy_model_logits = self .policy_model (
388
371
input_ids = input_ids_forward_micro_batch ,
389
372
attention_mask = attention_mask_forward_micro_batch ,
390
373
).logits
391
- action_log_probs = calc_action_log_probs (
374
+ action_log_probs = memory_efficient_logprob (
392
375
policy_model_logits / self .generate_config ["temperature" ],
393
376
input_ids_forward_micro_batch ,
394
377
num_action ,
395
- self .plugin .shard_config ,
378
+ shard_config = self .plugin .shard_config ,
396
379
)
397
380
398
381
if self .policy_loss_fn .beta > 0 :
@@ -401,7 +384,7 @@ def _criterion(outputs, inputs):
401
384
input_ids = input_ids_forward_micro_batch ,
402
385
attention_mask = attention_mask_forward_micro_batch ,
403
386
).logits
404
- reference_action_log_probs = calc_action_log_probs (
387
+ reference_action_log_probs = memory_efficient_logprob (
405
388
reference_model_logits / self .generate_config ["temperature" ],
406
389
input_ids_forward_micro_batch ,
407
390
num_action ,
0 commit comments