We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ab472f0 commit fc4551aCopy full SHA for fc4551a
modelopt/torch/speculative/plugins/megatron_eagle.py
@@ -68,6 +68,13 @@
68
except ImportError:
69
warnings.warn("Fail to import megatron.core.post_training! EAGLE feature will be disable!")
70
71
+try:
72
+ from transformers.trainer_pt_utils import LabelSmoother
73
+
74
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
75
+except ImportError:
76
+ IGNORE_TOKEN_ID = -100
77
78
79
def dict_to_config(
80
architecture_config,
@@ -1145,7 +1152,7 @@ def forward(
1145
1152
eagle_top1 += self.eagle_module.d2t[eagle_top1]
1146
1153
top1_p = (
1147
1154
torch.eq(labels[:, i + ttt_step + 1 :], eagle_top1).sum()
1148
- / eagle_top1.numel()
1155
+ / (labels != IGNORE_TOKEN_ID).sum().item()
1149
1156
)
1150
1157
acc.append(top1_p)
1151
1158
0 commit comments