@@ -195,48 +195,72 @@ def set_multi_step_attention_mask(attn_mask, step):
195195 h0 h1 h2 h3 h4 h5 h6 h7 (base hidden_states)
196196 l0 l1 l2 l3 l4 l5 l6 l7 (base labels)
197197
198- ttt_step=2
199- parallel_draft_step=2
200- ->step=3
201-
202- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- |
203- (out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 |
204- =======================================================================================================================
205- F1 l1 | i1 h0 | x | | | |
206- F2 l2 | i2 h1 | x x | | | |
207- F3 l3 | i3 h2 | x x x | | | |
208- F4 l4 | i4 h3 | x x x x | | | |
209- F5 l5 | i5 h4 | x x x x x | | | |
210- F6 l6 | i6 h5 | x x x x x x | | | |
211- F7 l7 | i7 h6 | x x x x x x x | | | |
212- -- -- | -- h7 | o o o o o o o o | | | |
198+ ttt_steps=2
199+ parallel_draft_step=3
200+
201+
202+ ttt_step=0
203+ | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
204+ | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 |
205+ =============================================================================================
206+ F1 l1 | i1 h0 | x | | |
207+ F2 l2 | i2 h1 | x x | | |
208+ F3 l3 | i3 h2 | x x x | | |
209+ F4 l4 | i4 h3 | x x x x | | |
210+ F5 l5 | i5 h4 | x x x x x | | |
211+ F6 l6 | i6 h5 | x x x x x x | | |
212+ F7 l7 | i7 h6 | x x x x x x x | | |
213+ -- -- | -- h7 | o o o o o o o o | | |
214+ =============================================================================================
215+ -- -- | m0 -- | | | |
216+ G2 l2 | m0 h1 | x o | x | |
217+ G3 l3 | m0 h2 | x x o | x | |
218+ G4 l4 | m0 h3 | x x x o | x | |
219+ G5 l5 | m0 h4 | x x x x o | x | |
220+ G6 l6 | m0 h5 | x x x x x o | x | |
221+ G7 l7 | m0 h6 | x x x x x x o | x | |
222+ -- -- | -- h7 | | | |
223+ =============================================================================================
224+ -- -- | m1 -- | | | |
225+ -- -- | m1 h1 | | | |
226+ H3 l3 | m1 h2 | x o o | x o | x |
227+ H4 l4 | m1 h3 | x x o o | x o | x |
228+ H5 l5 | m1 h4 | x x x o o | x o | x |
229+ H6 l6 | m1 h5 | x x x x o o | x o | x |
230+ H7 l7 | m1 h6 | x x x x x o o | x o | x |
231+ -- -- | -- h7 | | | |
232+
233+
234+ ttt_step=1
235+ | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
236+ | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 |
213237 =======================================================================================================================
214238 -- -- | i1 -- | | | | |
215- G2 l2 | i2 h1 | x o | x | | |
216- G3 l3 | i3 h2 | x x o | x | | |
217- G4 l4 | i4 h3 | x x x o | x | | |
218- G5 l5 | i5 h4 | x x x x o | x | | |
219- G6 l6 | i6 h5 | x x x x x o | x | | |
220- G7 l7 | i7 h6 | x x x x x x o | x | | |
221- -- -- | -- h7 | | | | |
239+ J2 l2 | i2 F1 | x o | x | | |
240+ J3 l3 | i3 F2 | x x o | x | | |
241+ J4 l4 | i4 F3 | x x x o | x | | |
242+ J5 l5 | i5 F4 | x x x x o | x | | |
243+ J6 l6 | i6 F5 | x x x x x o | x | | |
244+ J7 l7 | i7 F6 | x x x x x x o | x | | |
245+ -- -- | -- F7 | | | | |
222246 =======================================================================================================================
223- -- -- | i1 -- | | | | |
224- -- -- | i2 -- | | | | |
225- H3 l3 | i3 G2 | x o o | x o | x | |
226- H4 l4 | i4 G3 | x x o o | x o | x | |
227- H5 l5 | i5 G4 | x x x o o | x o | x | |
228- H6 l6 | i6 G5 | x x x x o o | x o | x | |
229- H7 l7 | i7 G6 | x x x x x o o | x o | x | |
230- -- -- | -- G7 | | | | |
231- =======================================================================================================================
232- -- -- | m0 -- | | | | |
233247 -- -- | m0 -- | | | | |
234248 -- -- | m0 -- | | | | |
235- K4 l4 | m0 G3 | x | x | x | x |
236- K5 l5 | m0 G4 | x x | x | x | x |
237- K6 l6 | m0 G5 | x x x | x | x | x |
238- K7 l7 | m0 G6 | x x x x | x | x | x |
239- -- -- | -- G7 | | | | |
249+ K3 l3 | m0 F2 | x o o | x o | x | | |
250+ K4 l4 | m0 F3 | x x o o | x o | x | |
251+ K5 l5 | m0 F4 | x x x o o | x o | x | |
252+ K6 l6 | m0 F5 | x x x x o o | x o | x | |
253+ K7 l7 | m0 F6 | x x x x x o o | x o | x | |
254+ -- -- | -- F7 | | | | |
255+ =======================================================================================================================
256+ -- -- | m1 -- | | | | |
257+ -- -- | m1 -- | | | | |
258+ -- -- | m1 -- | | | | |
259+ N4 l4 | m1 F3 | x | x | x | x |
260+ N5 l5 | m1 F4 | x x | x | x | x |
261+ N6 l6 | m1 F5 | x x x | x | x | x |
262+ N7 l7 | m1 F6 | x x x x | x | x | x |
263+ -- -- | -- F7 | | | | |
240264 =======================================================================================================================
241265 """ # noqa: E501
242266 s = attn_mask .shape [- 1 ]
@@ -765,7 +789,6 @@ def _get_eagle_module_inputs(
765789 """Getting EAGLE module inputs."""
766790 b = hidden_states .shape [1 ]
767791 h = hidden_states .shape [2 ]
768- s = input_ids .shape [1 ]
769792
770793 # [b, 1]
771794 id_padding = torch .zeros ((b , 1 ), dtype = input_ids .dtype , device = input_ids .device )
@@ -801,8 +824,7 @@ def _get_eagle_module_inputs(
801824 else :
802825 gathered_hidden_states = hidden_states
803826 gathered_features = features
804- if gathered_features is not None :
805- feature = gathered_features [- s :]
827+
806828 eagle_inputs ["hidden_states" ] = (
807829 gathered_hidden_states
808830 if ttt_step == 0
@@ -813,7 +835,7 @@ def _get_eagle_module_inputs(
813835 dtype = hidden_states .dtype ,
814836 device = hidden_states .device ,
815837 ),
816- feature [:- 1 , :, :],
838+ gathered_features [:- 1 , :, :],
817839 )
818840 )
819841 )
@@ -824,12 +846,11 @@ def _get_eagle_module_inputs(
824846 )
825847
826848 eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
827- attn_mask , ttt_step * self . eagle_config . parallel_draft_step + parallel_draft_index
849+ attn_mask , ttt_step + parallel_draft_index
828850 )
829851
830852 eagle_inputs ["rotary_pos_emb" ] = torch .cat (
831- [rotary_pos_emb ]
832- * (ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_index + 1 ),
853+ [rotary_pos_emb ] * (ttt_step + parallel_draft_index + 1 ),
833854 dim = 0 ,
834855 )
835856
@@ -1015,7 +1036,7 @@ def forward(
10151036 # EAGLE kv cache
10161037 eagle_inference_context = StaticInferenceContext (
10171038 input_ids .shape [0 ],
1018- input_ids .shape [1 ] * self .eagle_config .parallel_draft_step * ttt_steps ,
1039+ input_ids .shape [1 ] * ( self .eagle_config .parallel_draft_step + ttt_steps - 1 ) ,
10191040 )
10201041
10211042 if self .eagle_offline :
@@ -1087,9 +1108,19 @@ def forward(
10871108 ** (extra_block_kwargs or {}),
10881109 )
10891110
1111+ if i == 0 :
1112+ next_eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
1113+
10901114 eagle_logits .append (eagle_logits_ )
10911115 eagle_logits = torch .cat (eagle_logits , dim = 0 )
1092- eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
1116+ eagle_hidden_states_pre_norm = next_eagle_hidden_states_pre_norm
1117+
1118+ # Discard kv cache for the last parallel_draft_step - 1 tokens
1119+ # as the next ttt_step will only base on the first token in the
1120+ # current ttt_step
1121+ eagle_inference_context .sequence_len_offset -= input_ids .shape [1 ] * (
1122+ self .eagle_config .parallel_draft_step - 1
1123+ )
10931124
10941125 # If labels are not provided, return the original logits. We only return after
10951126 # all eagle weights have been exercised for quantization calibration purpose.
0 commit comments