From 3e53bb8ca03524ce3e9ace71e3300ccd1c0db29f Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 19 Sep 2025 10:52:17 -0700 Subject: [PATCH 01/26] refactor_multi_step_attn_mask_for_arbitrary_step Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 62 ++++--------------- 1 file changed, 11 insertions(+), 51 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 2a0e63a3c..348f8b5c7 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -305,65 +305,25 @@ def set_multi_step_attention_mask(attn_mask, step): ======================================================================================================================= """ # noqa: E501 assert step > 1, "step should be larger than 1 in multi-step attention mask." - assert step <= 4, "Currently only a step of 4 or smaller is supported!" s = attn_mask.shape[-1] - zero_mask = torch.ones_like(attn_mask).bool() - mask_2_1 = attn_mask.clone().detach() - mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attn_mask).bool() - for i in range(1, s - 1): - mask_2_2[:, :, i, i] = False - - if step == 2: - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2), dim=-1), - ), - dim=-2, - ) - return attn_mask - - mask_3_1 = mask_2_1.clone().detach() - mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] - mask_3_2 = mask_2_2.clone().detach() - mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] - mask_3_2[:, :, 1, 0] = True - mask_3_3 = mask_2_2.clone().detach() - mask_3_3[:, :, 1, 1] = True + for iter in range(2, step + 1): + # iter starts from 2nd step + zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool() + mask_0 = attn_mask.clone().detach()[:, :, -s:, :] + mask_0[:, :, iter - 2] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() + for i in range(iter - 1, s - 1): + mask_1[:, :, i, i] = False - if step == 3: attn_mask = torch.cat( ( - torch.cat((attn_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), + torch.cat((attn_mask, zero_mask), dim=-1), + torch.cat((mask_0, mask_1), dim=-1), ), dim=-2, ) - return attn_mask - - mask_4_1 = mask_3_1.clone().detach() - mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] - mask_4_2 = mask_3_2.clone().detach() - mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] - mask_4_2[:, :, 2, 0] = True - mask_4_3 = mask_3_3.clone().detach() - mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] - mask_4_3[:, :, 2, 1] = True - mask_4_4 = mask_3_3.clone().detach() - mask_4_4[:, :, 2, 2] = True - - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), - torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), - ), - dim=-2, - ) return attn_mask From 0b531c446d0fd484476d5de175b575446abddc26 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 19 Sep 2025 11:14:29 -0700 Subject: [PATCH 02/26] minor Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 348f8b5c7..ba7177a67 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -843,7 +843,7 @@ def _get_eagle_module_inputs( rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) attn_mask = attention_mask.clone().detach() - attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] + attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] attn_mask[:, :, -1, :] = True attn_mask[:, :, :, -1] = True From a123371d6ec99a6b117c3d564777ff0f077c1e9d Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 19 Sep 2025 11:21:18 -0700 Subject: [PATCH 03/26] make new mask the same dtype and device as attn_mask Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index ba7177a67..a0978fd46 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -309,11 +309,13 @@ def set_multi_step_attention_mask(attn_mask, step): s = attn_mask.shape[-1] for iter in range(2, step + 1): # iter starts from 2nd step - zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool() + zero_mask = attn_mask.new_ones( + attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s + ).bool() mask_0 = attn_mask.clone().detach()[:, :, -s:, :] - mask_0[:, :, iter - 2] = True + mask_0[:, :, iter - 2, :] = True mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] - mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() for i in range(iter - 1, s - 1): mask_1[:, :, i, i] = False From c5a2e361a70dba080caabb975e5739d1d2705671 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 22 Sep 2025 12:17:54 -0700 Subject: [PATCH 04/26] minor Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index a0978fd46..466cbb5ad 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -845,7 +845,7 @@ def _get_eagle_module_inputs( rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) attn_mask = attention_mask.clone().detach() - attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] + attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] attn_mask[:, :, -1, :] = True attn_mask[:, :, :, -1] = True @@ -914,6 +914,55 @@ def _get_eagle_module_inputs( eagle_inputs["attention_mask"] = attn_mask eagle_inputs["position_ids"] = position_ids eagle_inputs["rotary_pos_emb"] = rotary_pos_emb + + if self.config.sequence_parallel: + gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) + else: + gathered_hidden_states = hidden_states + eagle_inputs["hidden_states"] = gathered_hidden_states + + for i in range(self.eagle_config.parallel_draft_step - 1): + eagle_inputs["input_ids"] = torch.cat( + ( + eagle_inputs["input_ids"], + torch.full( + padded_input_ids.shape, + getattr(self, f"mask_token_{i}"), + device=padded_input_ids.device, + dtype=padded_input_ids.dtype, + ), + ), + dim=-1, + ) + + eagle_inputs["hidden_states"] = torch.cat( + ( + eagle_inputs["hidden_states"], + torch.zeros( + (1 + i, b, h), dtype=hidden_states.dtype, device=hidden_states.device + ), + gathered_hidden_states[: -(1 + i)], + ), + dim=0, + ) + + eagle_inputs["position_ids"] = torch.cat( + (eagle_inputs["position_ids"], position_ids), dim=-1 + ) + + if rotary_pos_emb is not None: + eagle_inputs["rotary_pos_emb"] = torch.cat( + (eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0 + ) + + if self.config.sequence_parallel: + eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( + eagle_inputs["hidden_states"] + ) + + eagle_inputs["attention_mask"] = set_multi_step_attention_mask( + attn_mask, self.eagle_config.parallel_draft_step + ) elif features.shape[0] == hidden_states.shape[0]: eagle_inputs["input_ids"] = torch.cat( (padded_input_ids, padded_input_ids), From c6731f878d7e7d7405689da7c6994dd492fa975c Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 22 Sep 2025 12:20:31 -0700 Subject: [PATCH 05/26] revert Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 49 ------------------- 1 file changed, 49 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 466cbb5ad..139a5a648 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -914,55 +914,6 @@ def _get_eagle_module_inputs( eagle_inputs["attention_mask"] = attn_mask eagle_inputs["position_ids"] = position_ids eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - else: - gathered_hidden_states = hidden_states - eagle_inputs["hidden_states"] = gathered_hidden_states - - for i in range(self.eagle_config.parallel_draft_step - 1): - eagle_inputs["input_ids"] = torch.cat( - ( - eagle_inputs["input_ids"], - torch.full( - padded_input_ids.shape, - getattr(self, f"mask_token_{i}"), - device=padded_input_ids.device, - dtype=padded_input_ids.dtype, - ), - ), - dim=-1, - ) - - eagle_inputs["hidden_states"] = torch.cat( - ( - eagle_inputs["hidden_states"], - torch.zeros( - (1 + i, b, h), dtype=hidden_states.dtype, device=hidden_states.device - ), - gathered_hidden_states[: -(1 + i)], - ), - dim=0, - ) - - eagle_inputs["position_ids"] = torch.cat( - (eagle_inputs["position_ids"], position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0 - ) - - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, self.eagle_config.parallel_draft_step - ) elif features.shape[0] == hidden_states.shape[0]: eagle_inputs["input_ids"] = torch.cat( (padded_input_ids, padded_input_ids), From 8aa0e12875777b3a773472be232181c7d9ea33cf Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 23 Sep 2025 12:21:30 -0700 Subject: [PATCH 06/26] integrate parallel draft to eagle auto regression Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 419 ++++++------------ 1 file changed, 133 insertions(+), 286 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 139a5a648..f62d8e76c 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -194,78 +194,11 @@ def set_multi_step_attention_mask(attn_mask, step): h0 h1 h2 h3 h4 h5 h6 h7 (base hidden_states) l0 l1 l2 l3 l4 l5 l6 l7 (base labels) + ttt_step=2 + parallel_draft_step=2 - (1st) | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | - ========================================= - f1 l1 | i1 h0 | x | - f2 l2 | i2 h1 | x x | - f3 l3 | i3 h2 | x x x | - f4 l4 | i4 h3 | x x x x | - f5 l5 | i5 h4 | x x x x x | - f6 l6 | i6 h5 | x x x x x x | - f7 l7 | i7 h6 | x x x x x x x | - -- -- | -- h7 | o o o o o o o o | - ========================================= - - - (2nd) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | - =================================================================== - F1 l1 | i1 h0 | x | | - F2 l2 | i2 h1 | x x | | - F3 l3 | i3 h2 | x x x | | - F4 l4 | i4 h3 | x x x x | | - F5 l5 | i5 h4 | x x x x x | | - F6 l6 | i6 h5 | x x x x x x | | - F7 l7 | i7 h6 | x x x x x x x | | - -- -- | -- h7 | o o o o o o o o | | - =================================================================== - -- -- | i1 -- | | | - G2 l2 | i2 F1 | x o | x | - G3 l3 | i3 F2 | x x o | x | - G4 l4 | i4 F3 | x x x o | x | - G5 l5 | i5 F4 | x x x x o | x | - G6 l6 | i6 F5 | x x x x x o | x | - G7 l7 | i7 F6 | x x x x x x o | x | - -- -- | -- F7 | | | - =================================================================== - - - (3rd) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- -- G2 G3 G4 G5 G6 G7 | - ============================================================================================= - F1 l1 | i1 h0 | x | | | - F2 l2 | i2 h1 | x x | | | - F3 l3 | i3 h2 | x x x | | | - F4 l4 | i4 h3 | x x x x | | | - F5 l5 | i5 h4 | x x x x x | | | - F6 l6 | i6 h5 | x x x x x x | | | - F7 l7 | i7 h6 | x x x x x x x | | | - -- -- | -- h7 | o o o o o o o o | | | - ============================================================================================= - -- -- | i1 -- | | | | - G2 l2 | i2 F1 | x o | x | | - G3 l3 | i3 F2 | x x o | x | | - G4 l4 | i4 F3 | x x x o | x | | - G5 l5 | i5 F4 | x x x x o | x | | - G6 l6 | i6 F5 | x x x x x o | x | | - G7 l7 | i7 F6 | x x x x x x o | x | | - -- -- | -- F7 | | | | - ============================================================================================= - -- -- | i1 -- | | | | - -- -- | i2 -- | | | | - H3 l3 | i3 G2 | x o o | x o | x | - H4 l4 | i4 G3 | x x o o | x o | x | - H5 l5 | i5 G4 | x x x o o | x o | x | - H6 l6 | i6 G5 | x x x x o o | x o | x | - H7 l7 | i7 G6 | x x x x x o o | x o | x | - -- -- | -- G7 | | | | - ============================================================================================= - - - (4th) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- -- H3 H4 H5 H6 H7 | + | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | + (out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 | ======================================================================================================================= F1 l1 | i1 h0 | x | | | | F2 l2 | i2 h1 | x x | | | | @@ -277,13 +210,13 @@ def set_multi_step_attention_mask(attn_mask, step): -- -- | -- h7 | o o o o o o o o | | | | ======================================================================================================================= -- -- | i1 -- | | | | | - G2 l2 | i2 F1 | x o | x | | | - G3 l3 | i3 F2 | x x o | x | | | - G4 l4 | i4 F3 | x x x o | x | | | - G5 l5 | i5 F4 | x x x x o | x | | | - G6 l6 | i6 F5 | x x x x x o | x | | | - G7 l7 | i7 F6 | x x x x x x o | x | | | - -- -- | -- F7 | | | | | + G2 l2 | i2 h1 | x o | x | | | + G3 l3 | i3 h2 | x x o | x | | | + G4 l4 | i4 h3 | x x x o | x | | | + G5 l5 | i5 h4 | x x x x o | x | | | + G6 l6 | i6 h5 | x x x x x o | x | | | + G7 l7 | i7 h6 | x x x x x x o | x | | | + -- -- | -- h7 | | | | | ======================================================================================================================= -- -- | i1 -- | | | | | -- -- | i2 -- | | | | | @@ -294,18 +227,16 @@ def set_multi_step_attention_mask(attn_mask, step): H7 l7 | i7 G6 | x x x x x o o | x o | x | | -- -- | -- G7 | | | | | ======================================================================================================================= - -- -- | i1 -- | | | | | - -- -- | i2 -- | | | | | - -- -- | i3 -- | | | | | - K4 l4 | i4 H3 | x | x | x | x | - K5 l5 | i5 H4 | x x | x | x | x | - K6 l6 | i6 H5 | x x x | x | x | x | - K7 l7 | i7 H6 | x x x x | x | x | x | - -- -- | -- H7 | | | | | + -- -- | m0 -- | | | | | + -- -- | m0 -- | | | | | + -- -- | m0 -- | | | | | + K4 l4 | m0 G3 | x | x | x | x | + K5 l5 | m0 G4 | x x | x | x | x | + K6 l6 | m0 G5 | x x x | x | x | x | + K7 l7 | m0 G6 | x x x x | x | x | x | + -- -- | -- G7 | | | | | ======================================================================================================================= """ # noqa: E501 - assert step > 1, "step should be larger than 1 in multi-step attention mask." - s = attn_mask.shape[-1] for iter in range(2, step + 1): # iter starts from 2nd step @@ -833,10 +764,12 @@ def _get_eagle_module_inputs( attention_mask: torch.Tensor, position_ids: torch.Tensor, features: torch.Tensor | None = None, + ttt_step: int = 1, ): """Getting EAGLE module inputs.""" b = hidden_states.shape[1] h = hidden_states.shape[2] + s = input_ids.shape[1] # [b, 1] id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device) @@ -851,28 +784,35 @@ def _get_eagle_module_inputs( eagle_inputs = {} - if self.eagle_config.parallel_draft_step > 1: - eagle_inputs["input_ids"] = padded_input_ids - eagle_inputs["position_ids"] = position_ids - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - else: - gathered_hidden_states = hidden_states - eagle_inputs["hidden_states"] = gathered_hidden_states + eagle_inputs["input_ids"] = torch.empty( + 0, dtype=padded_input_ids.dtype, device=padded_input_ids.device + ) + eagle_inputs["position_ids"] = torch.empty( + 0, dtype=position_ids.dtype, device=position_ids.device + ) + eagle_inputs["rotary_pos_emb"] = rotary_pos_emb + if self.config.sequence_parallel: + gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) + gathered_features = ( + None if features is None else gather_from_sequence_parallel_region(features) + ) + else: + gathered_hidden_states = hidden_states + gathered_features = features + eagle_inputs["hidden_states"] = torch.empty( + 0, dtype=gathered_hidden_states.dtype, device=gathered_hidden_states.device + ) - for i in range(self.eagle_config.parallel_draft_step - 1): + for step in range(ttt_step): + for i in range(self.eagle_config.parallel_draft_step): eagle_inputs["input_ids"] = torch.cat( ( eagle_inputs["input_ids"], - torch.full( + padded_input_ids + if i == 0 + else torch.full( padded_input_ids.shape, - getattr(self, f"mask_token_{i}"), + getattr(self, f"mask_token_{i - 1}"), device=padded_input_ids.device, dtype=padded_input_ids.dtype, ), @@ -880,13 +820,25 @@ def _get_eagle_module_inputs( dim=-1, ) + if step > 0: + feature = gathered_features[ + (step * self.eagle_config.parallel_draft_step - 1) * s : step + * self.eagle_config.parallel_draft_step + * s + ] eagle_inputs["hidden_states"] = torch.cat( ( eagle_inputs["hidden_states"], - torch.zeros( - (1 + i, b, h), dtype=hidden_states.dtype, device=hidden_states.device + gathered_hidden_states + if step == 0 + else ( + torch.zeros( + (1, b, h), + dtype=hidden_states.dtype, + device=hidden_states.device, + ), + feature[:-1, :, :], ), - gathered_hidden_states[: -(1 + i)], ), dim=0, ) @@ -900,129 +852,14 @@ def _get_eagle_module_inputs( (eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0 ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, self.eagle_config.parallel_draft_step - ) - elif features is None: - eagle_inputs["input_ids"] = padded_input_ids - eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = attn_mask - eagle_inputs["position_ids"] = position_ids - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - elif features.shape[0] == hidden_states.shape[0]: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids), - dim=-1, - ) - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, - ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 2) - eagle_inputs["position_ids"] = torch.cat((position_ids, position_ids), dim=-1) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=0) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - elif features.shape[0] == hidden_states.shape[0] * 2: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids, padded_input_ids), - dim=-1, - ) - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, - ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 3) - eagle_inputs["position_ids"] = torch.cat( - (position_ids, position_ids, position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), - dim=0, - ) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - else: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids, padded_input_ids, padded_input_ids), - dim=-1, - ) - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, - ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 4) - eagle_inputs["position_ids"] = torch.cat( - (position_ids, position_ids, position_ids, position_ids), dim=-1 + if self.config.sequence_parallel: + eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( + eagle_inputs["hidden_states"] ) - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), - dim=0, - ) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None + eagle_inputs["attention_mask"] = set_multi_step_attention_mask( + attn_mask, ttt_step * self.eagle_config.parallel_draft_step + ) eagle_inputs["embedding"] = self.embedding( input_ids=eagle_inputs["input_ids"], @@ -1258,28 +1095,26 @@ def forward( loss = self.compute_language_model_loss(labels, logits_sbh) loss = 0.0 * loss - if self.eagle_config.parallel_draft_step > 1: - for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]] - loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) - loss_ = loss_[:, i:] - loss[:, i + 1 :] += 1.0 * loss_ - return loss - - loss_0 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_0) - loss[:, 1:] += self.eagle_loss_decay_factor * loss_0 + eagle_logits_0 = eagle_logits_0[-labels.shape[1] * self.eagle_config.parallel_draft_step :] + for i in range(self.eagle_config.parallel_draft_step): + eagle_logits = eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]] + loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) + loss_ = loss_[:, i:] + loss[:, i + 1 :] += self.eagle_loss_decay_factor * loss_ if self.eagle_report_acc and not self.training: acc = [] with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_0[:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 1:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) + for i in range(self.eagle_config.parallel_draft_step): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]] + ) + gathered_logits = gathered_logits[i:-1] + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = torch.eq(labels[:, i + 1 :], eagle_top1).sum() / eagle_top1.numel() + acc.append(top1_p) if get_tensor_model_parallel_rank() == 0: print( @@ -1294,6 +1129,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, features=eagle_hidden_states_0_pre_norm, + ttt_step=2, ) _, eagle_logits_2x, eagle_hidden_states_2x_pre_norm = self._eagle_forward( @@ -1303,24 +1139,27 @@ def forward( packed_seq_params=packed_seq_params, **(extra_block_kwargs or {}), ) - eagle_logits_1 = eagle_logits_2x[-labels.shape[1] :, :, :] + eagle_logits_1 = eagle_logits_2x[-labels.shape[1] * self.eagle_config.parallel_draft_step :] - loss_1 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_1) - # [b, s - 2] - loss_1 = loss_1[:, 1:] - loss[:, 2:] += self.eagle_loss_decay_factor**2 * loss_1 + for i in range(self.eagle_config.parallel_draft_step): + eagle_logits = eagle_logits_1[i * labels.shape[1] : (i + 1) * labels.shape[1]] + loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) + loss_ = loss_[:, i + 1 :] + loss[:, i + 2 :] += self.eagle_loss_decay_factor**2 * loss_ if self.eagle_report_acc and not self.training: acc = [] with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_1[1:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 2:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) + for i in range(self.eagle_config.parallel_draft_step): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits_1[i * labels.shape[1] : (i + 1) * labels.shape[1]] + ) + gathered_logits = gathered_logits[i + 1 : -1] + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = torch.eq(labels[:, i + 2 :], eagle_top1).sum() / eagle_top1.numel() + acc.append(top1_p) if get_tensor_model_parallel_rank() == 0: print( @@ -1335,6 +1174,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, features=eagle_hidden_states_2x_pre_norm, + ttt_step=3, ) _, eagle_logits_3x, eagle_hidden_states_3x_pre_norm = self._eagle_forward( @@ -1345,24 +1185,27 @@ def forward( **(extra_block_kwargs or {}), ) - eagle_logits_2 = eagle_logits_3x[-labels.shape[1] :, :, :] + eagle_logits_2 = eagle_logits_3x[-labels.shape[1] * self.eagle_config.parallel_draft_step :] - loss_2 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_2) - # [b, s - 3] - loss_2 = loss_2[:, 2:] - loss[:, 3:] += self.eagle_loss_decay_factor**3 * loss_2 + for i in range(self.eagle_config.parallel_draft_step): + eagle_logits = eagle_logits_2[i * labels.shape[1] : (i + 1) * labels.shape[1]] + loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) + loss_ = loss_[:, i + 2 :] + loss[:, i + 3 :] += self.eagle_loss_decay_factor**3 * loss_ if self.eagle_report_acc and not self.training: acc = [] with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_2[2:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 3:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) + for i in range(self.eagle_config.parallel_draft_step): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits_2[i * labels.shape[1] : (i + 1) * labels.shape[1]] + ) + gathered_logits = gathered_logits[i + 2 : -1] + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = torch.eq(labels[:, i + 3 :], eagle_top1).sum() / eagle_top1.numel() + acc.append(top1_p) if get_tensor_model_parallel_rank() == 0: print( @@ -1377,6 +1220,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, features=eagle_hidden_states_3x_pre_norm, + ttt_step=4, ) _, eagle_logits_4x, eagle_hidden_states_4x_pre_norm = self._eagle_forward( @@ -1387,24 +1231,27 @@ def forward( **(extra_block_kwargs or {}), ) - eagle_logits_3 = eagle_logits_4x[-labels.shape[1] :, :, :] + eagle_logits_3 = eagle_logits_4x[-labels.shape[1] * self.eagle_config.parallel_draft_step :] - loss_3 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_3) - # [b, s - 4] - loss_3 = loss_3[:, 3:] - loss[:, 4:] += self.eagle_loss_decay_factor**4 * loss_3 + for i in range(self.eagle_config.parallel_draft_step): + eagle_logits = eagle_logits_3[i * labels.shape[1] : (i + 1) * labels.shape[1]] + loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) + loss_ = loss_[:, i + 3 :] + loss[:, i + 4 :] += self.eagle_loss_decay_factor**4 * loss_ if self.eagle_report_acc and not self.training: acc = [] with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_3[3:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 4:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) + for i in range(self.eagle_config.parallel_draft_step): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits_3[i * labels.shape[1] : (i + 1) * labels.shape[1]] + ) + gathered_logits = gathered_logits[i + 3 : -1] + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = torch.eq(labels[:, i + 4 :], eagle_top1).sum() / eagle_top1.numel() + acc.append(top1_p) if get_tensor_model_parallel_rank() == 0: print( From ff04f43386ff301ac98f9038890db1c709adc511 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 23 Sep 2025 14:00:37 -0700 Subject: [PATCH 07/26] debug Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index f62d8e76c..deefe698b 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -790,7 +790,9 @@ def _get_eagle_module_inputs( eagle_inputs["position_ids"] = torch.empty( 0, dtype=position_ids.dtype, device=position_ids.device ) - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb + eagle_inputs["rotary_pos_emb"] = torch.empty( + 0, dtype=rotary_pos_emb.dtype, device=rotary_pos_emb.device + ) if self.config.sequence_parallel: gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) gathered_features = ( @@ -831,13 +833,15 @@ def _get_eagle_module_inputs( eagle_inputs["hidden_states"], gathered_hidden_states if step == 0 - else ( - torch.zeros( - (1, b, h), - dtype=hidden_states.dtype, - device=hidden_states.device, - ), - feature[:-1, :, :], + else torch.cat( + ( + torch.zeros( + (1, b, h), + dtype=hidden_states.dtype, + device=hidden_states.device, + ), + feature[:-1, :, :], + ) ), ), dim=0, From e5f2403624c628438f1617e3a4cf1c17d5dda95b Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 23 Sep 2025 14:24:01 -0700 Subject: [PATCH 08/26] fix pseudo spec generate for parallel draft Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 41 ++++++++----------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index deefe698b..da4c25cf5 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1491,12 +1491,11 @@ def pseudo_speculative_generate( draft_tokens = [] for _ in range(steps): - if self.eagle_config.parallel_draft_step > 1: - for i in range(self.eagle_config.parallel_draft_step - 1): - eagle_ids = torch.cat( - (eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1 - ) - hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0) + for i in range(self.eagle_config.parallel_draft_step - 1): + eagle_ids = torch.cat( + (eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1 + ) + hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0) padded_eagle_ids, seq_len, padded_hidden_states = right_padding( eagle_ids, hidden_states ) @@ -1530,31 +1529,25 @@ def pseudo_speculative_generate( ) eagle_next_hidden_states_input = eagle_next_hidden_states_input[:seq_len, :, :] - if self.eagle_config.parallel_draft_step > 1: - draft_token = ( - gather_from_tensor_model_parallel_region(eagle_logits)[ - -self.eagle_config.parallel_draft_step :, :, : - ] - .argmax(dim=-1) - .transpose(0, 1) - ) - else: - draft_token = ( - gather_from_tensor_model_parallel_region(eagle_logits)[-1:, :, :] - .argmax(dim=-1) - .transpose(0, 1) - ) + draft_token = ( + gather_from_tensor_model_parallel_region(eagle_logits)[ + -self.eagle_config.parallel_draft_step :, :, : + ] + .argmax(dim=-1) + .transpose(0, 1) + ) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: draft_token += self.eagle_module.d2t[draft_token] - if self.eagle_config.parallel_draft_step > 1: - return base_token, draft_token - draft_tokens.append(draft_token) eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1) hidden_states = torch.cat( - (hidden_states, eagle_next_hidden_states_input[-1:, :, :]), dim=0 + ( + hidden_states, + eagle_next_hidden_states_input[-self.eagle_config.parallel_draft_step :, :, :], + ), + dim=0, ) draft_tokens = torch.cat(draft_tokens, dim=-1) From e032a527cb32309d74b5454a3e624c861a5588c2 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 24 Sep 2025 13:52:21 -0700 Subject: [PATCH 09/26] implement kv cache (inference_context) for eagle training Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index da4c25cf5..e49c0e7c5 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -26,6 +26,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel @@ -240,23 +241,15 @@ def set_multi_step_attention_mask(attn_mask, step): s = attn_mask.shape[-1] for iter in range(2, step + 1): # iter starts from 2nd step - zero_mask = attn_mask.new_ones( - attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s - ).bool() - mask_0 = attn_mask.clone().detach()[:, :, -s:, :] + mask_0 = attn_mask.clone().detach() mask_0[:, :, iter - 2, :] = True mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() for i in range(iter - 1, s - 1): mask_1[:, :, i, i] = False - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask), dim=-1), - torch.cat((mask_0, mask_1), dim=-1), - ), - dim=-2, - ) + attn_mask = torch.cat((mask_0, mask_1), dim=-1) + return attn_mask @@ -516,6 +509,7 @@ def forward( rotary_pos_emb: torch.Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + inference_context: StaticInferenceContext | None = None, extra_block_kwargs: dict | None = None, ) -> torch.Tensor: """Forward function.""" @@ -556,6 +550,7 @@ def forward( inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, packed_seq_params=packed_seq_params, + inference_context=inference_context, **(extra_block_kwargs or {}), ) @@ -962,6 +957,7 @@ def _eagle_forward( output_weight, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + inference_context: StaticInferenceContext | None = None, extra_block_kwargs: dict | None = None, ): eagle_hidden_states, eagle_hidden_states_pre_final_layernorm = self.eagle_module( @@ -971,15 +967,23 @@ def _eagle_forward( eagle_inputs["rotary_pos_emb"], inference_params=inference_params, packed_seq_params=packed_seq_params, + inference_context=inference_context, **(extra_block_kwargs or {}), ) + # Update inference_context.sequence_len_offset after each call of eagle_module + inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] + if hasattr(self.eagle_module, "eagle_output_layer"): eagle_logits, _ = self.eagle_module.eagle_output_layer(eagle_hidden_states) else: eagle_logits, _ = self.output_layer(eagle_hidden_states, weight=output_weight) - return eagle_hidden_states, eagle_logits, eagle_hidden_states_pre_final_layernorm + return ( + eagle_hidden_states, + eagle_logits, + eagle_hidden_states_pre_final_layernorm, + ) def forward( self, @@ -1033,6 +1037,11 @@ def forward( output_weight = self.shared_embedding_or_output_weight() logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight) + # EAGLE kv cache + eagle_inference_context = StaticInferenceContext( + input_ids.shape[0], input_ids.shape[1] * self.eagle_config.parallel_draft_step * 4 + ) + if self.eagle_offline: eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state @@ -1075,6 +1084,7 @@ def forward( output_weight, inference_params=inference_params, packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) @@ -1141,6 +1151,7 @@ def forward( output_weight, inference_params=inference_params, packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) eagle_logits_1 = eagle_logits_2x[-labels.shape[1] * self.eagle_config.parallel_draft_step :] @@ -1186,6 +1197,7 @@ def forward( output_weight, inference_params=inference_params, packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) @@ -1232,6 +1244,7 @@ def forward( output_weight, inference_params=inference_params, packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) @@ -1443,6 +1456,8 @@ def pseudo_speculative_generate( ): """Pseudo generate of the EAGLE GPTModel. + This function does not support kv cache as sequence parallel may be enabled. + Returns: base_token (torch.Tensor): token from base model draft_tokens (torch.Tensor): draft tokens from eagle module From 5328b55c201a3628f87abd20bded45460d532a58 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 25 Sep 2025 14:03:49 -0700 Subject: [PATCH 10/26] allow groundtruth mismatch in AcceptanceRateValidation Signed-off-by: Ye Yu --- modelopt/torch/speculative/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 648cc8163..116b4eabc 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -331,7 +331,9 @@ def validate( if ground_truth is None: ground_truth = self.get_ground_truth(input_ids, osl) - ground_truth = self.check_data_consistency_across_ranks(ground_truth) + ground_truth = self.check_data_consistency_across_ranks( + ground_truth, fail_when_mismatch=False + ) cnt = 0 draft_tokens = None From be260f6ab90018212374a65d1125fccd7f36204d Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 25 Sep 2025 14:12:39 -0700 Subject: [PATCH 11/26] set fail_when_mismatch to False Signed-off-by: Ye Yu --- modelopt/torch/speculative/utils.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 116b4eabc..96fb56243 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -290,7 +290,7 @@ def check_draft(self, ground_truth, input_ids, draft_tokens): return input_ids - def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=True): + def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=False): """This function checks the data consistency across all ranks in the group. Use rank 0 data as the golden set to broadcast to all ranks. @@ -331,9 +331,7 @@ def validate( if ground_truth is None: ground_truth = self.get_ground_truth(input_ids, osl) - ground_truth = self.check_data_consistency_across_ranks( - ground_truth, fail_when_mismatch=False - ) + ground_truth = self.check_data_consistency_across_ranks(ground_truth) cnt = 0 draft_tokens = None @@ -348,16 +346,12 @@ def validate( if tree_paths: input_id, draft_tokens, pred_tokens = self.model.tree_decode(input_ids, tree=tree) - pred_tokens = self.check_data_consistency_across_ranks( - pred_tokens, fail_when_mismatch=False - ) + pred_tokens = self.check_data_consistency_across_ranks(pred_tokens) else: input_id, draft_tokens = self.model.pseudo_speculative_generate( input_ids, steps=steps ) - draft_tokens = self.check_data_consistency_across_ranks( - draft_tokens, fail_when_mismatch=False - ) + draft_tokens = self.check_data_consistency_across_ranks(draft_tokens) input_id = self.check_data_consistency_across_ranks(input_id) input_ids = torch.cat((input_ids, input_id), dim=-1) From 035d5ef6dffa4ec0296fecc78118d024c089136a Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 26 Sep 2025 09:51:35 -0700 Subject: [PATCH 12/26] debug Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 179 ++++++++++-------- 1 file changed, 96 insertions(+), 83 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index e49c0e7c5..7b3cfeb47 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -760,6 +760,7 @@ def _get_eagle_module_inputs( position_ids: torch.Tensor, features: torch.Tensor | None = None, ttt_step: int = 1, + parallel_draft_step: int = 1, ): """Getting EAGLE module inputs.""" b = hidden_states.shape[1] @@ -801,7 +802,7 @@ def _get_eagle_module_inputs( ) for step in range(ttt_step): - for i in range(self.eagle_config.parallel_draft_step): + for i in range(parallel_draft_step): eagle_inputs["input_ids"] = torch.cat( ( eagle_inputs["input_ids"], @@ -818,11 +819,7 @@ def _get_eagle_module_inputs( ) if step > 0: - feature = gathered_features[ - (step * self.eagle_config.parallel_draft_step - 1) * s : step - * self.eagle_config.parallel_draft_step - * s - ] + feature = gathered_features[-s:] eagle_inputs["hidden_states"] = torch.cat( ( eagle_inputs["hidden_states"], @@ -857,7 +854,7 @@ def _get_eagle_module_inputs( ) eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, ttt_step * self.eagle_config.parallel_draft_step + attn_mask, (ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step ) eagle_inputs["embedding"] = self.embedding( @@ -1072,21 +1069,28 @@ def forward( # In calibration mode, we want to make sure all weights have been exercised. # This makes sure all quantized weights have amax calibrated if inference_params is None or self.calibration_mode: - eagle_inputs_0 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) + eagle_logits_0 = [] + for i in range(self.eagle_config.parallel_draft_step): + eagle_inputs_0 = self._get_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ttt_step=1, + parallel_draft_step=i + 1, + ) - _, eagle_logits_0, eagle_hidden_states_0_pre_norm = self._eagle_forward( - eagle_inputs_0, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) + _, eagle_logits_, eagle_hidden_states_0_pre_norm = self._eagle_forward( + eagle_inputs_0, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + **(extra_block_kwargs or {}), + ) + + eagle_logits_0.append(eagle_logits_[-input_ids.shape[1] :]) + eagle_logits_0 = torch.cat(eagle_logits_0, dim=0) # If labels are not provided, return the original logits. We only return after # all eagle weights have been exercised for quantization calibration purpose. @@ -1109,9 +1113,8 @@ def forward( loss = self.compute_language_model_loss(labels, logits_sbh) loss = 0.0 * loss - eagle_logits_0 = eagle_logits_0[-labels.shape[1] * self.eagle_config.parallel_draft_step :] for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]] + eagle_logits = eagle_logits_0[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) loss_ = loss_[:, i:] loss[:, i + 1 :] += self.eagle_loss_decay_factor * loss_ @@ -1121,7 +1124,7 @@ def forward( with torch.no_grad(): for i in range(self.eagle_config.parallel_draft_step): gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]] + eagle_logits_0[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] ) gathered_logits = gathered_logits[i:-1] eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) @@ -1137,27 +1140,31 @@ def forward( ) # Second round of EAGLE loss - eagle_inputs_1 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_0_pre_norm, - ttt_step=2, - ) + eagle_logits_1 = [] + for i in range(self.eagle_config.parallel_draft_step): + eagle_inputs_1 = self._get_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + features=eagle_hidden_states_0_pre_norm, + ttt_step=2, + parallel_draft_step=i + 1, + ) - _, eagle_logits_2x, eagle_hidden_states_2x_pre_norm = self._eagle_forward( - eagle_inputs_1, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) - eagle_logits_1 = eagle_logits_2x[-labels.shape[1] * self.eagle_config.parallel_draft_step :] + _, eagle_logits_, eagle_hidden_states_2x_pre_norm = self._eagle_forward( + eagle_inputs_1, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + **(extra_block_kwargs or {}), + ) + eagle_logits_1.append(eagle_logits_[-input_ids.shape[1] :]) + eagle_logits_1 = torch.cat(eagle_logits_1, dim=0) for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_1[i * labels.shape[1] : (i + 1) * labels.shape[1]] + eagle_logits = eagle_logits_1[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) loss_ = loss_[:, i + 1 :] loss[:, i + 2 :] += self.eagle_loss_decay_factor**2 * loss_ @@ -1167,7 +1174,7 @@ def forward( with torch.no_grad(): for i in range(self.eagle_config.parallel_draft_step): gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_1[i * labels.shape[1] : (i + 1) * labels.shape[1]] + eagle_logits_1[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] ) gathered_logits = gathered_logits[i + 1 : -1] eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) @@ -1183,28 +1190,31 @@ def forward( ) # Third EAGLE loss - eagle_inputs_2 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_2x_pre_norm, - ttt_step=3, - ) - - _, eagle_logits_3x, eagle_hidden_states_3x_pre_norm = self._eagle_forward( - eagle_inputs_2, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) + eagle_logits_2 = [] + for i in range(self.eagle_config.parallel_draft_step): + eagle_inputs_2 = self._get_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + features=eagle_hidden_states_2x_pre_norm, + ttt_step=3, + parallel_draft_step=i + 1, + ) - eagle_logits_2 = eagle_logits_3x[-labels.shape[1] * self.eagle_config.parallel_draft_step :] + _, eagle_logits_, eagle_hidden_states_3x_pre_norm = self._eagle_forward( + eagle_inputs_2, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + **(extra_block_kwargs or {}), + ) + eagle_logits_2.append(eagle_logits_[-input_ids.shape[1] :]) + eagle_logits_2 = torch.cat(eagle_logits_2, dim=0) for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_2[i * labels.shape[1] : (i + 1) * labels.shape[1]] + eagle_logits = eagle_logits_2[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) loss_ = loss_[:, i + 2 :] loss[:, i + 3 :] += self.eagle_loss_decay_factor**3 * loss_ @@ -1214,7 +1224,7 @@ def forward( with torch.no_grad(): for i in range(self.eagle_config.parallel_draft_step): gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_2[i * labels.shape[1] : (i + 1) * labels.shape[1]] + eagle_logits_2[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] ) gathered_logits = gathered_logits[i + 2 : -1] eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) @@ -1230,28 +1240,31 @@ def forward( ) # Forth EAGLE loss - eagle_inputs_3 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_3x_pre_norm, - ttt_step=4, - ) - - _, eagle_logits_4x, eagle_hidden_states_4x_pre_norm = self._eagle_forward( - eagle_inputs_3, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) + eagle_logits_3 = [] + for i in range(self.eagle_config.parallel_draft_step): + eagle_inputs_3 = self._get_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + features=eagle_hidden_states_3x_pre_norm, + ttt_step=4, + parallel_draft_step=i + 1, + ) - eagle_logits_3 = eagle_logits_4x[-labels.shape[1] * self.eagle_config.parallel_draft_step :] + _, eagle_logits_, eagle_hidden_states_4x_pre_norm = self._eagle_forward( + eagle_inputs_3, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + **(extra_block_kwargs or {}), + ) + eagle_logits_3.append(eagle_logits_[-input_ids.shape[1] :]) + eagle_logits_3 = torch.cat(eagle_logits_3, dim=0) for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_3[i * labels.shape[1] : (i + 1) * labels.shape[1]] + eagle_logits = eagle_logits_3[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) loss_ = loss_[:, i + 3 :] loss[:, i + 4 :] += self.eagle_loss_decay_factor**4 * loss_ @@ -1261,7 +1274,7 @@ def forward( with torch.no_grad(): for i in range(self.eagle_config.parallel_draft_step): gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_3[i * labels.shape[1] : (i + 1) * labels.shape[1]] + eagle_logits_3[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] ) gathered_logits = gathered_logits[i + 3 : -1] eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) From 72b3aa0928f187099f7627b526ff2bc5428a60a4 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 26 Sep 2025 13:42:54 -0700 Subject: [PATCH 13/26] debug Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 82 ++++++------------- 1 file changed, 27 insertions(+), 55 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 7b3cfeb47..fc8b00f59 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -780,15 +780,9 @@ def _get_eagle_module_inputs( eagle_inputs = {} - eagle_inputs["input_ids"] = torch.empty( - 0, dtype=padded_input_ids.dtype, device=padded_input_ids.device - ) - eagle_inputs["position_ids"] = torch.empty( - 0, dtype=position_ids.dtype, device=position_ids.device - ) - eagle_inputs["rotary_pos_emb"] = torch.empty( - 0, dtype=rotary_pos_emb.dtype, device=rotary_pos_emb.device - ) + eagle_inputs["position_ids"] = position_ids + eagle_inputs["rotary_pos_emb"] = rotary_pos_emb + if self.config.sequence_parallel: gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) gathered_features = ( @@ -797,56 +791,34 @@ def _get_eagle_module_inputs( else: gathered_hidden_states = hidden_states gathered_features = features - eagle_inputs["hidden_states"] = torch.empty( - 0, dtype=gathered_hidden_states.dtype, device=gathered_hidden_states.device - ) - for step in range(ttt_step): - for i in range(parallel_draft_step): - eagle_inputs["input_ids"] = torch.cat( - ( - eagle_inputs["input_ids"], - padded_input_ids - if i == 0 - else torch.full( - padded_input_ids.shape, - getattr(self, f"mask_token_{i - 1}"), - device=padded_input_ids.device, - dtype=padded_input_ids.dtype, - ), - ), - dim=-1, - ) + eagle_inputs["input_ids"] = ( + padded_input_ids + if parallel_draft_step == 1 + else torch.full( + padded_input_ids.shape, + getattr(self, f"mask_token_{parallel_draft_step - 2}"), + device=padded_input_ids.device, + dtype=padded_input_ids.dtype, + ) + ) - if step > 0: - feature = gathered_features[-s:] - eagle_inputs["hidden_states"] = torch.cat( - ( - eagle_inputs["hidden_states"], - gathered_hidden_states - if step == 0 - else torch.cat( - ( - torch.zeros( - (1, b, h), - dtype=hidden_states.dtype, - device=hidden_states.device, - ), - feature[:-1, :, :], - ) - ), + if gathered_features is not None: + feature = gathered_features[-s:] + eagle_inputs["hidden_states"] = ( + gathered_hidden_states + if ttt_step == 1 + else torch.cat( + ( + torch.zeros( + (1, b, h), + dtype=hidden_states.dtype, + device=hidden_states.device, ), - dim=0, - ) - - eagle_inputs["position_ids"] = torch.cat( - (eagle_inputs["position_ids"], position_ids), dim=-1 + feature[:-1, :, :], ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0 - ) + ) + ) if self.config.sequence_parallel: eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( From 460bc4e73c4fc9ac5a317c34d7b057e6b52aed0b Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 26 Sep 2025 14:08:36 -0700 Subject: [PATCH 14/26] debug Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index fc8b00f59..32ce59e6d 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -781,16 +781,6 @@ def _get_eagle_module_inputs( eagle_inputs = {} eagle_inputs["position_ids"] = position_ids - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = ( - None if features is None else gather_from_sequence_parallel_region(features) - ) - else: - gathered_hidden_states = hidden_states - gathered_features = features eagle_inputs["input_ids"] = ( padded_input_ids @@ -803,6 +793,14 @@ def _get_eagle_module_inputs( ) ) + if self.config.sequence_parallel: + gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) + gathered_features = ( + None if features is None else gather_from_sequence_parallel_region(features) + ) + else: + gathered_hidden_states = hidden_states + gathered_features = features if gathered_features is not None: feature = gathered_features[-s:] eagle_inputs["hidden_states"] = ( @@ -829,6 +827,12 @@ def _get_eagle_module_inputs( attn_mask, (ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step ) + eagle_inputs["rotary_pos_emb"] = torch.cat( + [rotary_pos_emb] + * ((ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step), + dim=0, + ) + eagle_inputs["embedding"] = self.embedding( input_ids=eagle_inputs["input_ids"], position_ids=eagle_inputs["position_ids"], From 72067dd777ec1786ef6a2fed154cd1bdedeeab7c Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 26 Sep 2025 14:18:01 -0700 Subject: [PATCH 15/26] debug Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 32ce59e6d..71e3019c6 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -945,7 +945,8 @@ def _eagle_forward( ) # Update inference_context.sequence_len_offset after each call of eagle_module - inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] + if inference_context is not None: + inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] if hasattr(self.eagle_module, "eagle_output_layer"): eagle_logits, _ = self.eagle_module.eagle_output_layer(eagle_hidden_states) From 890a7185e3d69ce3c647cb1f2da2d9fee8f22bff Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 29 Sep 2025 09:55:48 -0700 Subject: [PATCH 16/26] remove redundant index as now eagle_logits_ has a length of seq-len Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 71e3019c6..955c799f5 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1066,7 +1066,7 @@ def forward( **(extra_block_kwargs or {}), ) - eagle_logits_0.append(eagle_logits_[-input_ids.shape[1] :]) + eagle_logits_0.append(eagle_logits_) eagle_logits_0 = torch.cat(eagle_logits_0, dim=0) # If labels are not provided, return the original logits. We only return after @@ -1137,7 +1137,7 @@ def forward( inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) - eagle_logits_1.append(eagle_logits_[-input_ids.shape[1] :]) + eagle_logits_1.append(eagle_logits_) eagle_logits_1 = torch.cat(eagle_logits_1, dim=0) for i in range(self.eagle_config.parallel_draft_step): @@ -1187,7 +1187,7 @@ def forward( inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) - eagle_logits_2.append(eagle_logits_[-input_ids.shape[1] :]) + eagle_logits_2.append(eagle_logits_) eagle_logits_2 = torch.cat(eagle_logits_2, dim=0) for i in range(self.eagle_config.parallel_draft_step): @@ -1237,7 +1237,7 @@ def forward( inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) - eagle_logits_3.append(eagle_logits_[-input_ids.shape[1] :]) + eagle_logits_3.append(eagle_logits_) eagle_logits_3 = torch.cat(eagle_logits_3, dim=0) for i in range(self.eagle_config.parallel_draft_step): From 9195c1866a44ac9a5bf0de9139c6acefde965713 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 29 Sep 2025 10:31:42 -0700 Subject: [PATCH 17/26] minor edit based on coderabbit's suggestions Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 10 +++++----- modelopt/torch/speculative/utils.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 955c799f5..d7cf591e8 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -140,7 +140,7 @@ def dict_to_config( def mcore_version_higher_than(target_version: str): - """Check if megatron-core is least this version.""" + """Check if megatron-core is greater than this version.""" return Version(megatron.core.__version__) > Version(target_version) @@ -239,13 +239,13 @@ def set_multi_step_attention_mask(attn_mask, step): ======================================================================================================================= """ # noqa: E501 s = attn_mask.shape[-1] - for iter in range(2, step + 1): - # iter starts from 2nd step + for step_idx in range(2, step + 1): + # step_idx starts from 2nd step mask_0 = attn_mask.clone().detach() - mask_0[:, :, iter - 2, :] = True + mask_0[:, :, step_idx - 2, :] = True mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(iter - 1, s - 1): + for i in range(step_idx - 1, s - 1): mask_1[:, :, i, i] = False attn_mask = torch.cat((mask_0, mask_1), dim=-1) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 96fb56243..1d9a2d5a6 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -294,7 +294,9 @@ def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismat """This function checks the data consistency across all ranks in the group. Use rank 0 data as the golden set to broadcast to all ranks. - Each rank will then compare to this data and through error if different. + Each rank compares its data against this golden set and either raises + (when fail_when_mismatch=True) or emits a warning while forcing every + rank to adopt rank 0's data. """ if not torch.distributed.is_initialized(): return data From 32919ff03c5e365ec5718ca00ca6783624549006 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 1 Oct 2025 10:30:52 -0700 Subject: [PATCH 18/26] make ttt_step configurable in forward Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 292 +++++------------- 1 file changed, 77 insertions(+), 215 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index d7cf591e8..2faff69a2 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -197,6 +197,7 @@ def set_multi_step_attention_mask(attn_mask, step): ttt_step=2 parallel_draft_step=2 + ->step=3 | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | (out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 | @@ -239,13 +240,12 @@ def set_multi_step_attention_mask(attn_mask, step): ======================================================================================================================= """ # noqa: E501 s = attn_mask.shape[-1] - for step_idx in range(2, step + 1): - # step_idx starts from 2nd step + for step_idx in range(step): mask_0 = attn_mask.clone().detach() - mask_0[:, :, step_idx - 2, :] = True + mask_0[:, :, step_idx, :] = True mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(step_idx - 1, s - 1): + for i in range(step_idx + 1, s - 1): mask_1[:, :, i, i] = False attn_mask = torch.cat((mask_0, mask_1), dim=-1) @@ -759,8 +759,8 @@ def _get_eagle_module_inputs( attention_mask: torch.Tensor, position_ids: torch.Tensor, features: torch.Tensor | None = None, - ttt_step: int = 1, - parallel_draft_step: int = 1, + ttt_step: int = 0, + parallel_draft_step: int = 0, ): """Getting EAGLE module inputs.""" b = hidden_states.shape[1] @@ -784,10 +784,10 @@ def _get_eagle_module_inputs( eagle_inputs["input_ids"] = ( padded_input_ids - if parallel_draft_step == 1 + if parallel_draft_step == 0 else torch.full( padded_input_ids.shape, - getattr(self, f"mask_token_{parallel_draft_step - 2}"), + getattr(self, f"mask_token_{parallel_draft_step - 1}"), device=padded_input_ids.device, dtype=padded_input_ids.dtype, ) @@ -805,7 +805,7 @@ def _get_eagle_module_inputs( feature = gathered_features[-s:] eagle_inputs["hidden_states"] = ( gathered_hidden_states - if ttt_step == 1 + if ttt_step == 0 else torch.cat( ( torch.zeros( @@ -824,12 +824,12 @@ def _get_eagle_module_inputs( ) eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, (ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step + attn_mask, ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_step ) eagle_inputs["rotary_pos_emb"] = torch.cat( [rotary_pos_emb] - * ((ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step), + * (ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_step + 1), dim=0, ) @@ -970,6 +970,7 @@ def forward( packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict | None = None, return_eagle_inputs: bool = False, + ttt_steps=4, **kwargs, ) -> torch.Tensor: if position_ids is None or attention_mask is None: @@ -1013,7 +1014,8 @@ def forward( # EAGLE kv cache eagle_inference_context = StaticInferenceContext( - input_ids.shape[0], input_ids.shape[1] * self.eagle_config.parallel_draft_step * 4 + input_ids.shape[0], + input_ids.shape[1] * self.eagle_config.parallel_draft_step * ttt_steps, ) if self.eagle_offline: @@ -1043,22 +1045,41 @@ def forward( hidden_states, apply_fc=True ) - # In calibration mode, we want to make sure all weights have been exercised. - # This makes sure all quantized weights have amax calibrated - if inference_params is None or self.calibration_mode: - eagle_logits_0 = [] + if labels is not None: + if labels.shape[1] == input_ids.shape[1] - 1: + # For offline training, labels may be 1 token shorter than input_ids. + # We will just pad a 0 to the labels to make the seq_len the same as + # input_ids. This will introduce a small error in training if logit_distillation + # is False, and testing accuracy is wrong for the last token. + right_token_pad = torch.zeros( + (labels.shape[0], 1), + dtype=labels.dtype, + device=labels.device, + ) + labels = torch.cat((labels, right_token_pad), dim=-1) + + # If eagle_freeze_base_model is set to True, + # the base model is frozen . + loss = self.compute_language_model_loss(labels, logits_sbh) + if self.eagle_freeze_base_model: + loss = 0.0 * loss + + eagle_hidden_states_pre_norm = None + for ttt_step in range(ttt_steps): + eagle_logits = [] for i in range(self.eagle_config.parallel_draft_step): - eagle_inputs_0 = self._get_eagle_module_inputs( + eagle_inputs = self._get_eagle_module_inputs( input_ids=input_ids, hidden_states=eagle_module_input_hidden_states, attention_mask=attention_mask, position_ids=position_ids, - ttt_step=1, - parallel_draft_step=i + 1, + features=eagle_hidden_states_pre_norm, + ttt_step=ttt_step, + parallel_draft_step=i, ) - _, eagle_logits_, eagle_hidden_states_0_pre_norm = self._eagle_forward( - eagle_inputs_0, + _, eagle_logits_, eagle_hidden_states_pre_norm_ = self._eagle_forward( + eagle_inputs, output_weight, inference_params=inference_params, packed_seq_params=packed_seq_params, @@ -1066,205 +1087,46 @@ def forward( **(extra_block_kwargs or {}), ) - eagle_logits_0.append(eagle_logits_) - eagle_logits_0 = torch.cat(eagle_logits_0, dim=0) - - # If labels are not provided, return the original logits. We only return after - # all eagle weights have been exercised for quantization calibration purpose. - if labels is None: - return logits_sbh.transpose(0, 1).contiguous() - elif labels.shape[1] == input_ids.shape[1] - 1: - # For offline training, labels may be 1 token shorter than input_ids. - # We will just pad a 0 to the labels to make the seq_len the same as - # input_ids. This will introduce a small error in training if logit_distillation - # is False, and testing accuracy is wrong for the last token. - right_token_pad = torch.zeros( - (labels.shape[0], 1), - dtype=labels.dtype, - device=labels.device, - ) - labels = torch.cat((labels, right_token_pad), dim=-1) - - # If eagle_freeze_base_model is set to True, - # the base model is frozen . - loss = self.compute_language_model_loss(labels, logits_sbh) - loss = 0.0 * loss - - for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_0[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) - loss_ = loss_[:, i:] - loss[:, i + 1 :] += self.eagle_loss_decay_factor * loss_ - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - for i in range(self.eagle_config.parallel_draft_step): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_0[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - ) - gathered_logits = gathered_logits[i:-1] - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, i + 1 :], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 1st Top-1: {acc}", - flush=True, - ) - - # Second round of EAGLE loss - eagle_logits_1 = [] - for i in range(self.eagle_config.parallel_draft_step): - eagle_inputs_1 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_0_pre_norm, - ttt_step=2, - parallel_draft_step=i + 1, - ) + eagle_logits.append(eagle_logits_) + eagle_logits = torch.cat(eagle_logits, dim=0) + eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_ - _, eagle_logits_, eagle_hidden_states_2x_pre_norm = self._eagle_forward( - eagle_inputs_1, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) - eagle_logits_1.append(eagle_logits_) - eagle_logits_1 = torch.cat(eagle_logits_1, dim=0) - - for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_1[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) - loss_ = loss_[:, i + 1 :] - loss[:, i + 2 :] += self.eagle_loss_decay_factor**2 * loss_ - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - for i in range(self.eagle_config.parallel_draft_step): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_1[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - ) - gathered_logits = gathered_logits[i + 1 : -1] - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, i + 2 :], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 2nd Top-1: {acc}", - flush=True, - ) + # If labels are not provided, return the original logits. We only return after + # all eagle weights have been exercised for quantization calibration purpose. + if labels is None: + return logits_sbh.transpose(0, 1).contiguous() - # Third EAGLE loss - eagle_logits_2 = [] - for i in range(self.eagle_config.parallel_draft_step): - eagle_inputs_2 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_2x_pre_norm, - ttt_step=3, - parallel_draft_step=i + 1, - ) - - _, eagle_logits_, eagle_hidden_states_3x_pre_norm = self._eagle_forward( - eagle_inputs_2, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) - eagle_logits_2.append(eagle_logits_) - eagle_logits_2 = torch.cat(eagle_logits_2, dim=0) - - for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_2[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) - loss_ = loss_[:, i + 2 :] - loss[:, i + 3 :] += self.eagle_loss_decay_factor**3 * loss_ - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - for i in range(self.eagle_config.parallel_draft_step): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_2[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - ) - gathered_logits = gathered_logits[i + 2 : -1] - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, i + 3 :], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 3rd Top-1: {acc}", - flush=True, + for i in range(self.eagle_config.parallel_draft_step): + eagle_logit = eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] + loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logit) + loss_ = loss_[:, i + ttt_step :] + loss[:, i + ttt_step + 1 :] += ( + self.eagle_loss_decay_factor ** (ttt_step + i) * loss_ ) - # Forth EAGLE loss - eagle_logits_3 = [] - for i in range(self.eagle_config.parallel_draft_step): - eagle_inputs_3 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_3x_pre_norm, - ttt_step=4, - parallel_draft_step=i + 1, - ) - - _, eagle_logits_, eagle_hidden_states_4x_pre_norm = self._eagle_forward( - eagle_inputs_3, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) - eagle_logits_3.append(eagle_logits_) - eagle_logits_3 = torch.cat(eagle_logits_3, dim=0) - - for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_3[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) - loss_ = loss_[:, i + 3 :] - loss[:, i + 4 :] += self.eagle_loss_decay_factor**4 * loss_ - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - for i in range(self.eagle_config.parallel_draft_step): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_3[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] + if self.eagle_report_acc and not self.training: + acc = [] + with torch.no_grad(): + for i in range(self.eagle_config.parallel_draft_step): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] + ) + gathered_logits = gathered_logits[i + ttt_step : -1] + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = ( + torch.eq(labels[:, i + ttt_step + 1 :], eagle_top1).sum() + / eagle_top1.numel() + ) + acc.append(top1_p) + + if get_tensor_model_parallel_rank() == 0: + print( + f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3}" + f"EAGLE 1st Top-1: {acc}", + flush=True, ) - gathered_logits = gathered_logits[i + 3 : -1] - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, i + 4 :], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 4th Top-1: {acc}", - flush=True, - ) return loss From 4abfb52b4d649bb94c50e343b2b8d6d9f6ed3431 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 1 Oct 2025 10:46:02 -0700 Subject: [PATCH 19/26] fix the bug in pseudo_speculative_generate Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 2faff69a2..288bbaa2f 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1408,6 +1408,11 @@ def pseudo_speculative_generate( draft_tokens.append(draft_token) + # Remove mask tokens from eagle_ids before adding draft_token + # Remove added hidden_states before + eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1] + hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1] + eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1) hidden_states = torch.cat( ( From cb9282e82fab2bc6f3ebeb7e17343fe9b6bc92c5 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 1 Oct 2025 10:49:11 -0700 Subject: [PATCH 20/26] change variable name to make it clear Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 288bbaa2f..9918a25d3 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -760,7 +760,7 @@ def _get_eagle_module_inputs( position_ids: torch.Tensor, features: torch.Tensor | None = None, ttt_step: int = 0, - parallel_draft_step: int = 0, + parallel_draft_index: int = 0, ): """Getting EAGLE module inputs.""" b = hidden_states.shape[1] @@ -784,10 +784,10 @@ def _get_eagle_module_inputs( eagle_inputs["input_ids"] = ( padded_input_ids - if parallel_draft_step == 0 + if parallel_draft_index == 0 else torch.full( padded_input_ids.shape, - getattr(self, f"mask_token_{parallel_draft_step - 1}"), + getattr(self, f"mask_token_{parallel_draft_index - 1}"), device=padded_input_ids.device, dtype=padded_input_ids.dtype, ) @@ -824,12 +824,12 @@ def _get_eagle_module_inputs( ) eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_step + attn_mask, ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_index ) eagle_inputs["rotary_pos_emb"] = torch.cat( [rotary_pos_emb] - * (ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_step + 1), + * (ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_index + 1), dim=0, ) @@ -1075,7 +1075,7 @@ def forward( position_ids=position_ids, features=eagle_hidden_states_pre_norm, ttt_step=ttt_step, - parallel_draft_step=i, + parallel_draft_index=i, ) _, eagle_logits_, eagle_hidden_states_pre_norm_ = self._eagle_forward( From b47928e025a273d648357fd2582fefc9c5839132 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 1 Oct 2025 13:32:26 -0700 Subject: [PATCH 21/26] debug: reduce kv cache size from ttt*parallel to ttt+parallel-1; in each ttt step, only the non_parallel tokens from previous ttt are used as context Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 123 +++++++++++------- 1 file changed, 77 insertions(+), 46 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 9918a25d3..d0b11b84d 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -195,48 +195,72 @@ def set_multi_step_attention_mask(attn_mask, step): h0 h1 h2 h3 h4 h5 h6 h7 (base hidden_states) l0 l1 l2 l3 l4 l5 l6 l7 (base labels) - ttt_step=2 - parallel_draft_step=2 - ->step=3 - - | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 | - ======================================================================================================================= - F1 l1 | i1 h0 | x | | | | - F2 l2 | i2 h1 | x x | | | | - F3 l3 | i3 h2 | x x x | | | | - F4 l4 | i4 h3 | x x x x | | | | - F5 l5 | i5 h4 | x x x x x | | | | - F6 l6 | i6 h5 | x x x x x x | | | | - F7 l7 | i7 h6 | x x x x x x x | | | | - -- -- | -- h7 | o o o o o o o o | | | | + ttt_steps=2 + parallel_draft_step=3 + + + ttt_step=0 + | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | + ============================================================================================= + F1 l1 | i1 h0 | x | | | + F2 l2 | i2 h1 | x x | | | + F3 l3 | i3 h2 | x x x | | | + F4 l4 | i4 h3 | x x x x | | | + F5 l5 | i5 h4 | x x x x x | | | + F6 l6 | i6 h5 | x x x x x x | | | + F7 l7 | i7 h6 | x x x x x x x | | | + -- -- | -- h7 | o o o o o o o o | | | + ============================================================================================= + -- -- | m0 -- | | | | + G2 l2 | m0 h1 | x o | x | | + G3 l3 | m0 h2 | x x o | x | | + G4 l4 | m0 h3 | x x x o | x | | + G5 l5 | m0 h4 | x x x x o | x | | + G6 l6 | m0 h5 | x x x x x o | x | | + G7 l7 | m0 h6 | x x x x x x o | x | | + -- -- | -- h7 | | | | + ============================================================================================= + -- -- | m1 -- | | | | + -- -- | m1 h1 | | | | + H3 l3 | m1 h2 | x o o | x o | x | + H4 l4 | m1 h3 | x x o o | x o | x | + H5 l5 | m1 h4 | x x x o o | x o | x | + H6 l6 | m1 h5 | x x x x o o | x o | x | + H7 l7 | m1 h6 | x x x x x o o | x o | x | + -- -- | -- h7 | | | | + + + ttt_step=1 + | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | ======================================================================================================================= -- -- | i1 -- | | | | | - G2 l2 | i2 h1 | x o | x | | | - G3 l3 | i3 h2 | x x o | x | | | - G4 l4 | i4 h3 | x x x o | x | | | - G5 l5 | i5 h4 | x x x x o | x | | | - G6 l6 | i6 h5 | x x x x x o | x | | | - G7 l7 | i7 h6 | x x x x x x o | x | | | - -- -- | -- h7 | | | | | + J2 l2 | i2 F1 | x o | x | | | + J3 l3 | i3 F2 | x x o | x | | | + J4 l4 | i4 F3 | x x x o | x | | | + J5 l5 | i5 F4 | x x x x o | x | | | + J6 l6 | i6 F5 | x x x x x o | x | | | + J7 l7 | i7 F6 | x x x x x x o | x | | | + -- -- | -- F7 | | | | | ======================================================================================================================= - -- -- | i1 -- | | | | | - -- -- | i2 -- | | | | | - H3 l3 | i3 G2 | x o o | x o | x | | - H4 l4 | i4 G3 | x x o o | x o | x | | - H5 l5 | i5 G4 | x x x o o | x o | x | | - H6 l6 | i6 G5 | x x x x o o | x o | x | | - H7 l7 | i7 G6 | x x x x x o o | x o | x | | - -- -- | -- G7 | | | | | - ======================================================================================================================= - -- -- | m0 -- | | | | | -- -- | m0 -- | | | | | -- -- | m0 -- | | | | | - K4 l4 | m0 G3 | x | x | x | x | - K5 l5 | m0 G4 | x x | x | x | x | - K6 l6 | m0 G5 | x x x | x | x | x | - K7 l7 | m0 G6 | x x x x | x | x | x | - -- -- | -- G7 | | | | | + K3 l3 | m0 F2 | x o o | x o | x | | | + K4 l4 | m0 F3 | x x o o | x o | x | | + K5 l5 | m0 F4 | x x x o o | x o | x | | + K6 l6 | m0 F5 | x x x x o o | x o | x | | + K7 l7 | m0 F6 | x x x x x o o | x o | x | | + -- -- | -- F7 | | | | | + ======================================================================================================================= + -- -- | m1 -- | | | | | + -- -- | m1 -- | | | | | + -- -- | m1 -- | | | | | + N4 l4 | m1 F3 | x | x | x | x | + N5 l5 | m1 F4 | x x | x | x | x | + N6 l6 | m1 F5 | x x x | x | x | x | + N7 l7 | m1 F6 | x x x x | x | x | x | + -- -- | -- F7 | | | | | ======================================================================================================================= """ # noqa: E501 s = attn_mask.shape[-1] @@ -765,7 +789,6 @@ def _get_eagle_module_inputs( """Getting EAGLE module inputs.""" b = hidden_states.shape[1] h = hidden_states.shape[2] - s = input_ids.shape[1] # [b, 1] id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device) @@ -801,8 +824,7 @@ def _get_eagle_module_inputs( else: gathered_hidden_states = hidden_states gathered_features = features - if gathered_features is not None: - feature = gathered_features[-s:] + eagle_inputs["hidden_states"] = ( gathered_hidden_states if ttt_step == 0 @@ -813,7 +835,7 @@ def _get_eagle_module_inputs( dtype=hidden_states.dtype, device=hidden_states.device, ), - feature[:-1, :, :], + gathered_features[:-1, :, :], ) ) ) @@ -824,12 +846,11 @@ def _get_eagle_module_inputs( ) eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_index + attn_mask, ttt_step + parallel_draft_index ) eagle_inputs["rotary_pos_emb"] = torch.cat( - [rotary_pos_emb] - * (ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_index + 1), + [rotary_pos_emb] * (ttt_step + parallel_draft_index + 1), dim=0, ) @@ -1015,7 +1036,7 @@ def forward( # EAGLE kv cache eagle_inference_context = StaticInferenceContext( input_ids.shape[0], - input_ids.shape[1] * self.eagle_config.parallel_draft_step * ttt_steps, + input_ids.shape[1] * (self.eagle_config.parallel_draft_step + ttt_steps - 1), ) if self.eagle_offline: @@ -1087,9 +1108,19 @@ def forward( **(extra_block_kwargs or {}), ) + if i == 0: + next_eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_ + eagle_logits.append(eagle_logits_) eagle_logits = torch.cat(eagle_logits, dim=0) - eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_ + eagle_hidden_states_pre_norm = next_eagle_hidden_states_pre_norm + + # Discard kv cache for the last parallel_draft_step - 1 tokens + # as the next ttt_step will only base on the first token in the + # current ttt_step + eagle_inference_context.sequence_len_offset -= input_ids.shape[1] * ( + self.eagle_config.parallel_draft_step - 1 + ) # If labels are not provided, return the original logits. We only return after # all eagle weights have been exercised for quantization calibration purpose. From b528ccc874184663818bb163d011186677a84f2f Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 1 Oct 2025 13:59:09 -0700 Subject: [PATCH 22/26] gitnore type for precommit Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index d0b11b84d..25a1bef68 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -835,7 +835,7 @@ def _get_eagle_module_inputs( dtype=hidden_states.dtype, device=hidden_states.device, ), - gathered_features[:-1, :, :], + gathered_features[:-1, :, :], # type: ignore[index] ) ) ) From 44e0eb1818d345a6beb815097440e1cea40493d1 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 1 Oct 2025 14:17:02 -0700 Subject: [PATCH 23/26] consolidate acc printout Signed-off-by: Ye Yu --- .../torch/speculative/plugins/megatron_eagle.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 25a1bef68..2e97b6f49 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1085,6 +1085,7 @@ def forward( if self.eagle_freeze_base_model: loss = 0.0 * loss + acc = [] eagle_hidden_states_pre_norm = None for ttt_step in range(ttt_steps): eagle_logits = [] @@ -1136,7 +1137,6 @@ def forward( ) if self.eagle_report_acc and not self.training: - acc = [] with torch.no_grad(): for i in range(self.eagle_config.parallel_draft_step): gathered_logits = gather_from_tensor_model_parallel_region( @@ -1152,12 +1152,12 @@ def forward( ) acc.append(top1_p) - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3}" - f"EAGLE 1st Top-1: {acc}", - flush=True, - ) + if self.eagle_report_acc and not self.training and get_tensor_model_parallel_rank() == 0: + print( + f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3}" + f"EAGLE Top-1: {acc}", + flush=True, + ) return loss From 13f218c65954b98535320dc3aecb1cf0683650b0 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 1 Oct 2025 20:16:28 -0700 Subject: [PATCH 24/26] fix the bug in pseudo_speculative_generate Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 2e97b6f49..bddff1a0e 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1441,8 +1441,9 @@ def pseudo_speculative_generate( # Remove mask tokens from eagle_ids before adding draft_token # Remove added hidden_states before - eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1] - hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1] + if self.eagle_config.parallel_draft_step > 1: + eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1] + hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1] eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1) hidden_states = torch.cat( From 8212265f510fe340a1fd7c172f1b4db7bad12315 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 3 Oct 2025 14:17:00 -0700 Subject: [PATCH 25/26] use embedding for mask tokens as hidden_states Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 152 +++++++++--------- 1 file changed, 78 insertions(+), 74 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index bddff1a0e..259155008 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -201,7 +201,7 @@ def set_multi_step_attention_mask(attn_mask, step): ttt_step=0 | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- | - | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | + | h0 h1 h2 h3 h4 h5 h6 h7 | M0 M0 M0 M0 M0 M0 M0 M0 | M1 M1 M1 M1 M1 M1 M1 M1 | ============================================================================================= F1 l1 | i1 h0 | x | | | F2 l2 | i2 h1 | x x | | | @@ -212,28 +212,28 @@ def set_multi_step_attention_mask(attn_mask, step): F7 l7 | i7 h6 | x x x x x x x | | | -- -- | -- h7 | o o o o o o o o | | | ============================================================================================= - -- -- | m0 -- | | | | - G2 l2 | m0 h1 | x o | x | | - G3 l3 | m0 h2 | x x o | x | | - G4 l4 | m0 h3 | x x x o | x | | - G5 l5 | m0 h4 | x x x x o | x | | - G6 l6 | m0 h5 | x x x x x o | x | | - G7 l7 | m0 h6 | x x x x x x o | x | | - -- -- | -- h7 | | | | + -- -- | m0 M0 | | | | + G2 l2 | m0 M0 | x o | x | | + G3 l3 | m0 M0 | x x o | x | | + G4 l4 | m0 M0 | x x x o | x | | + G5 l5 | m0 M0 | x x x x o | x | | + G6 l6 | m0 M0 | x x x x x o | x | | + G7 l7 | m0 M0 | x x x x x x o | x | | + -- -- | -- M0 | | | | ============================================================================================= - -- -- | m1 -- | | | | - -- -- | m1 h1 | | | | - H3 l3 | m1 h2 | x o o | x o | x | - H4 l4 | m1 h3 | x x o o | x o | x | - H5 l5 | m1 h4 | x x x o o | x o | x | - H6 l6 | m1 h5 | x x x x o o | x o | x | - H7 l7 | m1 h6 | x x x x x o o | x o | x | - -- -- | -- h7 | | | | + -- -- | m1 M0 | | | | + -- -- | m1 M1 | | | | + H3 l3 | m1 M1 | x o o | x o | x | + H4 l4 | m1 M1 | x x o o | x o | x | + H5 l5 | m1 M1 | x x x o o | x o | x | + H6 l6 | m1 M1 | x x x x o o | x o | x | + H7 l7 | m1 M1 | x x x x x o o | x o | x | + -- -- | -- M1 | | | | ttt_step=1 | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- | - | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | + | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | M0 M0 M0 M0 M0 M0 M0 M0 | M1 M1 M1 M1 M1 M1 M1 M1 | ======================================================================================================================= -- -- | i1 -- | | | | | J2 l2 | i2 F1 | x o | x | | | @@ -244,23 +244,23 @@ def set_multi_step_attention_mask(attn_mask, step): J7 l7 | i7 F6 | x x x x x x o | x | | | -- -- | -- F7 | | | | | ======================================================================================================================= - -- -- | m0 -- | | | | | - -- -- | m0 -- | | | | | - K3 l3 | m0 F2 | x o o | x o | x | | | - K4 l4 | m0 F3 | x x o o | x o | x | | - K5 l5 | m0 F4 | x x x o o | x o | x | | - K6 l6 | m0 F5 | x x x x o o | x o | x | | - K7 l7 | m0 F6 | x x x x x o o | x o | x | | - -- -- | -- F7 | | | | | + -- -- | m0 M0 | | | | | + -- -- | m0 M0 | | | | | + K3 l3 | m0 M0 | x o o | x o | x | | | + K4 l4 | m0 M0 | x x o o | x o | x | | + K5 l5 | m0 M0 | x x x o o | x o | x | | + K6 l6 | m0 M0 | x x x x o o | x o | x | | + K7 l7 | m0 M0 | x x x x x o o | x o | x | | + -- -- | -- M0 | | | | | ======================================================================================================================= - -- -- | m1 -- | | | | | - -- -- | m1 -- | | | | | - -- -- | m1 -- | | | | | - N4 l4 | m1 F3 | x | x | x | x | - N5 l5 | m1 F4 | x x | x | x | x | - N6 l6 | m1 F5 | x x x | x | x | x | - N7 l7 | m1 F6 | x x x x | x | x | x | - -- -- | -- F7 | | | | | + -- -- | m1 M1 | | | | | + -- -- | m1 M1 | | | | | + -- -- | m1 M1 | | | | | + N4 l4 | m1 M1 | x | x | x | x | + N5 l5 | m1 M1 | x x | x | x | x | + N6 l6 | m1 M1 | x x x | x | x | x | + N7 l7 | m1 M1 | x x x x | x | x | x | + -- -- | -- M1 | | | | | ======================================================================================================================= """ # noqa: E501 s = attn_mask.shape[-1] @@ -782,16 +782,14 @@ def _get_eagle_module_inputs( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, - features: torch.Tensor | None = None, ttt_step: int = 0, parallel_draft_index: int = 0, ): """Getting EAGLE module inputs.""" - b = hidden_states.shape[1] - h = hidden_states.shape[2] - # [b, 1] - id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device) + id_padding = torch.zeros( + (input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device + ) padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1) rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) @@ -816,35 +814,15 @@ def _get_eagle_module_inputs( ) ) - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = ( - None if features is None else gather_from_sequence_parallel_region(features) - ) - else: - gathered_hidden_states = hidden_states - gathered_features = features + eagle_inputs["embedding"] = self.embedding( + input_ids=eagle_inputs["input_ids"], + position_ids=eagle_inputs["position_ids"], + ) eagle_inputs["hidden_states"] = ( - gathered_hidden_states - if ttt_step == 0 - else torch.cat( - ( - torch.zeros( - (1, b, h), - dtype=hidden_states.dtype, - device=hidden_states.device, - ), - gathered_features[:-1, :, :], # type: ignore[index] - ) - ) + hidden_states if parallel_draft_index == 0 else eagle_inputs["embedding"] ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - eagle_inputs["attention_mask"] = set_multi_step_attention_mask( attn_mask, ttt_step + parallel_draft_index ) @@ -854,11 +832,6 @@ def _get_eagle_module_inputs( dim=0, ) - eagle_inputs["embedding"] = self.embedding( - input_ids=eagle_inputs["input_ids"], - position_ids=eagle_inputs["position_ids"], - ) - return eagle_inputs def _compute_eagle_loss(self, logits, labels, eagle_logits): @@ -1086,7 +1059,6 @@ def forward( loss = 0.0 * loss acc = [] - eagle_hidden_states_pre_norm = None for ttt_step in range(ttt_steps): eagle_logits = [] for i in range(self.eagle_config.parallel_draft_step): @@ -1095,7 +1067,6 @@ def forward( hidden_states=eagle_module_input_hidden_states, attention_mask=attention_mask, position_ids=position_ids, - features=eagle_hidden_states_pre_norm, ttt_step=ttt_step, parallel_draft_index=i, ) @@ -1114,7 +1085,29 @@ def forward( eagle_logits.append(eagle_logits_) eagle_logits = torch.cat(eagle_logits, dim=0) - eagle_hidden_states_pre_norm = next_eagle_hidden_states_pre_norm + eagle_module_input_hidden_states = next_eagle_hidden_states_pre_norm + if self.config.sequence_parallel: + eagle_module_input_hidden_states = gather_from_sequence_parallel_region( + eagle_module_input_hidden_states + ) + eagle_module_input_hidden_states = torch.cat( + ( + torch.zeros( + ( + 1, + eagle_module_input_hidden_states.shape[1], + eagle_module_input_hidden_states.shape[2], + ), + dtype=eagle_module_input_hidden_states.dtype, + device=eagle_module_input_hidden_states.device, + ), + eagle_module_input_hidden_states[:-1, :, :], + ) + ) + if self.config.sequence_parallel: + eagle_module_input_hidden_states = scatter_to_sequence_parallel_region( + eagle_module_input_hidden_states + ) # Discard kv cache for the last parallel_draft_step - 1 tokens # as the next ttt_step will only base on the first token in the @@ -1393,12 +1386,12 @@ def pseudo_speculative_generate( eagle_ids = torch.cat( (eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1 ) + # Pad dummy hidden_states for mask tokens + # They will be replaced by embeddings after padding hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0) padded_eagle_ids, seq_len, padded_hidden_states = right_padding( eagle_ids, hidden_states ) - if self.config.sequence_parallel: - padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states) eagle_attention_mask, eagle_position_ids = get_default_attention_mask_and_position_ids( padded_eagle_ids ) @@ -1409,6 +1402,17 @@ def pseudo_speculative_generate( input_ids=padded_eagle_ids, position_ids=eagle_position_ids, ) + if self.config.sequence_parallel: + gathered_embedding = gather_from_sequence_parallel_region(eagle_inputs["embedding"]) + if self.eagle_config.parallel_draft_step > 1: + # Replace dummy hidden_states with embedding for mask tokens + padded_hidden_states[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] = gathered_embedding[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] + if self.config.sequence_parallel: + padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states) eagle_inputs["hidden_states"] = padded_hidden_states eagle_inputs["attention_mask"] = eagle_attention_mask From aa328ed57cbaaf80129fb5296e99e0ec239d026f Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 3 Oct 2025 15:07:08 -0700 Subject: [PATCH 26/26] debug Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 259155008..b4b9ffeac 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1404,6 +1404,8 @@ def pseudo_speculative_generate( ) if self.config.sequence_parallel: gathered_embedding = gather_from_sequence_parallel_region(eagle_inputs["embedding"]) + else: + gathered_embedding = eagle_inputs["embedding"] if self.eagle_config.parallel_draft_step > 1: # Replace dummy hidden_states with embedding for mask tokens padded_hidden_states[