Skip to content

Commit c839006

Browse files
committed
remove redundant index as now eagle_logits_ has a length of seq-len
Signed-off-by: Ye Yu <[email protected]>
1 parent b0edb38 commit c839006

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def forward(
10661066
**(extra_block_kwargs or {}),
10671067
)
10681068

1069-
eagle_logits_0.append(eagle_logits_[-input_ids.shape[1] :])
1069+
eagle_logits_0.append(eagle_logits_)
10701070
eagle_logits_0 = torch.cat(eagle_logits_0, dim=0)
10711071

10721072
# If labels are not provided, return the original logits. We only return after
@@ -1137,7 +1137,7 @@ def forward(
11371137
inference_context=eagle_inference_context,
11381138
**(extra_block_kwargs or {}),
11391139
)
1140-
eagle_logits_1.append(eagle_logits_[-input_ids.shape[1] :])
1140+
eagle_logits_1.append(eagle_logits_)
11411141
eagle_logits_1 = torch.cat(eagle_logits_1, dim=0)
11421142

11431143
for i in range(self.eagle_config.parallel_draft_step):
@@ -1187,7 +1187,7 @@ def forward(
11871187
inference_context=eagle_inference_context,
11881188
**(extra_block_kwargs or {}),
11891189
)
1190-
eagle_logits_2.append(eagle_logits_[-input_ids.shape[1] :])
1190+
eagle_logits_2.append(eagle_logits_)
11911191
eagle_logits_2 = torch.cat(eagle_logits_2, dim=0)
11921192

11931193
for i in range(self.eagle_config.parallel_draft_step):
@@ -1237,7 +1237,7 @@ def forward(
12371237
inference_context=eagle_inference_context,
12381238
**(extra_block_kwargs or {}),
12391239
)
1240-
eagle_logits_3.append(eagle_logits_[-input_ids.shape[1] :])
1240+
eagle_logits_3.append(eagle_logits_)
12411241
eagle_logits_3 = torch.cat(eagle_logits_3, dim=0)
12421242

12431243
for i in range(self.eagle_config.parallel_draft_step):

0 commit comments

Comments
 (0)