Skip to content

Commit f5c64a0

Browse files
[EP] Refactor DeepEP Engine Organization for Mixed Mode & Buffer Management Optimization (#3182)
* Add support for mixed-ep across multi nodes * code refine --------- Co-authored-by: yuanxiaolan <[email protected]>
1 parent 14ed75f commit f5c64a0

File tree

2 files changed

+29
-30
lines changed

2 files changed

+29
-30
lines changed

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -68,25 +68,20 @@ def __init__(
6868
self.num_local_experts = num_experts // ep_size
6969
self.async_finish = async_finish
7070

71-
self.prefill_deepep_engine = None
72-
self.decode_deepep_engine = None
71+
self.deepep_engine = None
7372

7473
self.ep_config = Config(24, 6, 256)
7574
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
7675

7776
# In mixed EP mode on a single node, we dynamically switch between
7877
# high throughput and low latency modes.
7978
if splitwise_role == "mixed":
80-
# decode engine
81-
logger.info("Initializing Low Latency Buffer")
82-
self.get_low_latency_buffer()
83-
# prefill engine
84-
self.prefill_deepep_engine = deep_ep.Buffer(
79+
self.deepep_engine = deep_ep.Buffer(
8580
self.group,
86-
int(5e8),
87-
0,
88-
low_latency_mode=False,
89-
num_qps_per_rank=1,
81+
int(2e9),
82+
int(5e9),
83+
low_latency_mode=True,
84+
num_qps_per_rank=24,
9085
)
9186
# In disaggregated mode on mutiple nodes, we either use
9287
# high throughput mode or low latency mode.
@@ -95,7 +90,7 @@ def __init__(
9590
logger.info("Initializing Low Latency Buffer")
9691
self.get_low_latency_buffer()
9792
elif moe_phase.phase == "prefill":
98-
self.prefill_deepep_engine = deep_ep.Buffer(
93+
self.deepep_engine = deep_ep.Buffer(
9994
self.group,
10095
int(5e8),
10196
0,
@@ -124,14 +119,14 @@ def get_low_latency_buffer(self):
124119
)
125120
# Allocate a buffer if not existed or not enough buffer size
126121
if (
127-
self.decode_deepep_engine is None
128-
or self.decode_deepep_engine.group != self.group
129-
or not self.decode_deepep_engine.low_latency_mode
130-
or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes
122+
self.deepep_engine is None
123+
or self.deepep_engine.group != self.group
124+
or not self.deepep_engine.low_latency_mode
125+
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
131126
):
132127
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
133128
assert self.num_experts % self.ep_size == 0
134-
self.decode_deepep_engine = deep_ep.Buffer(
129+
self.deepep_engine = deep_ep.Buffer(
135130
self.group,
136131
0,
137132
num_rdma_bytes,
@@ -168,7 +163,7 @@ def low_latency_dispatch(
168163
handle,
169164
_,
170165
dispatch_hook,
171-
) = self.decode_deepep_engine.low_latency_dispatch(
166+
) = self.deepep_engine.low_latency_dispatch(
172167
hidden_states,
173168
topk_idx,
174169
expertwise_scale,
@@ -210,7 +205,7 @@ def low_latency_combine(
210205
num_experts,
211206
)
212207

213-
combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine(
208+
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
214209
hidden_states,
215210
topk_idx,
216211
topk_weights,
@@ -224,19 +219,15 @@ def clean_low_latency_buffer(self):
224219
"""
225220
clean_low_latency_buffer
226221
"""
227-
self.decode_deepep_engine.clean_low_latency_buffer(
222+
self.deepep_engine.clean_low_latency_buffer(
228223
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
229224
)
230225

231226
def barrier_all(self):
232227
"""
233228
barrier_all
234229
"""
235-
if self.prefill_deepep_engine is not None:
236-
self.prefill_deepep_engine.barrier_all()
237-
238-
if self.decode_deepep_engine is not None:
239-
self.decode_deepep_engine.barrier_all()
230+
self.deepep_engine.barrier_all()
240231

241232

242233
class EPRunner:
@@ -316,6 +307,9 @@ def combine(self, *args, **kwargs):
316307
"""
317308
raise NotImplementedError
318309

310+
def clean_low_latency_buffer(self):
311+
self.ep_engine.clean_low_latency_buffer()
312+
319313

320314
class EPPrefillRunner(EPRunner):
321315
"""
@@ -328,6 +322,7 @@ def __init__(
328322
hidden: int,
329323
num_experts: int,
330324
splitwise_role: str,
325+
num_max_dispatch_tokens_per_rank: int,
331326
ep_size: int = 1,
332327
ep_rank: int = 0,
333328
redundant_experts_num: int = 0,
@@ -339,7 +334,7 @@ def __init__(
339334
num_experts,
340335
splitwise_role,
341336
moe_phase,
342-
num_max_dispatch_tokens_per_rank=256,
337+
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
343338
ep_size=ep_size,
344339
ep_rank=ep_rank,
345340
redundant_experts_num=redundant_experts_num,
@@ -359,7 +354,7 @@ def dispatch(
359354
num_tokens_per_expert,
360355
is_token_in_rank,
361356
_,
362-
) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
357+
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
363358

364359
x_scale_tensor = kwargs.get("x_scale_tensor", None)
365360
dispatch_args = {
@@ -372,7 +367,7 @@ def dispatch(
372367
"topk_idx": topk_idx,
373368
"topk_weights": topk_weights,
374369
}
375-
return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args)
370+
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
376371

377372
def combine(
378373
self,
@@ -387,14 +382,14 @@ def combine(
387382
"async_finish": self.ep_engine.async_finish,
388383
"topk_weights": recv_topk_weights,
389384
}
390-
fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args)
385+
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
391386

392387
return fused_moe_out
393388

394389

395390
class EPDecoderRunner(EPRunner):
396391
"""
397-
EPPrefillRunner
392+
EPDecoderRunner
398393
"""
399394

400395
def __init__(

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def init_ep(self, layer: nn.Layer) -> None:
5151
layer.hidden_size,
5252
layer.num_experts,
5353
layer.fd_config.parallel_config.splitwise_role,
54+
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
5455
layer.ep_size,
5556
layer.ep_rank,
5657
layer.fd_config.model_config.redundant_experts_num,
@@ -74,6 +75,7 @@ def init_ep(self, layer: nn.Layer) -> None:
7475
layer.hidden_size,
7576
layer.num_experts,
7677
layer.fd_config.parallel_config.splitwise_role,
78+
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
7779
layer.ep_size,
7880
layer.ep_rank,
7981
layer.fd_config.model_config.redundant_experts_num,
@@ -165,8 +167,10 @@ def apply(
165167
"""
166168
if layer.ep_size > 1:
167169
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
170+
self.ep_prefill_runner.clean_low_latency_buffer()
168171
return self.apply_ep_prefill(layer, x, gate_out)
169172
else:
173+
self.ep_decoder_runner.clean_low_latency_buffer()
170174
return self.apply_ep_decode(layer, x, gate_out)
171175
else:
172176
return self.apply_tp(layer, x, gate_out)

0 commit comments

Comments
 (0)