@@ -1073,15 +1073,13 @@ def _compute_eagle_loss(self, logits, labels, eagle_logits):
1073
1073
"""Compute the total loss for EAGLE.
1074
1074
1075
1075
logits: [s, b, vocab // TP]
1076
- labels: [b, s] or [b, s-1] for offline mode
1076
+ labels: [b, s]
1077
1077
eagle_logits: [s, b, vocab // TP]
1078
1078
"""
1079
1079
# Compute lm loss (classification loss) or KLDivergence
1080
1080
if self .eagle_self_logit_distillation :
1081
1081
mapping = self .eagle_module .d2t if hasattr (self .eagle_module , "d2t" ) else None
1082
1082
token_loss = self .kld (eagle_logits [:- 1 , :, :], logits [1 :, :, :], mapping )
1083
- elif labels .shape [1 ] < eagle_logits .shape [0 ]:
1084
- token_loss = self .compute_language_model_loss (labels [:, 1 :], eagle_logits [:- 2 , :, :])
1085
1083
else :
1086
1084
token_loss = self .compute_language_model_loss (labels [:, 1 :], eagle_logits [:- 1 , :, :])
1087
1085
@@ -1281,20 +1279,26 @@ def forward(
1281
1279
# all eagle weights have been exercised for quantization calibration purpose.
1282
1280
if labels is None :
1283
1281
return logits_sbh .transpose (0 , 1 ).contiguous ()
1282
+ elif labels .shape [1 ] == input_ids .shape [1 ] - 1 :
1283
+ # For offline training, labels may be 1 token shorter than input_ids.
1284
+ # We will just pad a 0 to the labels to make the seq_len the same as
1285
+ # input_ids. This will introduce a small error in training if logit_distillation
1286
+ # is False, and testing accuracy is wrong for the last token.
1287
+ right_token_pad = torch .zeros (
1288
+ (labels .shape [0 ], 1 ),
1289
+ dtype = labels .dtype ,
1290
+ device = labels .device ,
1291
+ )
1292
+ labels = torch .cat ((labels , right_token_pad ), dim = - 1 )
1284
1293
1285
1294
# If eagle_freeze_base_model is set to True,
1286
1295
# the base model is frozen .
1287
- if self .eagle_offline :
1288
- loss = torch .zeros (input_ids .shape ).to (input_ids .device )
1289
- else :
1290
- loss = self .compute_language_model_loss (labels , logits_sbh )
1296
+ loss = self .compute_language_model_loss (labels , logits_sbh )
1291
1297
loss = 0.0 * loss
1292
1298
1293
1299
if self .eagle_config .parallel_draft_step > 1 :
1294
1300
for i in range (self .eagle_config .parallel_draft_step ):
1295
- eagle_logits = eagle_logits_0 [
1296
- i * logits_sbh .shape [0 ] : (i + 1 ) * logits_sbh .shape [0 ]
1297
- ]
1301
+ eagle_logits = eagle_logits_0 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1298
1302
loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1299
1303
loss_ = loss_ [:, i :]
1300
1304
loss [:, i + 1 :] += 1.0 * loss_
@@ -1307,7 +1311,7 @@ def forward(
1307
1311
acc = []
1308
1312
with torch .no_grad ():
1309
1313
gathered_logits = gather_from_tensor_model_parallel_region (
1310
- eagle_logits_0 [:- 2 , :, :] if self . eagle_offline else eagle_logits_0 [: - 1 , :, :]
1314
+ eagle_logits_0 [:- 1 , :, :]
1311
1315
)
1312
1316
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1313
1317
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1337,7 +1341,7 @@ def forward(
1337
1341
packed_seq_params = packed_seq_params ,
1338
1342
** (extra_block_kwargs or {}),
1339
1343
)
1340
- eagle_logits_1 = eagle_logits_2x [logits_sbh .shape [0 ] :, :, :]
1344
+ eagle_logits_1 = eagle_logits_2x [- labels .shape [1 ] :, :, :]
1341
1345
1342
1346
loss_1 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_1 )
1343
1347
# [b, s - 2]
@@ -1348,7 +1352,7 @@ def forward(
1348
1352
acc = []
1349
1353
with torch .no_grad ():
1350
1354
gathered_logits = gather_from_tensor_model_parallel_region (
1351
- eagle_logits_1 [1 :- 2 , :, :] if self . eagle_offline else eagle_logits_1 [ 1 : - 1 , :, :]
1355
+ eagle_logits_1 [1 :- 1 , :, :]
1352
1356
)
1353
1357
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1354
1358
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1379,7 +1383,7 @@ def forward(
1379
1383
** (extra_block_kwargs or {}),
1380
1384
)
1381
1385
1382
- eagle_logits_2 = eagle_logits_3x [- logits_sbh .shape [0 ] :, :, :]
1386
+ eagle_logits_2 = eagle_logits_3x [- labels .shape [1 ] :, :, :]
1383
1387
1384
1388
loss_2 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_2 )
1385
1389
# [b, s - 3]
@@ -1390,7 +1394,7 @@ def forward(
1390
1394
acc = []
1391
1395
with torch .no_grad ():
1392
1396
gathered_logits = gather_from_tensor_model_parallel_region (
1393
- eagle_logits_2 [2 :- 2 , :, :] if self . eagle_offline else eagle_logits_2 [ 2 : - 1 , :, :]
1397
+ eagle_logits_2 [2 :- 1 , :, :]
1394
1398
)
1395
1399
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1396
1400
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1421,7 +1425,7 @@ def forward(
1421
1425
** (extra_block_kwargs or {}),
1422
1426
)
1423
1427
1424
- eagle_logits_3 = eagle_logits_4x [- logits_sbh .shape [0 ] :, :, :]
1428
+ eagle_logits_3 = eagle_logits_4x [- labels .shape [1 ] :, :, :]
1425
1429
1426
1430
loss_3 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_3 )
1427
1431
# [b, s - 4]
@@ -1432,7 +1436,7 @@ def forward(
1432
1436
acc = []
1433
1437
with torch .no_grad ():
1434
1438
gathered_logits = gather_from_tensor_model_parallel_region (
1435
- eagle_logits_3 [3 :- 2 , :, :] if self . eagle_offline else eagle_logits_3 [ 3 : - 1 , :, :]
1439
+ eagle_logits_3 [3 :- 1 , :, :]
1436
1440
)
1437
1441
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1438
1442
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
0 commit comments