@@ -64,6 +64,8 @@ def __init__(
64
64
num_max_dispatch_tokens_per_rank : int ,
65
65
splitwise_role : str ,
66
66
moe_phase : MoEPhase ,
67
+ use_internode_ll_two_stage : bool = False ,
68
+ top_k : int = 8 ,
67
69
):
68
70
self .group = group
69
71
self .hidden_size = hidden_size
@@ -72,6 +74,8 @@ def __init__(
72
74
self .num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
73
75
self .splitwise_role = splitwise_role
74
76
self .moe_phase = moe_phase
77
+ self .use_internode_ll_two_stage = use_internode_ll_two_stage
78
+ self .top_k = top_k
75
79
76
80
self .deepep_buffer = None
77
81
self .num_nvl_bytes = 0
@@ -95,12 +99,26 @@ def _compute_buffer_sizes(self, param_bytes: int = 2):
95
99
)
96
100
97
101
if self .splitwise_role == "mixed" or self .moe_phase .phase == "decode" :
98
- num_rdma_bytes = deep_ep .Buffer .get_low_latency_rdma_size_hint (
99
- self .num_max_dispatch_tokens_per_rank ,
100
- self .hidden_size ,
101
- self .ep_size ,
102
- self .num_experts ,
103
- )
102
+ if not self .use_internode_ll_two_stage :
103
+ num_rdma_bytes = deep_ep .Buffer .get_low_latency_rdma_size_hint (
104
+ self .num_max_dispatch_tokens_per_rank ,
105
+ self .hidden_size ,
106
+ self .ep_size ,
107
+ self .num_experts ,
108
+ )
109
+ else :
110
+ num_rdma_bytes = deep_ep .Buffer .get_low_latency_rdma_size_hint_two_stage (
111
+ self .num_max_dispatch_tokens_per_rank , self .hidden_size , self .ep_size , self .num_experts , self .top_k
112
+ )
113
+ num_nvl_bytes = deep_ep .Buffer .get_low_latency_nvl_size_hint_two_stage (
114
+ self .num_max_dispatch_tokens_per_rank ,
115
+ self .hidden_size ,
116
+ self .ep_size ,
117
+ self .num_experts ,
118
+ self .top_k ,
119
+ True , # just supports dispatch_use_fp8 = True now!
120
+ )
121
+ self .num_nvl_bytes = max (self .num_nvl_bytes , num_nvl_bytes )
104
122
self .num_rdma_bytes = max (self .num_rdma_bytes , num_rdma_bytes )
105
123
106
124
logger .info (f"DeepEP num nvl bytes : { self .num_nvl_bytes } , num rdma bytes : { self .num_rdma_bytes } " )
@@ -172,11 +190,21 @@ def get_buffer(self):
172
190
173
191
def clean_low_latency_buffer (self ):
174
192
if self .deepep_buffer is not None :
175
- self .deepep_buffer .clean_low_latency_buffer (
176
- self .num_max_dispatch_tokens_per_rank ,
177
- self .hidden_size ,
178
- self .num_experts ,
179
- )
193
+ if not self .use_internode_ll_two_stage :
194
+ self .deepep_buffer .clean_low_latency_buffer (
195
+ self .num_max_dispatch_tokens_per_rank ,
196
+ self .hidden_size ,
197
+ self .num_experts ,
198
+ )
199
+ else :
200
+ self .deepep_buffer .clean_low_latency_two_stage_buffer (
201
+ self .num_max_dispatch_tokens_per_rank ,
202
+ self .hidden_size ,
203
+ self .num_experts ,
204
+ self .top_k ,
205
+ self .ep_size ,
206
+ True , # just supports dispatch_use_fp8 = True now!
207
+ )
180
208
181
209
def barrier_all (self ):
182
210
if self .deepep_buffer is not None :
@@ -201,6 +229,8 @@ def __init__(
201
229
moe_phase : MoEPhase ,
202
230
async_finish : bool = False ,
203
231
group = None ,
232
+ use_internode_ll_two_stage : bool = False ,
233
+ top_k : int = 8 ,
204
234
):
205
235
if group is None :
206
236
group = paddle .distributed .new_group (range (ep_size ))
@@ -210,10 +240,10 @@ def __init__(
210
240
self .hidden_size = hidden_size
211
241
self .num_experts = num_experts
212
242
self .num_local_experts = num_experts // ep_size
243
+ self .top_k = top_k
213
244
self .async_finish = async_finish
214
- from paddle .base .core import Config
215
245
216
- self .ep_config = Config ( 24 , 6 , 256 )
246
+ self .ep_config = None
217
247
218
248
# Store phase and role for buffer management
219
249
self ._splitwise_role = splitwise_role
@@ -228,6 +258,8 @@ def __init__(
228
258
num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank ,
229
259
splitwise_role = splitwise_role ,
230
260
moe_phase = moe_phase ,
261
+ use_internode_ll_two_stage = use_internode_ll_two_stage ,
262
+ top_k = self .top_k ,
231
263
)
232
264
self .buffer .create_buffer ()
233
265
@@ -274,6 +306,37 @@ def low_latency_dispatch(
274
306
275
307
return packed_recv_x , recv_expert_count , handle , dispatch_hook
276
308
309
+ def low_latency_dispatch_two_stage (
310
+ self ,
311
+ hidden_states : paddle .Tensor ,
312
+ topk_idx : paddle .Tensor ,
313
+ topk_weights : paddle .Tensor ,
314
+ expertwise_scale ,
315
+ use_fp8 : bool = False ,
316
+ ):
317
+ if self .deepep_engine is None :
318
+ raise RuntimeError ("DeepEP buffer not initialized!" )
319
+
320
+ (
321
+ packed_recv_x ,
322
+ packed_recv_count ,
323
+ _ ,
324
+ handle ,
325
+ _ ,
326
+ dispatch_hook ,
327
+ ) = self .deepep_engine .low_latency_dispatch_two_stage (
328
+ hidden_states ,
329
+ topk_idx ,
330
+ topk_weights ,
331
+ self .buffer .num_max_dispatch_tokens_per_rank ,
332
+ self .num_experts ,
333
+ use_fp8 = use_fp8 ,
334
+ async_finish = False ,
335
+ return_recv_hook = True ,
336
+ )
337
+
338
+ return packed_recv_x , packed_recv_count , handle , dispatch_hook
339
+
277
340
def low_latency_combine (
278
341
self ,
279
342
hidden_states : paddle .Tensor ,
@@ -300,6 +363,28 @@ def low_latency_combine(
300
363
)
301
364
return combined_hidden_states , combine_hook
302
365
366
+ def low_latency_combine_two_stage (
367
+ self ,
368
+ hidden_states : paddle .Tensor ,
369
+ topk_idx : paddle .Tensor ,
370
+ topk_weights : paddle .Tensor ,
371
+ dispatch_use_fp8 : bool ,
372
+ handle ,
373
+ ):
374
+ if self .deepep_engine is None :
375
+ raise RuntimeError ("DeepEP buffer not initialized!" )
376
+
377
+ combined_hidden_states , _ , combine_hook = self .deepep_engine .low_latency_combine_two_stage (
378
+ hidden_states ,
379
+ topk_idx ,
380
+ topk_weights ,
381
+ handle ,
382
+ async_finish = False ,
383
+ dispatch_use_fp8 = dispatch_use_fp8 ,
384
+ return_recv_hook = True ,
385
+ )
386
+ return combined_hidden_states , combine_hook
387
+
303
388
def clean_low_latency_buffer (self ):
304
389
self .buffer .clean_low_latency_buffer ()
305
390
@@ -324,10 +409,12 @@ def __init__(
324
409
ep_rank : int = 0 ,
325
410
redundant_experts_num : int = 0 ,
326
411
ep_group = None ,
412
+ use_internode_ll_two_stage : bool = False ,
327
413
):
328
414
self .top_k = top_k
329
415
self .num_experts = num_experts
330
416
self .redundant_experts_num = redundant_experts_num
417
+ self .use_internode_ll_two_stage = use_internode_ll_two_stage
331
418
self .ep_engine = DeepEPEngine (
332
419
num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank ,
333
420
hidden_size = hidden_size ,
@@ -337,6 +424,8 @@ def __init__(
337
424
splitwise_role = splitwise_role ,
338
425
moe_phase = moe_phase ,
339
426
group = ep_group ,
427
+ use_internode_ll_two_stage = self .use_internode_ll_two_stage ,
428
+ top_k = self .top_k ,
340
429
)
341
430
342
431
def moe_select (self , layer : nn .Layer , gate_out : paddle .Tensor ):
@@ -416,6 +505,7 @@ def __init__(
416
505
redundant_experts_num : int = 0 ,
417
506
moe_phase : MoEPhase = MoEPhase ("prefill" ),
418
507
ep_group = None ,
508
+ use_internode_ll_two_stage : bool = False ,
419
509
):
420
510
super ().__init__ (
421
511
top_k ,
@@ -428,6 +518,7 @@ def __init__(
428
518
ep_rank = ep_rank ,
429
519
redundant_experts_num = redundant_experts_num ,
430
520
ep_group = ep_group ,
521
+ use_internode_ll_two_stage = use_internode_ll_two_stage ,
431
522
)
432
523
433
524
def dispatch (
@@ -502,6 +593,7 @@ def __init__(
502
593
redundant_experts_num : int = 0 ,
503
594
ep_group = None ,
504
595
moe_phase : MoEPhase = MoEPhase ("decode" ),
596
+ use_internode_ll_two_stage : bool = False ,
505
597
):
506
598
super ().__init__ (
507
599
top_k ,
@@ -514,6 +606,7 @@ def __init__(
514
606
ep_rank = ep_rank ,
515
607
redundant_experts_num = redundant_experts_num ,
516
608
ep_group = ep_group ,
609
+ use_internode_ll_two_stage = use_internode_ll_two_stage ,
517
610
)
518
611
519
612
def dispatch (
@@ -527,18 +620,30 @@ def dispatch(
527
620
expertwise_scale = kwargs .get ("expertwise_scale" , None )
528
621
use_fp8 = kwargs .get ("use_fp8" , False )
529
622
530
- recv_hidden_states , recv_expert_count , handle , dispatch_hook = self .ep_engine .low_latency_dispatch (
531
- x , topk_idx , expertwise_scale , use_fp8
532
- )
623
+ if not self .use_internode_ll_two_stage :
624
+ recv_hidden_states , recv_expert_count , handle , dispatch_hook = self .ep_engine .low_latency_dispatch (
625
+ x , topk_idx , expertwise_scale , use_fp8
626
+ )
627
+ else :
628
+ # just supports dispatch_use_fp8 = True now!
629
+ assert use_fp8 is True
630
+ recv_hidden_states , recv_expert_count , handle , dispatch_hook = (
631
+ self .ep_engine .low_latency_dispatch_two_stage (x , topk_idx , topk_weights , expertwise_scale , use_fp8 )
632
+ )
533
633
if dispatch_hook is not None :
534
634
dispatch_hook ()
535
635
536
636
return recv_hidden_states , recv_expert_count , handle
537
637
538
638
def combine (self , ffn_out , topk_idx , topk_weights , handle ):
539
- combined_hidden_states , combine_hook = self .ep_engine .low_latency_combine (
540
- ffn_out , topk_idx , topk_weights , handle
541
- )
639
+ if not self .use_internode_ll_two_stage :
640
+ combined_hidden_states , combine_hook = self .ep_engine .low_latency_combine (
641
+ ffn_out , topk_idx , topk_weights , handle
642
+ )
643
+ else :
644
+ combined_hidden_states , combine_hook = self .ep_engine .low_latency_combine_two_stage (
645
+ ffn_out , topk_idx , topk_weights , True , handle # just supports dispatch_use_fp8 = True now!
646
+ )
542
647
if combine_hook is not None :
543
648
combine_hook ()
544
649
0 commit comments