@@ -1073,13 +1073,15 @@ def _compute_eagle_loss(self, logits, labels, eagle_logits):
1073
1073
"""Compute the total loss for EAGLE.
1074
1074
1075
1075
logits: [s, b, vocab // TP]
1076
- labels: [b, s]
1076
+ labels: [b, s] or [b, s-1] for offline mode
1077
1077
eagle_logits: [s, b, vocab // TP]
1078
1078
"""
1079
1079
# Compute lm loss (classification loss) or KLDivergence
1080
1080
if self .eagle_self_logit_distillation :
1081
1081
mapping = self .eagle_module .d2t if hasattr (self .eagle_module , "d2t" ) else None
1082
1082
token_loss = self .kld (eagle_logits [:- 1 , :, :], logits [1 :, :, :], mapping )
1083
+ elif labels .shape [1 ] < eagle_logits .shape [0 ]:
1084
+ token_loss = self .compute_language_model_loss (labels [:, 1 :], eagle_logits [:- 2 , :, :])
1083
1085
else :
1084
1086
token_loss = self .compute_language_model_loss (labels [:, 1 :], eagle_logits [:- 1 , :, :])
1085
1087
@@ -1279,26 +1281,20 @@ def forward(
1279
1281
# all eagle weights have been exercised for quantization calibration purpose.
1280
1282
if labels is None :
1281
1283
return logits_sbh .transpose (0 , 1 ).contiguous ()
1282
- elif labels .shape [1 ] == input_ids .shape [1 ] - 1 :
1283
- # For offline training, labels may be 1 token shorter than input_ids.
1284
- # We will just pad a 0 to the labels to make the seq_len the same as
1285
- # input_ids. This will introduce a small error in training if logit_distillation
1286
- # is False, and testing accuracy is wrong for the last token.
1287
- right_token_pad = torch .zeros (
1288
- (labels .shape [0 ], 1 ),
1289
- dtype = labels .dtype ,
1290
- device = labels .device ,
1291
- )
1292
- labels = torch .cat ((labels , right_token_pad ), dim = - 1 )
1293
1284
1294
1285
# If eagle_freeze_base_model is set to True,
1295
1286
# the base model is frozen .
1296
- loss = self .compute_language_model_loss (labels , logits_sbh )
1287
+ if self .eagle_offline :
1288
+ loss = torch .zeros (input_ids .shape ).to (input_ids .device )
1289
+ else :
1290
+ loss = self .compute_language_model_loss (labels , logits_sbh )
1297
1291
loss = 0.0 * loss
1298
1292
1299
1293
if self .eagle_config .parallel_draft_step > 1 :
1300
1294
for i in range (self .eagle_config .parallel_draft_step ):
1301
- eagle_logits = eagle_logits_0 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1295
+ eagle_logits = eagle_logits_0 [
1296
+ i * logits_sbh .shape [0 ] : (i + 1 ) * logits_sbh .shape [0 ]
1297
+ ]
1302
1298
loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1303
1299
loss_ = loss_ [:, i :]
1304
1300
loss [:, i + 1 :] += 1.0 * loss_
@@ -1311,7 +1307,7 @@ def forward(
1311
1307
acc = []
1312
1308
with torch .no_grad ():
1313
1309
gathered_logits = gather_from_tensor_model_parallel_region (
1314
- eagle_logits_0 [:- 1 , :, :]
1310
+ eagle_logits_0 [:- 2 , :, :] if self . eagle_offline else eagle_logits_0 [: - 1 , :, :]
1315
1311
)
1316
1312
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1317
1313
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1341,7 +1337,7 @@ def forward(
1341
1337
packed_seq_params = packed_seq_params ,
1342
1338
** (extra_block_kwargs or {}),
1343
1339
)
1344
- eagle_logits_1 = eagle_logits_2x [- labels .shape [1 ] :, :, :]
1340
+ eagle_logits_1 = eagle_logits_2x [logits_sbh .shape [0 ] :, :, :]
1345
1341
1346
1342
loss_1 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_1 )
1347
1343
# [b, s - 2]
@@ -1352,7 +1348,7 @@ def forward(
1352
1348
acc = []
1353
1349
with torch .no_grad ():
1354
1350
gathered_logits = gather_from_tensor_model_parallel_region (
1355
- eagle_logits_1 [1 :- 1 , :, :]
1351
+ eagle_logits_1 [1 :- 2 , :, :] if self . eagle_offline else eagle_logits_1 [ 1 : - 1 , :, :]
1356
1352
)
1357
1353
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1358
1354
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1383,7 +1379,7 @@ def forward(
1383
1379
** (extra_block_kwargs or {}),
1384
1380
)
1385
1381
1386
- eagle_logits_2 = eagle_logits_3x [- labels .shape [1 ] :, :, :]
1382
+ eagle_logits_2 = eagle_logits_3x [- logits_sbh .shape [0 ] :, :, :]
1387
1383
1388
1384
loss_2 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_2 )
1389
1385
# [b, s - 3]
@@ -1394,7 +1390,7 @@ def forward(
1394
1390
acc = []
1395
1391
with torch .no_grad ():
1396
1392
gathered_logits = gather_from_tensor_model_parallel_region (
1397
- eagle_logits_2 [2 :- 1 , :, :]
1393
+ eagle_logits_2 [2 :- 2 , :, :] if self . eagle_offline else eagle_logits_2 [ 2 : - 1 , :, :]
1398
1394
)
1399
1395
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1400
1396
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
@@ -1425,7 +1421,7 @@ def forward(
1425
1421
** (extra_block_kwargs or {}),
1426
1422
)
1427
1423
1428
- eagle_logits_3 = eagle_logits_4x [- labels .shape [1 ] :, :, :]
1424
+ eagle_logits_3 = eagle_logits_4x [- logits_sbh .shape [0 ] :, :, :]
1429
1425
1430
1426
loss_3 = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits_3 )
1431
1427
# [b, s - 4]
@@ -1436,7 +1432,7 @@ def forward(
1436
1432
acc = []
1437
1433
with torch .no_grad ():
1438
1434
gathered_logits = gather_from_tensor_model_parallel_region (
1439
- eagle_logits_3 [3 :- 1 , :, :]
1435
+ eagle_logits_3 [3 :- 2 , :, :] if self . eagle_offline else eagle_logits_3 [ 3 : - 1 , :, :]
1440
1436
)
1441
1437
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1442
1438
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
0 commit comments