@@ -1073,13 +1073,15 @@ def _compute_eagle_loss(self, logits, labels, eagle_logits):
10731073 """Compute the total loss for EAGLE.
10741074
10751075 logits: [s, b, vocab // TP]
1076- labels: [b, s]
1076+ labels: [b, s] or [b, s-1] for offline mode
10771077 eagle_logits: [s, b, vocab // TP]
10781078 """
10791079 # Compute lm loss (classification loss) or KLDivergence
10801080 if self .eagle_self_logit_distillation :
10811081 mapping = self .eagle_module .d2t if hasattr (self .eagle_module , "d2t" ) else None
10821082 token_loss = self .kld (eagle_logits [:- 1 , :, :], logits [1 :, :, :], mapping )
1083+ elif labels .shape [1 ] < eagle_logits .shape [0 ]:
1084+ token_loss = self .compute_language_model_loss (labels [:, 1 :], eagle_logits [:- 2 , :, :])
10831085 else :
10841086 token_loss = self .compute_language_model_loss (labels [:, 1 :], eagle_logits [:- 1 , :, :])
10851087
@@ -1279,26 +1281,20 @@ def forward(
12791281 # all eagle weights have been exercised for quantization calibration purpose.
12801282 if labels is None :
12811283 return logits_sbh .transpose (0 , 1 ).contiguous ()
1282- elif labels .shape [1 ] == input_ids .shape [1 ] - 1 :
1283- # For offline training, labels may be 1 token shorter than input_ids.
1284- # We will just pad a 0 to the labels to make the seq_len the same as
1285- # input_ids. This will introduce a small error in training if logit_distillation
1286- # is False, and testing accuracy is wrong for the last token.
1287- right_token_pad = torch .zeros (
1288- (labels .shape [0 ], 1 ),
1289- dtype = labels .dtype ,
1290- device = labels .device ,
1291- )
1292- labels = torch .cat ((labels , right_token_pad ), dim = - 1 )
12931284
12941285 # If eagle_freeze_base_model is set to True,
12951286 # the base model is frozen .
1296- loss = self .compute_language_model_loss (labels , logits_sbh )
1287+ if self .eagle_offline :
1288+ loss = torch .zeros (input_ids .shape ).to (input_ids .device )
1289+ else :
1290+ loss = self .compute_language_model_loss (labels , logits_sbh )
12971291 loss = 0.0 * loss
12981292
12991293 if self .eagle_config .parallel_draft_step > 1 :
13001294 for i in range (self .eagle_config .parallel_draft_step ):
1301- eagle_logits = eagle_logits_0 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1295+ eagle_logits = eagle_logits_0 [
1296+ i * logits_sbh .shape [0 ] : (i + 1 ) * logits_sbh .shape [0 ]
1297+ ]
13021298 loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
13031299 loss_ = loss_ [:, i :]
13041300 loss [:, i + 1 :] += 1.0 * loss_
@@ -1311,7 +1307,7 @@ def forward(
13111307 acc = []
13121308 with torch .no_grad ():
13131309 gathered_logits = gather_from_tensor_model_parallel_region (
1314- eagle_logits_0 [:- 1 , :, :]
1310+ eagle_logits_0 [:- 2 , :, :] if self . eagle_offline else eagle_logits_0 [: - 1 , :, :]
13151311 )
13161312 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
13171313 if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1341,7 +1337,7 @@ def forward(
13411337 packed_seq_params = packed_seq_params ,
13421338 ** (extra_block_kwargs or {}),
13431339 )
1344- eagle_logits_1 = eagle_logits_2x [- labels .shape [1 ] :, :, :]
1340+ eagle_logits_1 = eagle_logits_2x [logits_sbh .shape [0 ] :, :, :]
13451341
13461342 loss_1 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_1 )
13471343 # [b, s - 2]
@@ -1352,7 +1348,7 @@ def forward(
13521348 acc = []
13531349 with torch .no_grad ():
13541350 gathered_logits = gather_from_tensor_model_parallel_region (
1355- eagle_logits_1 [1 :- 1 , :, :]
1351+ eagle_logits_1 [1 :- 2 , :, :] if self . eagle_offline else eagle_logits_1 [ 1 : - 1 , :, :]
13561352 )
13571353 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
13581354 if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1383,7 +1379,7 @@ def forward(
13831379 ** (extra_block_kwargs or {}),
13841380 )
13851381
1386- eagle_logits_2 = eagle_logits_3x [- labels .shape [1 ] :, :, :]
1382+ eagle_logits_2 = eagle_logits_3x [- logits_sbh .shape [0 ] :, :, :]
13871383
13881384 loss_2 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_2 )
13891385 # [b, s - 3]
@@ -1394,7 +1390,7 @@ def forward(
13941390 acc = []
13951391 with torch .no_grad ():
13961392 gathered_logits = gather_from_tensor_model_parallel_region (
1397- eagle_logits_2 [2 :- 1 , :, :]
1393+ eagle_logits_2 [2 :- 2 , :, :] if self . eagle_offline else eagle_logits_2 [ 2 : - 1 , :, :]
13981394 )
13991395 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
14001396 if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1425,7 +1421,7 @@ def forward(
14251421 ** (extra_block_kwargs or {}),
14261422 )
14271423
1428- eagle_logits_3 = eagle_logits_4x [- labels .shape [1 ] :, :, :]
1424+ eagle_logits_3 = eagle_logits_4x [- logits_sbh .shape [0 ] :, :, :]
14291425
14301426 loss_3 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_3 )
14311427 # [b, s - 4]
@@ -1436,7 +1432,7 @@ def forward(
14361432 acc = []
14371433 with torch .no_grad ():
14381434 gathered_logits = gather_from_tensor_model_parallel_region (
1439- eagle_logits_3 [3 :- 1 , :, :]
1435+ eagle_logits_3 [3 :- 2 , :, :] if self . eagle_offline else eagle_logits_3 [ 3 : - 1 , :, :]
14401436 )
14411437 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
14421438 if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
0 commit comments