6
6
import wandb
7
7
from coati .distributed .consumer import BaseConsumer
8
8
from coati .distributed .loss import PolicyLoss
9
- from coati .distributed .utils import calc_action_log_probs
9
+ from coati .distributed .utils import memory_efficient_logprob
10
10
from coati .trainer .utils import all_reduce_mean , all_reduce_sum
11
11
from transformers import AutoModelForCausalLM , AutoTokenizer
12
12
@@ -280,21 +280,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
280
280
)
281
281
282
282
if self .booster .plugin .stage_manager .is_last_stage ():
283
- reference_action_log_probs = torch .zeros (
284
- (input_ids_forward_micro_batch .size (0 ), num_action ),
285
- device = input_ids_forward_micro_batch .device ,
283
+ reference_action_log_probs = memory_efficient_logprob (
284
+ reference_model_outputs ["outputs" ]["logits" ],
285
+ input_ids_forward_micro_batch ,
286
+ num_action ,
287
+ shard_config = self .plugin .shard_config ,
286
288
)
287
- for i in range (reference_action_log_probs .size (0 )):
288
- # activation for log_softmax is too large if vocab size and sequence length are large
289
- # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
290
- # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
291
- reference_action_log_probs [i , :] += calc_action_log_probs (
292
- reference_model_outputs ["outputs" ]["logits" ][i : i + 1 ]
293
- / self .generate_config ["temperature" ],
294
- input_ids_forward_micro_batch [i : i + 1 ],
295
- num_action ,
296
- self .plugin .shard_config ,
297
- )[0 ]
298
289
else :
299
290
# Dummy reference logprobs for data iterator.
300
291
reference_action_log_probs = None
@@ -316,19 +307,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
316
307
317
308
def _criterion (outputs , inputs ):
318
309
action_logits = outputs .logits
319
- action_log_probs = torch .zeros (
320
- (inputs ["input_ids" ].size (0 ), num_action ), device = action_logits .device
310
+ action_log_probs = memory_efficient_logprob (
311
+ action_logits ,
312
+ inputs ["input_ids" ],
313
+ num_action ,
314
+ shard_config = self .plugin .shard_config ,
321
315
)
322
- for i in range (action_log_probs .size (0 )):
323
- # activation for log_softmax is too large if vocab size and sequence length are large
324
- # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
325
- # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
326
- action_log_probs [i , :] += calc_action_log_probs (
327
- action_logits [i : i + 1 ] / self .generate_config ["temperature" ],
328
- inputs ["input_ids" ][i : i + 1 ],
329
- num_action ,
330
- self .plugin .shard_config ,
331
- )[0 ]
332
316
if "reference_action_log_probs" in inputs :
333
317
per_token_kl = (
334
318
torch .exp (inputs ["reference_action_log_probs" ] - action_log_probs )
@@ -370,16 +354,15 @@ def _criterion(outputs, inputs):
370
354
mean_kl .append (kl )
371
355
mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
372
356
else :
373
-
374
357
policy_model_logits = self .policy_model (
375
358
input_ids = input_ids_forward_micro_batch ,
376
359
attention_mask = attention_mask_forward_micro_batch ,
377
360
).logits
378
- action_log_probs = calc_action_log_probs (
361
+ action_log_probs = memory_efficient_logprob (
379
362
policy_model_logits / self .generate_config ["temperature" ],
380
363
input_ids_forward_micro_batch ,
381
364
num_action ,
382
- self .plugin .shard_config ,
365
+ shard_config = self .plugin .shard_config ,
383
366
)
384
367
385
368
if self .policy_loss_fn .beta > 0 :
@@ -388,7 +371,7 @@ def _criterion(outputs, inputs):
388
371
input_ids = input_ids_forward_micro_batch ,
389
372
attention_mask = attention_mask_forward_micro_batch ,
390
373
).logits
391
- reference_action_log_probs = calc_action_log_probs (
374
+ reference_action_log_probs = memory_efficient_logprob (
392
375
reference_model_logits / self .generate_config ["temperature" ],
393
376
input_ids_forward_micro_batch ,
394
377
num_action ,
0 commit comments