Skip to content

Commit b176cba

Browse files
authored
support mtp in ep64 (#4280)
1 parent dcf633c commit b176cba

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

fastdeploy/spec_decode/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,21 @@ def __init__(self, cfg: FDConfig):
3838
Init Speculative proposer
3939
"""
4040
cfg.parallel_config.tp_group = None
41+
cfg.parallel_config.ep_group = None
4142
self.cfg = deepcopy(cfg)
4243
cfg.parallel_config.tp_group = dist.get_group(
4344
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
4445
)
46+
cfg.parallel_config.ep_group = dist.get_group(
47+
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
48+
)
4549
self.cfg.parallel_config.tp_group = dist.get_group(
4650
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
4751
)
52+
self.cfg.parallel_config.ep_group = dist.get_group(
53+
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
54+
)
55+
4856
self.parallel_config = self.cfg.parallel_config
4957
self.model_config = self.cfg.model_config
5058
self.speculative_config = self.cfg.speculative_config

fastdeploy/spec_decode/mtp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,9 @@ def _propose(self, target_hidden_states):
695695

696696
if substep != self.num_model_steps - 1:
697697
target_hidden_states = self._get_self_hidden_states(hidden_states)
698+
else:
699+
if hasattr(self.model, "empty_input_forward"):
700+
self.model.empty_input_forward()
698701

699702
def _get_self_hidden_states(self, hidden_states):
700703
target_hidden_states = eagle_get_self_hidden_states(

0 commit comments

Comments
 (0)