@@ -68,25 +68,20 @@ def __init__(
68
68
self .num_local_experts = num_experts // ep_size
69
69
self .async_finish = async_finish
70
70
71
- self .prefill_deepep_engine = None
72
- self .decode_deepep_engine = None
71
+ self .deepep_engine = None
73
72
74
73
self .ep_config = Config (24 , 6 , 256 )
75
74
self .num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
76
75
77
76
# In mixed EP mode on a single node, we dynamically switch between
78
77
# high throughput and low latency modes.
79
78
if splitwise_role == "mixed" :
80
- # decode engine
81
- logger .info ("Initializing Low Latency Buffer" )
82
- self .get_low_latency_buffer ()
83
- # prefill engine
84
- self .prefill_deepep_engine = deep_ep .Buffer (
79
+ self .deepep_engine = deep_ep .Buffer (
85
80
self .group ,
86
- int (5e8 ),
87
- 0 ,
88
- low_latency_mode = False ,
89
- num_qps_per_rank = 1 ,
81
+ int (2e9 ),
82
+ int ( 5e9 ) ,
83
+ low_latency_mode = True ,
84
+ num_qps_per_rank = 24 ,
90
85
)
91
86
# In disaggregated mode on mutiple nodes, we either use
92
87
# high throughput mode or low latency mode.
@@ -95,7 +90,7 @@ def __init__(
95
90
logger .info ("Initializing Low Latency Buffer" )
96
91
self .get_low_latency_buffer ()
97
92
elif moe_phase .phase == "prefill" :
98
- self .prefill_deepep_engine = deep_ep .Buffer (
93
+ self .deepep_engine = deep_ep .Buffer (
99
94
self .group ,
100
95
int (5e8 ),
101
96
0 ,
@@ -124,14 +119,14 @@ def get_low_latency_buffer(self):
124
119
)
125
120
# Allocate a buffer if not existed or not enough buffer size
126
121
if (
127
- self .decode_deepep_engine is None
128
- or self .decode_deepep_engine .group != self .group
129
- or not self .decode_deepep_engine .low_latency_mode
130
- or self .decode_deepep_engine .num_rdma_bytes < num_rdma_bytes
122
+ self .deepep_engine is None
123
+ or self .deepep_engine .group != self .group
124
+ or not self .deepep_engine .low_latency_mode
125
+ or self .deepep_engine .num_rdma_bytes < num_rdma_bytes
131
126
):
132
127
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
133
128
assert self .num_experts % self .ep_size == 0
134
- self .decode_deepep_engine = deep_ep .Buffer (
129
+ self .deepep_engine = deep_ep .Buffer (
135
130
self .group ,
136
131
0 ,
137
132
num_rdma_bytes ,
@@ -168,7 +163,7 @@ def low_latency_dispatch(
168
163
handle ,
169
164
_ ,
170
165
dispatch_hook ,
171
- ) = self .decode_deepep_engine .low_latency_dispatch (
166
+ ) = self .deepep_engine .low_latency_dispatch (
172
167
hidden_states ,
173
168
topk_idx ,
174
169
expertwise_scale ,
@@ -210,7 +205,7 @@ def low_latency_combine(
210
205
num_experts ,
211
206
)
212
207
213
- combined_hidden_states , _ , combine_hook = self .decode_deepep_engine .low_latency_combine (
208
+ combined_hidden_states , _ , combine_hook = self .deepep_engine .low_latency_combine (
214
209
hidden_states ,
215
210
topk_idx ,
216
211
topk_weights ,
@@ -224,19 +219,15 @@ def clean_low_latency_buffer(self):
224
219
"""
225
220
clean_low_latency_buffer
226
221
"""
227
- self .decode_deepep_engine .clean_low_latency_buffer (
222
+ self .deepep_engine .clean_low_latency_buffer (
228
223
self .num_max_dispatch_tokens_per_rank , self .hidden , self .num_experts
229
224
)
230
225
231
226
def barrier_all (self ):
232
227
"""
233
228
barrier_all
234
229
"""
235
- if self .prefill_deepep_engine is not None :
236
- self .prefill_deepep_engine .barrier_all ()
237
-
238
- if self .decode_deepep_engine is not None :
239
- self .decode_deepep_engine .barrier_all ()
230
+ self .deepep_engine .barrier_all ()
240
231
241
232
242
233
class EPRunner :
@@ -316,6 +307,9 @@ def combine(self, *args, **kwargs):
316
307
"""
317
308
raise NotImplementedError
318
309
310
+ def clean_low_latency_buffer (self ):
311
+ self .ep_engine .clean_low_latency_buffer ()
312
+
319
313
320
314
class EPPrefillRunner (EPRunner ):
321
315
"""
@@ -328,6 +322,7 @@ def __init__(
328
322
hidden : int ,
329
323
num_experts : int ,
330
324
splitwise_role : str ,
325
+ num_max_dispatch_tokens_per_rank : int ,
331
326
ep_size : int = 1 ,
332
327
ep_rank : int = 0 ,
333
328
redundant_experts_num : int = 0 ,
@@ -339,7 +334,7 @@ def __init__(
339
334
num_experts ,
340
335
splitwise_role ,
341
336
moe_phase ,
342
- num_max_dispatch_tokens_per_rank = 256 ,
337
+ num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank ,
343
338
ep_size = ep_size ,
344
339
ep_rank = ep_rank ,
345
340
redundant_experts_num = redundant_experts_num ,
@@ -359,7 +354,7 @@ def dispatch(
359
354
num_tokens_per_expert ,
360
355
is_token_in_rank ,
361
356
_ ,
362
- ) = self .ep_engine .prefill_deepep_engine .get_dispatch_layout (topk_idx , self .num_experts )
357
+ ) = self .ep_engine .deepep_engine .get_dispatch_layout (topk_idx , self .num_experts )
363
358
364
359
x_scale_tensor = kwargs .get ("x_scale_tensor" , None )
365
360
dispatch_args = {
@@ -372,7 +367,7 @@ def dispatch(
372
367
"topk_idx" : topk_idx ,
373
368
"topk_weights" : topk_weights ,
374
369
}
375
- return self .ep_engine .prefill_deepep_engine .dispatch (** dispatch_args )
370
+ return self .ep_engine .deepep_engine .dispatch (** dispatch_args )
376
371
377
372
def combine (
378
373
self ,
@@ -387,14 +382,14 @@ def combine(
387
382
"async_finish" : self .ep_engine .async_finish ,
388
383
"topk_weights" : recv_topk_weights ,
389
384
}
390
- fused_moe_out , _ , _ = self .ep_engine .prefill_deepep_engine .combine (** combine_args )
385
+ fused_moe_out , _ , _ = self .ep_engine .deepep_engine .combine (** combine_args )
391
386
392
387
return fused_moe_out
393
388
394
389
395
390
class EPDecoderRunner (EPRunner ):
396
391
"""
397
- EPPrefillRunner
392
+ EPDecoderRunner
398
393
"""
399
394
400
395
def __init__ (
0 commit comments