File tree Expand file tree Collapse file tree 2 files changed +11
-0
lines changed Expand file tree Collapse file tree 2 files changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -38,13 +38,21 @@ def __init__(self, cfg: FDConfig):
38
38
Init Speculative proposer
39
39
"""
40
40
cfg .parallel_config .tp_group = None
41
+ cfg .parallel_config .ep_group = None
41
42
self .cfg = deepcopy (cfg )
42
43
cfg .parallel_config .tp_group = dist .get_group (
43
44
cfg .parallel_config .data_parallel_rank + envs .FD_TP_GROUP_GID_OFFSET
44
45
)
46
+ cfg .parallel_config .ep_group = dist .get_group (
47
+ cfg .parallel_config .data_parallel_size + envs .FD_TP_GROUP_GID_OFFSET
48
+ )
45
49
self .cfg .parallel_config .tp_group = dist .get_group (
46
50
cfg .parallel_config .data_parallel_rank + envs .FD_TP_GROUP_GID_OFFSET
47
51
)
52
+ self .cfg .parallel_config .ep_group = dist .get_group (
53
+ cfg .parallel_config .data_parallel_size + envs .FD_TP_GROUP_GID_OFFSET
54
+ )
55
+
48
56
self .parallel_config = self .cfg .parallel_config
49
57
self .model_config = self .cfg .model_config
50
58
self .speculative_config = self .cfg .speculative_config
Original file line number Diff line number Diff line change @@ -695,6 +695,9 @@ def _propose(self, target_hidden_states):
695
695
696
696
if substep != self .num_model_steps - 1 :
697
697
target_hidden_states = self ._get_self_hidden_states (hidden_states )
698
+ else :
699
+ if hasattr (self .model , "empty_input_forward" ):
700
+ self .model .empty_input_forward ()
698
701
699
702
def _get_self_hidden_states (self , hidden_states ):
700
703
target_hidden_states = eagle_get_self_hidden_states (
You can’t perform that action at this time.
0 commit comments