@@ -195,48 +195,72 @@ def set_multi_step_attention_mask(attn_mask, step):
195
195
h0 h1 h2 h3 h4 h5 h6 h7 (base hidden_states)
196
196
l0 l1 l2 l3 l4 l5 l6 l7 (base labels)
197
197
198
- ttt_step=2
199
- parallel_draft_step=2
200
- ->step=3
201
-
202
- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- |
203
- (out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 |
204
- =======================================================================================================================
205
- F1 l1 | i1 h0 | x | | | |
206
- F2 l2 | i2 h1 | x x | | | |
207
- F3 l3 | i3 h2 | x x x | | | |
208
- F4 l4 | i4 h3 | x x x x | | | |
209
- F5 l5 | i5 h4 | x x x x x | | | |
210
- F6 l6 | i6 h5 | x x x x x x | | | |
211
- F7 l7 | i7 h6 | x x x x x x x | | | |
212
- -- -- | -- h7 | o o o o o o o o | | | |
198
+ ttt_steps=2
199
+ parallel_draft_step=3
200
+
201
+
202
+ ttt_step=0
203
+ | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
204
+ | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 |
205
+ =============================================================================================
206
+ F1 l1 | i1 h0 | x | | |
207
+ F2 l2 | i2 h1 | x x | | |
208
+ F3 l3 | i3 h2 | x x x | | |
209
+ F4 l4 | i4 h3 | x x x x | | |
210
+ F5 l5 | i5 h4 | x x x x x | | |
211
+ F6 l6 | i6 h5 | x x x x x x | | |
212
+ F7 l7 | i7 h6 | x x x x x x x | | |
213
+ -- -- | -- h7 | o o o o o o o o | | |
214
+ =============================================================================================
215
+ -- -- | m0 -- | | | |
216
+ G2 l2 | m0 h1 | x o | x | |
217
+ G3 l3 | m0 h2 | x x o | x | |
218
+ G4 l4 | m0 h3 | x x x o | x | |
219
+ G5 l5 | m0 h4 | x x x x o | x | |
220
+ G6 l6 | m0 h5 | x x x x x o | x | |
221
+ G7 l7 | m0 h6 | x x x x x x o | x | |
222
+ -- -- | -- h7 | | | |
223
+ =============================================================================================
224
+ -- -- | m1 -- | | | |
225
+ -- -- | m1 h1 | | | |
226
+ H3 l3 | m1 h2 | x o o | x o | x |
227
+ H4 l4 | m1 h3 | x x o o | x o | x |
228
+ H5 l5 | m1 h4 | x x x o o | x o | x |
229
+ H6 l6 | m1 h5 | x x x x o o | x o | x |
230
+ H7 l7 | m1 h6 | x x x x x o o | x o | x |
231
+ -- -- | -- h7 | | | |
232
+
233
+
234
+ ttt_step=1
235
+ | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
236
+ | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 |
213
237
=======================================================================================================================
214
238
-- -- | i1 -- | | | | |
215
- G2 l2 | i2 h1 | x o | x | | |
216
- G3 l3 | i3 h2 | x x o | x | | |
217
- G4 l4 | i4 h3 | x x x o | x | | |
218
- G5 l5 | i5 h4 | x x x x o | x | | |
219
- G6 l6 | i6 h5 | x x x x x o | x | | |
220
- G7 l7 | i7 h6 | x x x x x x o | x | | |
221
- -- -- | -- h7 | | | | |
239
+ J2 l2 | i2 F1 | x o | x | | |
240
+ J3 l3 | i3 F2 | x x o | x | | |
241
+ J4 l4 | i4 F3 | x x x o | x | | |
242
+ J5 l5 | i5 F4 | x x x x o | x | | |
243
+ J6 l6 | i6 F5 | x x x x x o | x | | |
244
+ J7 l7 | i7 F6 | x x x x x x o | x | | |
245
+ -- -- | -- F7 | | | | |
222
246
=======================================================================================================================
223
- -- -- | i1 -- | | | | |
224
- -- -- | i2 -- | | | | |
225
- H3 l3 | i3 G2 | x o o | x o | x | |
226
- H4 l4 | i4 G3 | x x o o | x o | x | |
227
- H5 l5 | i5 G4 | x x x o o | x o | x | |
228
- H6 l6 | i6 G5 | x x x x o o | x o | x | |
229
- H7 l7 | i7 G6 | x x x x x o o | x o | x | |
230
- -- -- | -- G7 | | | | |
231
- =======================================================================================================================
232
- -- -- | m0 -- | | | | |
233
247
-- -- | m0 -- | | | | |
234
248
-- -- | m0 -- | | | | |
235
- K4 l4 | m0 G3 | x | x | x | x |
236
- K5 l5 | m0 G4 | x x | x | x | x |
237
- K6 l6 | m0 G5 | x x x | x | x | x |
238
- K7 l7 | m0 G6 | x x x x | x | x | x |
239
- -- -- | -- G7 | | | | |
249
+ K3 l3 | m0 F2 | x o o | x o | x | | |
250
+ K4 l4 | m0 F3 | x x o o | x o | x | |
251
+ K5 l5 | m0 F4 | x x x o o | x o | x | |
252
+ K6 l6 | m0 F5 | x x x x o o | x o | x | |
253
+ K7 l7 | m0 F6 | x x x x x o o | x o | x | |
254
+ -- -- | -- F7 | | | | |
255
+ =======================================================================================================================
256
+ -- -- | m1 -- | | | | |
257
+ -- -- | m1 -- | | | | |
258
+ -- -- | m1 -- | | | | |
259
+ N4 l4 | m1 F3 | x | x | x | x |
260
+ N5 l5 | m1 F4 | x x | x | x | x |
261
+ N6 l6 | m1 F5 | x x x | x | x | x |
262
+ N7 l7 | m1 F6 | x x x x | x | x | x |
263
+ -- -- | -- F7 | | | | |
240
264
=======================================================================================================================
241
265
""" # noqa: E501
242
266
s = attn_mask .shape [- 1 ]
@@ -765,7 +789,6 @@ def _get_eagle_module_inputs(
765
789
"""Getting EAGLE module inputs."""
766
790
b = hidden_states .shape [1 ]
767
791
h = hidden_states .shape [2 ]
768
- s = input_ids .shape [1 ]
769
792
770
793
# [b, 1]
771
794
id_padding = torch .zeros ((b , 1 ), dtype = input_ids .dtype , device = input_ids .device )
@@ -801,8 +824,7 @@ def _get_eagle_module_inputs(
801
824
else :
802
825
gathered_hidden_states = hidden_states
803
826
gathered_features = features
804
- if gathered_features is not None :
805
- feature = gathered_features [- s :]
827
+
806
828
eagle_inputs ["hidden_states" ] = (
807
829
gathered_hidden_states
808
830
if ttt_step == 0
@@ -813,7 +835,7 @@ def _get_eagle_module_inputs(
813
835
dtype = hidden_states .dtype ,
814
836
device = hidden_states .device ,
815
837
),
816
- feature [:- 1 , :, :],
838
+ gathered_features [:- 1 , :, :],
817
839
)
818
840
)
819
841
)
@@ -824,12 +846,11 @@ def _get_eagle_module_inputs(
824
846
)
825
847
826
848
eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
827
- attn_mask , ttt_step * self . eagle_config . parallel_draft_step + parallel_draft_index
849
+ attn_mask , ttt_step + parallel_draft_index
828
850
)
829
851
830
852
eagle_inputs ["rotary_pos_emb" ] = torch .cat (
831
- [rotary_pos_emb ]
832
- * (ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_index + 1 ),
853
+ [rotary_pos_emb ] * (ttt_step + parallel_draft_index + 1 ),
833
854
dim = 0 ,
834
855
)
835
856
@@ -1015,7 +1036,7 @@ def forward(
1015
1036
# EAGLE kv cache
1016
1037
eagle_inference_context = StaticInferenceContext (
1017
1038
input_ids .shape [0 ],
1018
- input_ids .shape [1 ] * self .eagle_config .parallel_draft_step * ttt_steps ,
1039
+ input_ids .shape [1 ] * ( self .eagle_config .parallel_draft_step + ttt_steps - 1 ) ,
1019
1040
)
1020
1041
1021
1042
if self .eagle_offline :
@@ -1087,9 +1108,19 @@ def forward(
1087
1108
** (extra_block_kwargs or {}),
1088
1109
)
1089
1110
1111
+ if i == 0 :
1112
+ next_eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
1113
+
1090
1114
eagle_logits .append (eagle_logits_ )
1091
1115
eagle_logits = torch .cat (eagle_logits , dim = 0 )
1092
- eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
1116
+ eagle_hidden_states_pre_norm = next_eagle_hidden_states_pre_norm
1117
+
1118
+ # Discard kv cache for the last parallel_draft_step - 1 tokens
1119
+ # as the next ttt_step will only base on the first token in the
1120
+ # current ttt_step
1121
+ eagle_inference_context .sequence_len_offset -= input_ids .shape [1 ] * (
1122
+ self .eagle_config .parallel_draft_step - 1
1123
+ )
1093
1124
1094
1125
# If labels are not provided, return the original logits. We only return after
1095
1126
# all eagle weights have been exercised for quantization calibration purpose.
0 commit comments