@@ -1066,7 +1066,7 @@ def forward(
10661066 ** (extra_block_kwargs or {}),
10671067 )
10681068
1069- eagle_logits_0 .append (eagle_logits_ [ - input_ids . shape [ 1 ] :] )
1069+ eagle_logits_0 .append (eagle_logits_ )
10701070 eagle_logits_0 = torch .cat (eagle_logits_0 , dim = 0 )
10711071
10721072 # If labels are not provided, return the original logits. We only return after
@@ -1137,7 +1137,7 @@ def forward(
11371137 inference_context = eagle_inference_context ,
11381138 ** (extra_block_kwargs or {}),
11391139 )
1140- eagle_logits_1 .append (eagle_logits_ [ - input_ids . shape [ 1 ] :] )
1140+ eagle_logits_1 .append (eagle_logits_ )
11411141 eagle_logits_1 = torch .cat (eagle_logits_1 , dim = 0 )
11421142
11431143 for i in range (self .eagle_config .parallel_draft_step ):
@@ -1187,7 +1187,7 @@ def forward(
11871187 inference_context = eagle_inference_context ,
11881188 ** (extra_block_kwargs or {}),
11891189 )
1190- eagle_logits_2 .append (eagle_logits_ [ - input_ids . shape [ 1 ] :] )
1190+ eagle_logits_2 .append (eagle_logits_ )
11911191 eagle_logits_2 = torch .cat (eagle_logits_2 , dim = 0 )
11921192
11931193 for i in range (self .eagle_config .parallel_draft_step ):
@@ -1237,7 +1237,7 @@ def forward(
12371237 inference_context = eagle_inference_context ,
12381238 ** (extra_block_kwargs or {}),
12391239 )
1240- eagle_logits_3 .append (eagle_logits_ [ - input_ids . shape [ 1 ] :] )
1240+ eagle_logits_3 .append (eagle_logits_ )
12411241 eagle_logits_3 = torch .cat (eagle_logits_3 , dim = 0 )
12421242
12431243 for i in range (self .eagle_config .parallel_draft_step ):
0 commit comments