File tree Expand file tree Collapse file tree 2 files changed +21
-10
lines changed Expand file tree Collapse file tree 2 files changed +21
-10
lines changed Original file line number Diff line number Diff line change @@ -546,7 +546,8 @@ def set_splitting_ops_for_v1(self):
546
546
# full cudagraph outside the fx graph. This reduces some cpu
547
547
# overhead when the runtime batch_size is not cudagraph captured.
548
548
# see https://github.com/vllm-project/vllm/pull/20059 for details.
549
- self .splitting_ops = self ._attention_ops
549
+ # make a copy to avoid mutating the class-level list via reference.
550
+ self .splitting_ops = list (self ._attention_ops )
550
551
elif len (self .splitting_ops ) == 0 :
551
552
logger .warning_once ("Using piecewise compilation with empty "
552
553
"splitting_ops." )
@@ -561,6 +562,18 @@ def set_splitting_ops_for_v1(self):
561
562
self .cudagraph_mode = CUDAGraphMode .FULL
562
563
self .splitting_ops = []
563
564
565
+ if envs .VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" :
566
+ # exclude MoE dispatch/combine from capture by ensuring
567
+ # piecewise splitting includes them, so communication remains
568
+ # outside CUDA graphs while compute can still be graphed.
569
+ moe_ops = [
570
+ "vllm.moe_forward" ,
571
+ "vllm.moe_forward_shared" ,
572
+ ]
573
+ for op in moe_ops :
574
+ if op not in self .splitting_ops :
575
+ self .splitting_ops .append (op )
576
+
564
577
def splitting_ops_contain_attention (self ) -> bool :
565
578
return self .splitting_ops is not None and all (
566
579
op in self .splitting_ops for op in self ._attention_ops )
Original file line number Diff line number Diff line change @@ -183,16 +183,14 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
183
183
compilation_config = vllm_config .compilation_config
184
184
if (envs .VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
185
185
and parallel_config .data_parallel_size > 1
186
- and compilation_config .cudagraph_mode != CUDAGraphMode .NONE ):
186
+ and compilation_config .cudagraph_mode
187
+ not in [CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE ]):
187
188
logger .info (
188
- "Data Parallel: disabling cudagraphs since DP "
189
- "with DeepEP high-throughput kernels are not CUDA Graph "
190
- "compatible. The DeepEP low-latency kernels are CUDA Graph "
191
- "compatible. Set the all_to_all backend to deepep_low_latency "
192
- "to use those kernels instead." )
193
- compilation_config .cudagraph_mode = CUDAGraphMode .NONE
194
- if model_config is not None :
195
- model_config .enforce_eager = True
189
+ "Data Parallel with DeepEP high-throughput: using PIECEWISE "
190
+ "CUDA graphs and excluding MoE ops from capture. Set "
191
+ "VLLM_ALL2ALL_BACKEND=deepep_low_latency if you need MoE "
192
+ "graphs captured as well." )
193
+ compilation_config .cudagraph_mode = CUDAGraphMode .PIECEWISE
196
194
197
195
@classmethod
198
196
def get_current_memory_usage (cls ,
You can’t perform that action at this time.
0 commit comments