Skip to content

Commit fc4551a

Browse files
committed
only compute acc on tokens whose label is not IGNORE_TOKEN_ID
Signed-off-by: Ye Yu <[email protected]>
1 parent ab472f0 commit fc4551a

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@
6868
except ImportError:
6969
warnings.warn("Fail to import megatron.core.post_training! EAGLE feature will be disable!")
7070

71+
try:
72+
from transformers.trainer_pt_utils import LabelSmoother
73+
74+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
75+
except ImportError:
76+
IGNORE_TOKEN_ID = -100
77+
7178

7279
def dict_to_config(
7380
architecture_config,
@@ -1145,7 +1152,7 @@ def forward(
11451152
eagle_top1 += self.eagle_module.d2t[eagle_top1]
11461153
top1_p = (
11471154
torch.eq(labels[:, i + ttt_step + 1 :], eagle_top1).sum()
1148-
/ eagle_top1.numel()
1155+
/ (labels != IGNORE_TOKEN_ID).sum().item()
11491156
)
11501157
acc.append(top1_p)
11511158

0 commit comments

Comments
 (0)