Skip to content

Commit be98f6e

Browse files
authored
supports internode_ll_two_stage (#4143)
* supports internode_ll_two_stage * supports internode_ll_two_stage * supports internode_ll_two_stage * supports internode_ll_two_stage
1 parent f75697c commit be98f6e

File tree

6 files changed

+144
-19
lines changed

6 files changed

+144
-19
lines changed

fastdeploy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ def __init__(
294294
self.engine_pid: Optional[int] = None
295295
# Do profile or not
296296
self.do_profile: bool = False
297+
# Use internode_ll_two_stage or not
298+
self.use_internode_ll_two_stage: bool = False
297299

298300
self.max_num_batched_tokens: int = 2048
299301
# splitwise role

fastdeploy/engine/args_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ class EngineArgs:
200200
Flag to enable the custom all-reduce kernel.
201201
"""
202202

203+
use_internode_ll_two_stage: bool = False
204+
"""
205+
Flag to use the internode_ll_two_stage kernel.
206+
"""
207+
203208
engine_worker_queue_port: str = "8002"
204209
"""
205210
Port for worker queue communication.
@@ -629,6 +634,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
629634
default=EngineArgs.disable_custom_all_reduce,
630635
help="Flag to disable custom all-reduce.",
631636
)
637+
parallel_group.add_argument(
638+
"--use-internode-ll-two-stage",
639+
action="store_true",
640+
default=EngineArgs.use_internode_ll_two_stage,
641+
help="Flag to use the internode_ll_two_stage kernel.",
642+
)
632643
parallel_group.add_argument(
633644
"--max-num-seqs",
634645
type=int,

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def _start_worker_service(self):
483483
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
484484
"disable_any_whitespace": self.cfg.disable_any_whitespace,
485485
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
486+
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
486487
"enable_logprob": self.cfg.model_config.enable_logprob,
487488
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
488489
}

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 124 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
num_max_dispatch_tokens_per_rank: int,
6565
splitwise_role: str,
6666
moe_phase: MoEPhase,
67+
use_internode_ll_two_stage: bool = False,
68+
top_k: int = 8,
6769
):
6870
self.group = group
6971
self.hidden_size = hidden_size
@@ -72,6 +74,8 @@ def __init__(
7274
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
7375
self.splitwise_role = splitwise_role
7476
self.moe_phase = moe_phase
77+
self.use_internode_ll_two_stage = use_internode_ll_two_stage
78+
self.top_k = top_k
7579

7680
self.deepep_buffer = None
7781
self.num_nvl_bytes = 0
@@ -95,12 +99,26 @@ def _compute_buffer_sizes(self, param_bytes: int = 2):
9599
)
96100

97101
if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode":
98-
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
99-
self.num_max_dispatch_tokens_per_rank,
100-
self.hidden_size,
101-
self.ep_size,
102-
self.num_experts,
103-
)
102+
if not self.use_internode_ll_two_stage:
103+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
104+
self.num_max_dispatch_tokens_per_rank,
105+
self.hidden_size,
106+
self.ep_size,
107+
self.num_experts,
108+
)
109+
else:
110+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint_two_stage(
111+
self.num_max_dispatch_tokens_per_rank, self.hidden_size, self.ep_size, self.num_experts, self.top_k
112+
)
113+
num_nvl_bytes = deep_ep.Buffer.get_low_latency_nvl_size_hint_two_stage(
114+
self.num_max_dispatch_tokens_per_rank,
115+
self.hidden_size,
116+
self.ep_size,
117+
self.num_experts,
118+
self.top_k,
119+
True, # just supports dispatch_use_fp8 = True now!
120+
)
121+
self.num_nvl_bytes = max(self.num_nvl_bytes, num_nvl_bytes)
104122
self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes)
105123

106124
logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}")
@@ -172,11 +190,21 @@ def get_buffer(self):
172190

173191
def clean_low_latency_buffer(self):
174192
if self.deepep_buffer is not None:
175-
self.deepep_buffer.clean_low_latency_buffer(
176-
self.num_max_dispatch_tokens_per_rank,
177-
self.hidden_size,
178-
self.num_experts,
179-
)
193+
if not self.use_internode_ll_two_stage:
194+
self.deepep_buffer.clean_low_latency_buffer(
195+
self.num_max_dispatch_tokens_per_rank,
196+
self.hidden_size,
197+
self.num_experts,
198+
)
199+
else:
200+
self.deepep_buffer.clean_low_latency_two_stage_buffer(
201+
self.num_max_dispatch_tokens_per_rank,
202+
self.hidden_size,
203+
self.num_experts,
204+
self.top_k,
205+
self.ep_size,
206+
True, # just supports dispatch_use_fp8 = True now!
207+
)
180208

181209
def barrier_all(self):
182210
if self.deepep_buffer is not None:
@@ -201,6 +229,8 @@ def __init__(
201229
moe_phase: MoEPhase,
202230
async_finish: bool = False,
203231
group=None,
232+
use_internode_ll_two_stage: bool = False,
233+
top_k: int = 8,
204234
):
205235
if group is None:
206236
group = paddle.distributed.new_group(range(ep_size))
@@ -210,10 +240,10 @@ def __init__(
210240
self.hidden_size = hidden_size
211241
self.num_experts = num_experts
212242
self.num_local_experts = num_experts // ep_size
243+
self.top_k = top_k
213244
self.async_finish = async_finish
214-
from paddle.base.core import Config
215245

216-
self.ep_config = Config(24, 6, 256)
246+
self.ep_config = None
217247

218248
# Store phase and role for buffer management
219249
self._splitwise_role = splitwise_role
@@ -228,6 +258,8 @@ def __init__(
228258
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
229259
splitwise_role=splitwise_role,
230260
moe_phase=moe_phase,
261+
use_internode_ll_two_stage=use_internode_ll_two_stage,
262+
top_k=self.top_k,
231263
)
232264
self.buffer.create_buffer()
233265

@@ -274,6 +306,37 @@ def low_latency_dispatch(
274306

275307
return packed_recv_x, recv_expert_count, handle, dispatch_hook
276308

309+
def low_latency_dispatch_two_stage(
310+
self,
311+
hidden_states: paddle.Tensor,
312+
topk_idx: paddle.Tensor,
313+
topk_weights: paddle.Tensor,
314+
expertwise_scale,
315+
use_fp8: bool = False,
316+
):
317+
if self.deepep_engine is None:
318+
raise RuntimeError("DeepEP buffer not initialized!")
319+
320+
(
321+
packed_recv_x,
322+
packed_recv_count,
323+
_,
324+
handle,
325+
_,
326+
dispatch_hook,
327+
) = self.deepep_engine.low_latency_dispatch_two_stage(
328+
hidden_states,
329+
topk_idx,
330+
topk_weights,
331+
self.buffer.num_max_dispatch_tokens_per_rank,
332+
self.num_experts,
333+
use_fp8=use_fp8,
334+
async_finish=False,
335+
return_recv_hook=True,
336+
)
337+
338+
return packed_recv_x, packed_recv_count, handle, dispatch_hook
339+
277340
def low_latency_combine(
278341
self,
279342
hidden_states: paddle.Tensor,
@@ -300,6 +363,28 @@ def low_latency_combine(
300363
)
301364
return combined_hidden_states, combine_hook
302365

366+
def low_latency_combine_two_stage(
367+
self,
368+
hidden_states: paddle.Tensor,
369+
topk_idx: paddle.Tensor,
370+
topk_weights: paddle.Tensor,
371+
dispatch_use_fp8: bool,
372+
handle,
373+
):
374+
if self.deepep_engine is None:
375+
raise RuntimeError("DeepEP buffer not initialized!")
376+
377+
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine_two_stage(
378+
hidden_states,
379+
topk_idx,
380+
topk_weights,
381+
handle,
382+
async_finish=False,
383+
dispatch_use_fp8=dispatch_use_fp8,
384+
return_recv_hook=True,
385+
)
386+
return combined_hidden_states, combine_hook
387+
303388
def clean_low_latency_buffer(self):
304389
self.buffer.clean_low_latency_buffer()
305390

@@ -324,10 +409,12 @@ def __init__(
324409
ep_rank: int = 0,
325410
redundant_experts_num: int = 0,
326411
ep_group=None,
412+
use_internode_ll_two_stage: bool = False,
327413
):
328414
self.top_k = top_k
329415
self.num_experts = num_experts
330416
self.redundant_experts_num = redundant_experts_num
417+
self.use_internode_ll_two_stage = use_internode_ll_two_stage
331418
self.ep_engine = DeepEPEngine(
332419
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
333420
hidden_size=hidden_size,
@@ -337,6 +424,8 @@ def __init__(
337424
splitwise_role=splitwise_role,
338425
moe_phase=moe_phase,
339426
group=ep_group,
427+
use_internode_ll_two_stage=self.use_internode_ll_two_stage,
428+
top_k=self.top_k,
340429
)
341430

342431
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
@@ -416,6 +505,7 @@ def __init__(
416505
redundant_experts_num: int = 0,
417506
moe_phase: MoEPhase = MoEPhase("prefill"),
418507
ep_group=None,
508+
use_internode_ll_two_stage: bool = False,
419509
):
420510
super().__init__(
421511
top_k,
@@ -428,6 +518,7 @@ def __init__(
428518
ep_rank=ep_rank,
429519
redundant_experts_num=redundant_experts_num,
430520
ep_group=ep_group,
521+
use_internode_ll_two_stage=use_internode_ll_two_stage,
431522
)
432523

433524
def dispatch(
@@ -502,6 +593,7 @@ def __init__(
502593
redundant_experts_num: int = 0,
503594
ep_group=None,
504595
moe_phase: MoEPhase = MoEPhase("decode"),
596+
use_internode_ll_two_stage: bool = False,
505597
):
506598
super().__init__(
507599
top_k,
@@ -514,6 +606,7 @@ def __init__(
514606
ep_rank=ep_rank,
515607
redundant_experts_num=redundant_experts_num,
516608
ep_group=ep_group,
609+
use_internode_ll_two_stage=use_internode_ll_two_stage,
517610
)
518611

519612
def dispatch(
@@ -527,18 +620,30 @@ def dispatch(
527620
expertwise_scale = kwargs.get("expertwise_scale", None)
528621
use_fp8 = kwargs.get("use_fp8", False)
529622

530-
recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch(
531-
x, topk_idx, expertwise_scale, use_fp8
532-
)
623+
if not self.use_internode_ll_two_stage:
624+
recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch(
625+
x, topk_idx, expertwise_scale, use_fp8
626+
)
627+
else:
628+
# just supports dispatch_use_fp8 = True now!
629+
assert use_fp8 is True
630+
recv_hidden_states, recv_expert_count, handle, dispatch_hook = (
631+
self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8)
632+
)
533633
if dispatch_hook is not None:
534634
dispatch_hook()
535635

536636
return recv_hidden_states, recv_expert_count, handle
537637

538638
def combine(self, ffn_out, topk_idx, topk_weights, handle):
539-
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
540-
ffn_out, topk_idx, topk_weights, handle
541-
)
639+
if not self.use_internode_ll_two_stage:
640+
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
641+
ffn_out, topk_idx, topk_weights, handle
642+
)
643+
else:
644+
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage(
645+
ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now!
646+
)
542647
if combine_hook is not None:
543648
combine_hook()
544649

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def init_ep(self, layer: nn.Layer) -> None:
6464
"ep_rank": layer.ep_rank,
6565
"redundant_experts_num": layer.fd_config.model_config.redundant_experts_num,
6666
"ep_group": layer.fd_config.parallel_config.ep_group,
67+
"use_internode_ll_two_stage": layer.fd_config.parallel_config.use_internode_ll_two_stage,
6768
}
6869

6970
config = layer.fd_config

fastdeploy/worker/worker_process.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,11 @@ def parse_args():
506506
action="store_true",
507507
help="enable chunked prefill",
508508
)
509+
parser.add_argument(
510+
"--use_internode_ll_two_stage",
511+
action="store_true",
512+
help="enable internode_ll_two_stage",
513+
)
509514
parser.add_argument(
510515
"--speculative_config",
511516
type=json.loads,

0 commit comments

Comments
 (0)