Skip to content

Commit 8212265

Browse files
committed
use embedding for mask tokens as hidden_states
Signed-off-by: Ye Yu <[email protected]>
1 parent 13f218c commit 8212265

File tree

1 file changed

+78
-74
lines changed

1 file changed

+78
-74
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 78 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def set_multi_step_attention_mask(attn_mask, step):
201201
202202
ttt_step=0
203203
| i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
204-
| h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 |
204+
| h0 h1 h2 h3 h4 h5 h6 h7 | M0 M0 M0 M0 M0 M0 M0 M0 | M1 M1 M1 M1 M1 M1 M1 M1 |
205205
=============================================================================================
206206
F1 l1 | i1 h0 | x | | |
207207
F2 l2 | i2 h1 | x x | | |
@@ -212,28 +212,28 @@ def set_multi_step_attention_mask(attn_mask, step):
212212
F7 l7 | i7 h6 | x x x x x x x | | |
213213
-- -- | -- h7 | o o o o o o o o | | |
214214
=============================================================================================
215-
-- -- | m0 -- | | | |
216-
G2 l2 | m0 h1 | x o | x | |
217-
G3 l3 | m0 h2 | x x o | x | |
218-
G4 l4 | m0 h3 | x x x o | x | |
219-
G5 l5 | m0 h4 | x x x x o | x | |
220-
G6 l6 | m0 h5 | x x x x x o | x | |
221-
G7 l7 | m0 h6 | x x x x x x o | x | |
222-
-- -- | -- h7 | | | |
215+
-- -- | m0 M0 | | | |
216+
G2 l2 | m0 M0 | x o | x | |
217+
G3 l3 | m0 M0 | x x o | x | |
218+
G4 l4 | m0 M0 | x x x o | x | |
219+
G5 l5 | m0 M0 | x x x x o | x | |
220+
G6 l6 | m0 M0 | x x x x x o | x | |
221+
G7 l7 | m0 M0 | x x x x x x o | x | |
222+
-- -- | -- M0 | | | |
223223
=============================================================================================
224-
-- -- | m1 -- | | | |
225-
-- -- | m1 h1 | | | |
226-
H3 l3 | m1 h2 | x o o | x o | x |
227-
H4 l4 | m1 h3 | x x o o | x o | x |
228-
H5 l5 | m1 h4 | x x x o o | x o | x |
229-
H6 l6 | m1 h5 | x x x x o o | x o | x |
230-
H7 l7 | m1 h6 | x x x x x o o | x o | x |
231-
-- -- | -- h7 | | | |
224+
-- -- | m1 M0 | | | |
225+
-- -- | m1 M1 | | | |
226+
H3 l3 | m1 M1 | x o o | x o | x |
227+
H4 l4 | m1 M1 | x x o o | x o | x |
228+
H5 l5 | m1 M1 | x x x o o | x o | x |
229+
H6 l6 | m1 M1 | x x x x o o | x o | x |
230+
H7 l7 | m1 M1 | x x x x x o o | x o | x |
231+
-- -- | -- M1 | | | |
232232
233233
234234
ttt_step=1
235235
| i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
236-
| h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 |
236+
| h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | M0 M0 M0 M0 M0 M0 M0 M0 | M1 M1 M1 M1 M1 M1 M1 M1 |
237237
=======================================================================================================================
238238
-- -- | i1 -- | | | | |
239239
J2 l2 | i2 F1 | x o | x | | |
@@ -244,23 +244,23 @@ def set_multi_step_attention_mask(attn_mask, step):
244244
J7 l7 | i7 F6 | x x x x x x o | x | | |
245245
-- -- | -- F7 | | | | |
246246
=======================================================================================================================
247-
-- -- | m0 -- | | | | |
248-
-- -- | m0 -- | | | | |
249-
K3 l3 | m0 F2 | x o o | x o | x | | |
250-
K4 l4 | m0 F3 | x x o o | x o | x | |
251-
K5 l5 | m0 F4 | x x x o o | x o | x | |
252-
K6 l6 | m0 F5 | x x x x o o | x o | x | |
253-
K7 l7 | m0 F6 | x x x x x o o | x o | x | |
254-
-- -- | -- F7 | | | | |
247+
-- -- | m0 M0 | | | | |
248+
-- -- | m0 M0 | | | | |
249+
K3 l3 | m0 M0 | x o o | x o | x | | |
250+
K4 l4 | m0 M0 | x x o o | x o | x | |
251+
K5 l5 | m0 M0 | x x x o o | x o | x | |
252+
K6 l6 | m0 M0 | x x x x o o | x o | x | |
253+
K7 l7 | m0 M0 | x x x x x o o | x o | x | |
254+
-- -- | -- M0 | | | | |
255255
=======================================================================================================================
256-
-- -- | m1 -- | | | | |
257-
-- -- | m1 -- | | | | |
258-
-- -- | m1 -- | | | | |
259-
N4 l4 | m1 F3 | x | x | x | x |
260-
N5 l5 | m1 F4 | x x | x | x | x |
261-
N6 l6 | m1 F5 | x x x | x | x | x |
262-
N7 l7 | m1 F6 | x x x x | x | x | x |
263-
-- -- | -- F7 | | | | |
256+
-- -- | m1 M1 | | | | |
257+
-- -- | m1 M1 | | | | |
258+
-- -- | m1 M1 | | | | |
259+
N4 l4 | m1 M1 | x | x | x | x |
260+
N5 l5 | m1 M1 | x x | x | x | x |
261+
N6 l6 | m1 M1 | x x x | x | x | x |
262+
N7 l7 | m1 M1 | x x x x | x | x | x |
263+
-- -- | -- M1 | | | | |
264264
=======================================================================================================================
265265
""" # noqa: E501
266266
s = attn_mask.shape[-1]
@@ -782,16 +782,14 @@ def _get_eagle_module_inputs(
782782
hidden_states: torch.Tensor,
783783
attention_mask: torch.Tensor,
784784
position_ids: torch.Tensor,
785-
features: torch.Tensor | None = None,
786785
ttt_step: int = 0,
787786
parallel_draft_index: int = 0,
788787
):
789788
"""Getting EAGLE module inputs."""
790-
b = hidden_states.shape[1]
791-
h = hidden_states.shape[2]
792-
793789
# [b, 1]
794-
id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device)
790+
id_padding = torch.zeros(
791+
(input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device
792+
)
795793
padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1)
796794

797795
rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1])
@@ -816,35 +814,15 @@ def _get_eagle_module_inputs(
816814
)
817815
)
818816

819-
if self.config.sequence_parallel:
820-
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states)
821-
gathered_features = (
822-
None if features is None else gather_from_sequence_parallel_region(features)
823-
)
824-
else:
825-
gathered_hidden_states = hidden_states
826-
gathered_features = features
817+
eagle_inputs["embedding"] = self.embedding(
818+
input_ids=eagle_inputs["input_ids"],
819+
position_ids=eagle_inputs["position_ids"],
820+
)
827821

828822
eagle_inputs["hidden_states"] = (
829-
gathered_hidden_states
830-
if ttt_step == 0
831-
else torch.cat(
832-
(
833-
torch.zeros(
834-
(1, b, h),
835-
dtype=hidden_states.dtype,
836-
device=hidden_states.device,
837-
),
838-
gathered_features[:-1, :, :], # type: ignore[index]
839-
)
840-
)
823+
hidden_states if parallel_draft_index == 0 else eagle_inputs["embedding"]
841824
)
842825

843-
if self.config.sequence_parallel:
844-
eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region(
845-
eagle_inputs["hidden_states"]
846-
)
847-
848826
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(
849827
attn_mask, ttt_step + parallel_draft_index
850828
)
@@ -854,11 +832,6 @@ def _get_eagle_module_inputs(
854832
dim=0,
855833
)
856834

857-
eagle_inputs["embedding"] = self.embedding(
858-
input_ids=eagle_inputs["input_ids"],
859-
position_ids=eagle_inputs["position_ids"],
860-
)
861-
862835
return eagle_inputs
863836

864837
def _compute_eagle_loss(self, logits, labels, eagle_logits):
@@ -1086,7 +1059,6 @@ def forward(
10861059
loss = 0.0 * loss
10871060

10881061
acc = []
1089-
eagle_hidden_states_pre_norm = None
10901062
for ttt_step in range(ttt_steps):
10911063
eagle_logits = []
10921064
for i in range(self.eagle_config.parallel_draft_step):
@@ -1095,7 +1067,6 @@ def forward(
10951067
hidden_states=eagle_module_input_hidden_states,
10961068
attention_mask=attention_mask,
10971069
position_ids=position_ids,
1098-
features=eagle_hidden_states_pre_norm,
10991070
ttt_step=ttt_step,
11001071
parallel_draft_index=i,
11011072
)
@@ -1114,7 +1085,29 @@ def forward(
11141085

11151086
eagle_logits.append(eagle_logits_)
11161087
eagle_logits = torch.cat(eagle_logits, dim=0)
1117-
eagle_hidden_states_pre_norm = next_eagle_hidden_states_pre_norm
1088+
eagle_module_input_hidden_states = next_eagle_hidden_states_pre_norm
1089+
if self.config.sequence_parallel:
1090+
eagle_module_input_hidden_states = gather_from_sequence_parallel_region(
1091+
eagle_module_input_hidden_states
1092+
)
1093+
eagle_module_input_hidden_states = torch.cat(
1094+
(
1095+
torch.zeros(
1096+
(
1097+
1,
1098+
eagle_module_input_hidden_states.shape[1],
1099+
eagle_module_input_hidden_states.shape[2],
1100+
),
1101+
dtype=eagle_module_input_hidden_states.dtype,
1102+
device=eagle_module_input_hidden_states.device,
1103+
),
1104+
eagle_module_input_hidden_states[:-1, :, :],
1105+
)
1106+
)
1107+
if self.config.sequence_parallel:
1108+
eagle_module_input_hidden_states = scatter_to_sequence_parallel_region(
1109+
eagle_module_input_hidden_states
1110+
)
11181111

11191112
# Discard kv cache for the last parallel_draft_step - 1 tokens
11201113
# as the next ttt_step will only base on the first token in the
@@ -1393,12 +1386,12 @@ def pseudo_speculative_generate(
13931386
eagle_ids = torch.cat(
13941387
(eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1
13951388
)
1389+
# Pad dummy hidden_states for mask tokens
1390+
# They will be replaced by embeddings after padding
13961391
hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0)
13971392
padded_eagle_ids, seq_len, padded_hidden_states = right_padding(
13981393
eagle_ids, hidden_states
13991394
)
1400-
if self.config.sequence_parallel:
1401-
padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states)
14021395
eagle_attention_mask, eagle_position_ids = get_default_attention_mask_and_position_ids(
14031396
padded_eagle_ids
14041397
)
@@ -1409,6 +1402,17 @@ def pseudo_speculative_generate(
14091402
input_ids=padded_eagle_ids,
14101403
position_ids=eagle_position_ids,
14111404
)
1405+
if self.config.sequence_parallel:
1406+
gathered_embedding = gather_from_sequence_parallel_region(eagle_inputs["embedding"])
1407+
if self.eagle_config.parallel_draft_step > 1:
1408+
# Replace dummy hidden_states with embedding for mask tokens
1409+
padded_hidden_states[
1410+
seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len
1411+
] = gathered_embedding[
1412+
seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len
1413+
]
1414+
if self.config.sequence_parallel:
1415+
padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states)
14121416
eagle_inputs["hidden_states"] = padded_hidden_states
14131417
eagle_inputs["attention_mask"] = eagle_attention_mask
14141418

0 commit comments

Comments
 (0)