Skip to content

Commit 79c6a77

Browse files
committed
minor
1 parent 66f5c67 commit 79c6a77

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def _get_eagle_module_inputs(
845845
rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1])
846846

847847
attn_mask = attention_mask.clone().detach()
848-
attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:]
848+
attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
849849
attn_mask[:, :, -1, :] = True
850850
attn_mask[:, :, :, -1] = True
851851

@@ -914,6 +914,55 @@ def _get_eagle_module_inputs(
914914
eagle_inputs["attention_mask"] = attn_mask
915915
eagle_inputs["position_ids"] = position_ids
916916
eagle_inputs["rotary_pos_emb"] = rotary_pos_emb
917+
918+
if self.config.sequence_parallel:
919+
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states)
920+
else:
921+
gathered_hidden_states = hidden_states
922+
eagle_inputs["hidden_states"] = gathered_hidden_states
923+
924+
for i in range(self.eagle_config.parallel_draft_step - 1):
925+
eagle_inputs["input_ids"] = torch.cat(
926+
(
927+
eagle_inputs["input_ids"],
928+
torch.full(
929+
padded_input_ids.shape,
930+
getattr(self, f"mask_token_{i}"),
931+
device=padded_input_ids.device,
932+
dtype=padded_input_ids.dtype,
933+
),
934+
),
935+
dim=-1,
936+
)
937+
938+
eagle_inputs["hidden_states"] = torch.cat(
939+
(
940+
eagle_inputs["hidden_states"],
941+
torch.zeros(
942+
(1 + i, b, h), dtype=hidden_states.dtype, device=hidden_states.device
943+
),
944+
gathered_hidden_states[: -(1 + i)],
945+
),
946+
dim=0,
947+
)
948+
949+
eagle_inputs["position_ids"] = torch.cat(
950+
(eagle_inputs["position_ids"], position_ids), dim=-1
951+
)
952+
953+
if rotary_pos_emb is not None:
954+
eagle_inputs["rotary_pos_emb"] = torch.cat(
955+
(eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0
956+
)
957+
958+
if self.config.sequence_parallel:
959+
eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region(
960+
eagle_inputs["hidden_states"]
961+
)
962+
963+
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(
964+
attn_mask, self.eagle_config.parallel_draft_step
965+
)
917966
elif features.shape[0] == hidden_states.shape[0]:
918967
eagle_inputs["input_ids"] = torch.cat(
919968
(padded_input_ids, padded_input_ids),

0 commit comments

Comments
 (0)