@@ -201,7 +201,7 @@ def set_multi_step_attention_mask(attn_mask, step):
201201
202202 ttt_step=0
203203 | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
204- | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 |
204+ | h0 h1 h2 h3 h4 h5 h6 h7 | M0 M0 M0 M0 M0 M0 M0 M0 | M1 M1 M1 M1 M1 M1 M1 M1 |
205205 =============================================================================================
206206 F1 l1 | i1 h0 | x | | |
207207 F2 l2 | i2 h1 | x x | | |
@@ -212,28 +212,28 @@ def set_multi_step_attention_mask(attn_mask, step):
212212 F7 l7 | i7 h6 | x x x x x x x | | |
213213 -- -- | -- h7 | o o o o o o o o | | |
214214 =============================================================================================
215- -- -- | m0 -- | | | |
216- G2 l2 | m0 h1 | x o | x | |
217- G3 l3 | m0 h2 | x x o | x | |
218- G4 l4 | m0 h3 | x x x o | x | |
219- G5 l5 | m0 h4 | x x x x o | x | |
220- G6 l6 | m0 h5 | x x x x x o | x | |
221- G7 l7 | m0 h6 | x x x x x x o | x | |
222- -- -- | -- h7 | | | |
215+ -- -- | m0 M0 | | | |
216+ G2 l2 | m0 M0 | x o | x | |
217+ G3 l3 | m0 M0 | x x o | x | |
218+ G4 l4 | m0 M0 | x x x o | x | |
219+ G5 l5 | m0 M0 | x x x x o | x | |
220+ G6 l6 | m0 M0 | x x x x x o | x | |
221+ G7 l7 | m0 M0 | x x x x x x o | x | |
222+ -- -- | -- M0 | | | |
223223 =============================================================================================
224- -- -- | m1 -- | | | |
225- -- -- | m1 h1 | | | |
226- H3 l3 | m1 h2 | x o o | x o | x |
227- H4 l4 | m1 h3 | x x o o | x o | x |
228- H5 l5 | m1 h4 | x x x o o | x o | x |
229- H6 l6 | m1 h5 | x x x x o o | x o | x |
230- H7 l7 | m1 h6 | x x x x x o o | x o | x |
231- -- -- | -- h7 | | | |
224+ -- -- | m1 M0 | | | |
225+ -- -- | m1 M1 | | | |
226+ H3 l3 | m1 M1 | x o o | x o | x |
227+ H4 l4 | m1 M1 | x x o o | x o | x |
228+ H5 l5 | m1 M1 | x x x o o | x o | x |
229+ H6 l6 | m1 M1 | x x x x o o | x o | x |
230+ H7 l7 | m1 M1 | x x x x x o o | x o | x |
231+ -- -- | -- M1 | | | |
232232
233233
234234 ttt_step=1
235235 | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
236- | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 |
236+ | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | M0 M0 M0 M0 M0 M0 M0 M0 | M1 M1 M1 M1 M1 M1 M1 M1 |
237237 =======================================================================================================================
238238 -- -- | i1 -- | | | | |
239239 J2 l2 | i2 F1 | x o | x | | |
@@ -244,23 +244,23 @@ def set_multi_step_attention_mask(attn_mask, step):
244244 J7 l7 | i7 F6 | x x x x x x o | x | | |
245245 -- -- | -- F7 | | | | |
246246 =======================================================================================================================
247- -- -- | m0 -- | | | | |
248- -- -- | m0 -- | | | | |
249- K3 l3 | m0 F2 | x o o | x o | x | | |
250- K4 l4 | m0 F3 | x x o o | x o | x | |
251- K5 l5 | m0 F4 | x x x o o | x o | x | |
252- K6 l6 | m0 F5 | x x x x o o | x o | x | |
253- K7 l7 | m0 F6 | x x x x x o o | x o | x | |
254- -- -- | -- F7 | | | | |
247+ -- -- | m0 M0 | | | | |
248+ -- -- | m0 M0 | | | | |
249+ K3 l3 | m0 M0 | x o o | x o | x | | |
250+ K4 l4 | m0 M0 | x x o o | x o | x | |
251+ K5 l5 | m0 M0 | x x x o o | x o | x | |
252+ K6 l6 | m0 M0 | x x x x o o | x o | x | |
253+ K7 l7 | m0 M0 | x x x x x o o | x o | x | |
254+ -- -- | -- M0 | | | | |
255255 =======================================================================================================================
256- -- -- | m1 -- | | | | |
257- -- -- | m1 -- | | | | |
258- -- -- | m1 -- | | | | |
259- N4 l4 | m1 F3 | x | x | x | x |
260- N5 l5 | m1 F4 | x x | x | x | x |
261- N6 l6 | m1 F5 | x x x | x | x | x |
262- N7 l7 | m1 F6 | x x x x | x | x | x |
263- -- -- | -- F7 | | | | |
256+ -- -- | m1 M1 | | | | |
257+ -- -- | m1 M1 | | | | |
258+ -- -- | m1 M1 | | | | |
259+ N4 l4 | m1 M1 | x | x | x | x |
260+ N5 l5 | m1 M1 | x x | x | x | x |
261+ N6 l6 | m1 M1 | x x x | x | x | x |
262+ N7 l7 | m1 M1 | x x x x | x | x | x |
263+ -- -- | -- M1 | | | | |
264264 =======================================================================================================================
265265 """ # noqa: E501
266266 s = attn_mask .shape [- 1 ]
@@ -782,16 +782,14 @@ def _get_eagle_module_inputs(
782782 hidden_states : torch .Tensor ,
783783 attention_mask : torch .Tensor ,
784784 position_ids : torch .Tensor ,
785- features : torch .Tensor | None = None ,
786785 ttt_step : int = 0 ,
787786 parallel_draft_index : int = 0 ,
788787 ):
789788 """Getting EAGLE module inputs."""
790- b = hidden_states .shape [1 ]
791- h = hidden_states .shape [2 ]
792-
793789 # [b, 1]
794- id_padding = torch .zeros ((b , 1 ), dtype = input_ids .dtype , device = input_ids .device )
790+ id_padding = torch .zeros (
791+ (input_ids .shape [0 ], 1 ), dtype = input_ids .dtype , device = input_ids .device
792+ )
795793 padded_input_ids = torch .cat ((input_ids [:, 1 :], id_padding ), dim = - 1 )
796794
797795 rotary_pos_emb = self .eagle_module .rotary_pos_emb (padded_input_ids .shape [- 1 ])
@@ -816,35 +814,15 @@ def _get_eagle_module_inputs(
816814 )
817815 )
818816
819- if self .config .sequence_parallel :
820- gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
821- gathered_features = (
822- None if features is None else gather_from_sequence_parallel_region (features )
823- )
824- else :
825- gathered_hidden_states = hidden_states
826- gathered_features = features
817+ eagle_inputs ["embedding" ] = self .embedding (
818+ input_ids = eagle_inputs ["input_ids" ],
819+ position_ids = eagle_inputs ["position_ids" ],
820+ )
827821
828822 eagle_inputs ["hidden_states" ] = (
829- gathered_hidden_states
830- if ttt_step == 0
831- else torch .cat (
832- (
833- torch .zeros (
834- (1 , b , h ),
835- dtype = hidden_states .dtype ,
836- device = hidden_states .device ,
837- ),
838- gathered_features [:- 1 , :, :], # type: ignore[index]
839- )
840- )
823+ hidden_states if parallel_draft_index == 0 else eagle_inputs ["embedding" ]
841824 )
842825
843- if self .config .sequence_parallel :
844- eagle_inputs ["hidden_states" ] = scatter_to_sequence_parallel_region (
845- eagle_inputs ["hidden_states" ]
846- )
847-
848826 eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
849827 attn_mask , ttt_step + parallel_draft_index
850828 )
@@ -854,11 +832,6 @@ def _get_eagle_module_inputs(
854832 dim = 0 ,
855833 )
856834
857- eagle_inputs ["embedding" ] = self .embedding (
858- input_ids = eagle_inputs ["input_ids" ],
859- position_ids = eagle_inputs ["position_ids" ],
860- )
861-
862835 return eagle_inputs
863836
864837 def _compute_eagle_loss (self , logits , labels , eagle_logits ):
@@ -1086,7 +1059,6 @@ def forward(
10861059 loss = 0.0 * loss
10871060
10881061 acc = []
1089- eagle_hidden_states_pre_norm = None
10901062 for ttt_step in range (ttt_steps ):
10911063 eagle_logits = []
10921064 for i in range (self .eagle_config .parallel_draft_step ):
@@ -1095,7 +1067,6 @@ def forward(
10951067 hidden_states = eagle_module_input_hidden_states ,
10961068 attention_mask = attention_mask ,
10971069 position_ids = position_ids ,
1098- features = eagle_hidden_states_pre_norm ,
10991070 ttt_step = ttt_step ,
11001071 parallel_draft_index = i ,
11011072 )
@@ -1114,7 +1085,29 @@ def forward(
11141085
11151086 eagle_logits .append (eagle_logits_ )
11161087 eagle_logits = torch .cat (eagle_logits , dim = 0 )
1117- eagle_hidden_states_pre_norm = next_eagle_hidden_states_pre_norm
1088+ eagle_module_input_hidden_states = next_eagle_hidden_states_pre_norm
1089+ if self .config .sequence_parallel :
1090+ eagle_module_input_hidden_states = gather_from_sequence_parallel_region (
1091+ eagle_module_input_hidden_states
1092+ )
1093+ eagle_module_input_hidden_states = torch .cat (
1094+ (
1095+ torch .zeros (
1096+ (
1097+ 1 ,
1098+ eagle_module_input_hidden_states .shape [1 ],
1099+ eagle_module_input_hidden_states .shape [2 ],
1100+ ),
1101+ dtype = eagle_module_input_hidden_states .dtype ,
1102+ device = eagle_module_input_hidden_states .device ,
1103+ ),
1104+ eagle_module_input_hidden_states [:- 1 , :, :],
1105+ )
1106+ )
1107+ if self .config .sequence_parallel :
1108+ eagle_module_input_hidden_states = scatter_to_sequence_parallel_region (
1109+ eagle_module_input_hidden_states
1110+ )
11181111
11191112 # Discard kv cache for the last parallel_draft_step - 1 tokens
11201113 # as the next ttt_step will only base on the first token in the
@@ -1393,12 +1386,12 @@ def pseudo_speculative_generate(
13931386 eagle_ids = torch .cat (
13941387 (eagle_ids , getattr (self , f"mask_token_{ i } " ).view ((1 , 1 ))), dim = - 1
13951388 )
1389+ # Pad dummy hidden_states for mask tokens
1390+ # They will be replaced by embeddings after padding
13961391 hidden_states = torch .cat ((hidden_states , hidden_states [- 1 :]), dim = 0 )
13971392 padded_eagle_ids , seq_len , padded_hidden_states = right_padding (
13981393 eagle_ids , hidden_states
13991394 )
1400- if self .config .sequence_parallel :
1401- padded_hidden_states = scatter_to_sequence_parallel_region (padded_hidden_states )
14021395 eagle_attention_mask , eagle_position_ids = get_default_attention_mask_and_position_ids (
14031396 padded_eagle_ids
14041397 )
@@ -1409,6 +1402,17 @@ def pseudo_speculative_generate(
14091402 input_ids = padded_eagle_ids ,
14101403 position_ids = eagle_position_ids ,
14111404 )
1405+ if self .config .sequence_parallel :
1406+ gathered_embedding = gather_from_sequence_parallel_region (eagle_inputs ["embedding" ])
1407+ if self .eagle_config .parallel_draft_step > 1 :
1408+ # Replace dummy hidden_states with embedding for mask tokens
1409+ padded_hidden_states [
1410+ seq_len - self .eagle_config .parallel_draft_step + 1 : seq_len
1411+ ] = gathered_embedding [
1412+ seq_len - self .eagle_config .parallel_draft_step + 1 : seq_len
1413+ ]
1414+ if self .config .sequence_parallel :
1415+ padded_hidden_states = scatter_to_sequence_parallel_region (padded_hidden_states )
14121416 eagle_inputs ["hidden_states" ] = padded_hidden_states
14131417 eagle_inputs ["attention_mask" ] = eagle_attention_mask
14141418
0 commit comments