Skip to content

Commit 0582079

Browse files
committed
Revert "pad labels if it's 1 token shorter than input_ids"
This reverts commit 9450e0d. Signed-off-by: Ye Yu <[email protected]>
1 parent 912e427 commit 0582079

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,13 +1073,15 @@ def _compute_eagle_loss(self, logits, labels, eagle_logits):
10731073
"""Compute the total loss for EAGLE.
10741074
10751075
logits: [s, b, vocab // TP]
1076-
labels: [b, s]
1076+
labels: [b, s] or [b, s-1] for offline mode
10771077
eagle_logits: [s, b, vocab // TP]
10781078
"""
10791079
# Compute lm loss (classification loss) or KLDivergence
10801080
if self.eagle_self_logit_distillation:
10811081
mapping = self.eagle_module.d2t if hasattr(self.eagle_module, "d2t") else None
10821082
token_loss = self.kld(eagle_logits[:-1, :, :], logits[1:, :, :], mapping)
1083+
elif labels.shape[1] < eagle_logits.shape[0]:
1084+
token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-2, :, :])
10831085
else:
10841086
token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-1, :, :])
10851087

@@ -1279,26 +1281,20 @@ def forward(
12791281
# all eagle weights have been exercised for quantization calibration purpose.
12801282
if labels is None:
12811283
return logits_sbh.transpose(0, 1).contiguous()
1282-
elif labels.shape[1] == input_ids.shape[1] - 1:
1283-
# For offline training, labels may be 1 token shorter than input_ids.
1284-
# We will just pad a 0 to the labels to make the seq_len the same as
1285-
# input_ids. This will introduce a small error in training if logit_distillation
1286-
# is False, and testing accuracy is wrong for the last token.
1287-
right_token_pad = torch.zeros(
1288-
(labels.shape[0], 1),
1289-
dtype=labels.dtype,
1290-
device=labels.device,
1291-
)
1292-
labels = torch.cat((labels, right_token_pad), dim=-1)
12931284

12941285
# If eagle_freeze_base_model is set to True,
12951286
# the base model is frozen .
1296-
loss = self.compute_language_model_loss(labels, logits_sbh)
1287+
if self.eagle_offline:
1288+
loss = torch.zeros(input_ids.shape).to(input_ids.device)
1289+
else:
1290+
loss = self.compute_language_model_loss(labels, logits_sbh)
12971291
loss = 0.0 * loss
12981292

12991293
if self.eagle_config.parallel_draft_step > 1:
13001294
for i in range(self.eagle_config.parallel_draft_step):
1301-
eagle_logits = eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1295+
eagle_logits = eagle_logits_0[
1296+
i * logits_sbh.shape[0] : (i + 1) * logits_sbh.shape[0]
1297+
]
13021298
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
13031299
loss_ = loss_[:, i:]
13041300
loss[:, i + 1 :] += 1.0 * loss_
@@ -1311,7 +1307,7 @@ def forward(
13111307
acc = []
13121308
with torch.no_grad():
13131309
gathered_logits = gather_from_tensor_model_parallel_region(
1314-
eagle_logits_0[:-1, :, :]
1310+
eagle_logits_0[:-2, :, :] if self.eagle_offline else eagle_logits_0[:-1, :, :]
13151311
)
13161312
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
13171313
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
@@ -1341,7 +1337,7 @@ def forward(
13411337
packed_seq_params=packed_seq_params,
13421338
**(extra_block_kwargs or {}),
13431339
)
1344-
eagle_logits_1 = eagle_logits_2x[-labels.shape[1] :, :, :]
1340+
eagle_logits_1 = eagle_logits_2x[logits_sbh.shape[0] :, :, :]
13451341

13461342
loss_1 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_1)
13471343
# [b, s - 2]
@@ -1352,7 +1348,7 @@ def forward(
13521348
acc = []
13531349
with torch.no_grad():
13541350
gathered_logits = gather_from_tensor_model_parallel_region(
1355-
eagle_logits_1[1:-1, :, :]
1351+
eagle_logits_1[1:-2, :, :] if self.eagle_offline else eagle_logits_1[1:-1, :, :]
13561352
)
13571353
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
13581354
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
@@ -1383,7 +1379,7 @@ def forward(
13831379
**(extra_block_kwargs or {}),
13841380
)
13851381

1386-
eagle_logits_2 = eagle_logits_3x[-labels.shape[1] :, :, :]
1382+
eagle_logits_2 = eagle_logits_3x[-logits_sbh.shape[0] :, :, :]
13871383

13881384
loss_2 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_2)
13891385
# [b, s - 3]
@@ -1394,7 +1390,7 @@ def forward(
13941390
acc = []
13951391
with torch.no_grad():
13961392
gathered_logits = gather_from_tensor_model_parallel_region(
1397-
eagle_logits_2[2:-1, :, :]
1393+
eagle_logits_2[2:-2, :, :] if self.eagle_offline else eagle_logits_2[2:-1, :, :]
13981394
)
13991395
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
14001396
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
@@ -1425,7 +1421,7 @@ def forward(
14251421
**(extra_block_kwargs or {}),
14261422
)
14271423

1428-
eagle_logits_3 = eagle_logits_4x[-labels.shape[1] :, :, :]
1424+
eagle_logits_3 = eagle_logits_4x[-logits_sbh.shape[0] :, :, :]
14291425

14301426
loss_3 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_3)
14311427
# [b, s - 4]
@@ -1436,7 +1432,7 @@ def forward(
14361432
acc = []
14371433
with torch.no_grad():
14381434
gathered_logits = gather_from_tensor_model_parallel_region(
1439-
eagle_logits_3[3:-1, :, :]
1435+
eagle_logits_3[3:-2, :, :] if self.eagle_offline else eagle_logits_3[3:-1, :, :]
14401436
)
14411437
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
14421438
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:

0 commit comments

Comments
 (0)