@@ -1066,7 +1066,7 @@ def forward(
1066
1066
** (extra_block_kwargs or {}),
1067
1067
)
1068
1068
1069
- eagle_logits_0 .append (eagle_logits_ [ - input_ids . shape [ 1 ] :] )
1069
+ eagle_logits_0 .append (eagle_logits_ )
1070
1070
eagle_logits_0 = torch .cat (eagle_logits_0 , dim = 0 )
1071
1071
1072
1072
# If labels are not provided, return the original logits. We only return after
@@ -1137,7 +1137,7 @@ def forward(
1137
1137
inference_context = eagle_inference_context ,
1138
1138
** (extra_block_kwargs or {}),
1139
1139
)
1140
- eagle_logits_1 .append (eagle_logits_ [ - input_ids . shape [ 1 ] :] )
1140
+ eagle_logits_1 .append (eagle_logits_ )
1141
1141
eagle_logits_1 = torch .cat (eagle_logits_1 , dim = 0 )
1142
1142
1143
1143
for i in range (self .eagle_config .parallel_draft_step ):
@@ -1187,7 +1187,7 @@ def forward(
1187
1187
inference_context = eagle_inference_context ,
1188
1188
** (extra_block_kwargs or {}),
1189
1189
)
1190
- eagle_logits_2 .append (eagle_logits_ [ - input_ids . shape [ 1 ] :] )
1190
+ eagle_logits_2 .append (eagle_logits_ )
1191
1191
eagle_logits_2 = torch .cat (eagle_logits_2 , dim = 0 )
1192
1192
1193
1193
for i in range (self .eagle_config .parallel_draft_step ):
@@ -1237,7 +1237,7 @@ def forward(
1237
1237
inference_context = eagle_inference_context ,
1238
1238
** (extra_block_kwargs or {}),
1239
1239
)
1240
- eagle_logits_3 .append (eagle_logits_ [ - input_ids . shape [ 1 ] :] )
1240
+ eagle_logits_3 .append (eagle_logits_ )
1241
1241
eagle_logits_3 = torch .cat (eagle_logits_3 , dim = 0 )
1242
1242
1243
1243
for i in range (self .eagle_config .parallel_draft_step ):
0 commit comments