@@ -1073,15 +1073,13 @@ def _compute_eagle_loss(self, logits, labels, eagle_logits):
10731073 """Compute the total loss for EAGLE.
10741074
10751075 logits: [s, b, vocab // TP]
1076- labels: [b, s] or [b, s-1] for offline mode
1076+ labels: [b, s]
10771077 eagle_logits: [s, b, vocab // TP]
10781078 """
10791079 # Compute lm loss (classification loss) or KLDivergence
10801080 if self .eagle_self_logit_distillation :
10811081 mapping = self .eagle_module .d2t if hasattr (self .eagle_module , "d2t" ) else None
10821082 token_loss = self .kld (eagle_logits [:- 1 , :, :], logits [1 :, :, :], mapping )
1083- elif labels .shape [1 ] < eagle_logits .shape [0 ]:
1084- token_loss = self .compute_language_model_loss (labels [:, 1 :], eagle_logits [:- 2 , :, :])
10851083 else :
10861084 token_loss = self .compute_language_model_loss (labels [:, 1 :], eagle_logits [:- 1 , :, :])
10871085
@@ -1281,20 +1279,26 @@ def forward(
12811279 # all eagle weights have been exercised for quantization calibration purpose.
12821280 if labels is None :
12831281 return logits_sbh .transpose (0 , 1 ).contiguous ()
1282+ elif labels .shape [1 ] == input_ids .shape [1 ] - 1 :
1283+ # For offline training, labels may be 1 token shorter than input_ids.
1284+ # We will just pad a 0 to the labels to make the seq_len the same as
1285+ # input_ids. This will introduce a small error in training if logit_distillation
1286+ # is False, and testing accuracy is wrong for the last token.
1287+ right_token_pad = torch .zeros (
1288+ (labels .shape [0 ], 1 ),
1289+ dtype = labels .dtype ,
1290+ device = labels .device ,
1291+ )
1292+ labels = torch .cat ((labels , right_token_pad ), dim = - 1 )
12841293
12851294 # If eagle_freeze_base_model is set to True,
12861295 # the base model is frozen .
1287- if self .eagle_offline :
1288- loss = torch .zeros (input_ids .shape ).to (input_ids .device )
1289- else :
1290- loss = self .compute_language_model_loss (labels , logits_sbh )
1296+ loss = self .compute_language_model_loss (labels , logits_sbh )
12911297 loss = 0.0 * loss
12921298
12931299 if self .eagle_config .parallel_draft_step > 1 :
12941300 for i in range (self .eagle_config .parallel_draft_step ):
1295- eagle_logits = eagle_logits_0 [
1296- i * logits_sbh .shape [0 ] : (i + 1 ) * logits_sbh .shape [0 ]
1297- ]
1301+ eagle_logits = eagle_logits_0 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
12981302 loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
12991303 loss_ = loss_ [:, i :]
13001304 loss [:, i + 1 :] += 1.0 * loss_
@@ -1307,7 +1311,7 @@ def forward(
13071311 acc = []
13081312 with torch .no_grad ():
13091313 gathered_logits = gather_from_tensor_model_parallel_region (
1310- eagle_logits_0 [:- 2 , :, :] if self . eagle_offline else eagle_logits_0 [: - 1 , :, :]
1314+ eagle_logits_0 [:- 1 , :, :]
13111315 )
13121316 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
13131317 if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1337,7 +1341,7 @@ def forward(
13371341 packed_seq_params = packed_seq_params ,
13381342 ** (extra_block_kwargs or {}),
13391343 )
1340- eagle_logits_1 = eagle_logits_2x [logits_sbh .shape [0 ] :, :, :]
1344+ eagle_logits_1 = eagle_logits_2x [- labels .shape [1 ] :, :, :]
13411345
13421346 loss_1 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_1 )
13431347 # [b, s - 2]
@@ -1348,7 +1352,7 @@ def forward(
13481352 acc = []
13491353 with torch .no_grad ():
13501354 gathered_logits = gather_from_tensor_model_parallel_region (
1351- eagle_logits_1 [1 :- 2 , :, :] if self . eagle_offline else eagle_logits_1 [ 1 : - 1 , :, :]
1355+ eagle_logits_1 [1 :- 1 , :, :]
13521356 )
13531357 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
13541358 if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1379,7 +1383,7 @@ def forward(
13791383 ** (extra_block_kwargs or {}),
13801384 )
13811385
1382- eagle_logits_2 = eagle_logits_3x [- logits_sbh .shape [0 ] :, :, :]
1386+ eagle_logits_2 = eagle_logits_3x [- labels .shape [1 ] :, :, :]
13831387
13841388 loss_2 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_2 )
13851389 # [b, s - 3]
@@ -1390,7 +1394,7 @@ def forward(
13901394 acc = []
13911395 with torch .no_grad ():
13921396 gathered_logits = gather_from_tensor_model_parallel_region (
1393- eagle_logits_2 [2 :- 2 , :, :] if self . eagle_offline else eagle_logits_2 [ 2 : - 1 , :, :]
1397+ eagle_logits_2 [2 :- 1 , :, :]
13941398 )
13951399 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
13961400 if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1421,7 +1425,7 @@ def forward(
14211425 ** (extra_block_kwargs or {}),
14221426 )
14231427
1424- eagle_logits_3 = eagle_logits_4x [- logits_sbh .shape [0 ] :, :, :]
1428+ eagle_logits_3 = eagle_logits_4x [- labels .shape [1 ] :, :, :]
14251429
14261430 loss_3 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_3 )
14271431 # [b, s - 4]
@@ -1432,7 +1436,7 @@ def forward(
14321436 acc = []
14331437 with torch .no_grad ():
14341438 gathered_logits = gather_from_tensor_model_parallel_region (
1435- eagle_logits_3 [3 :- 2 , :, :] if self . eagle_offline else eagle_logits_3 [ 3 : - 1 , :, :]
1439+ eagle_logits_3 [3 :- 1 , :, :]
14361440 )
14371441 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
14381442 if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
0 commit comments