@@ -80,7 +80,7 @@ def test_chunked_prefill(self, attn_backend):
8080 @parametrize_with_ids ("attn_backend" , ["TRTLLM" , "FLASHINFER" ])
8181 def test_bfloat16 (self , attn_backend , torch_compile ):
8282 torch_compile_config = TorchCompileConfig (
83- torch_compile_fullgraph = True ) if torch_compile else None
83+ enable_fullgraph = True ) if torch_compile else None
8484 pytorch_config = dict (
8585 torch_compile_config = torch_compile_config ,
8686 cuda_graph_padding_enabled = torch_compile ,
@@ -109,7 +109,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend,
109109 "Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be "
110110 "discarded at graph breaks." )
111111 torch_compile_config = TorchCompileConfig (
112- torch_compile_fullgraph = True ) if torch_compile else None
112+ enable_fullgraph = True ) if torch_compile else None
113113 pytorch_config = dict (
114114 torch_compile_config = torch_compile_config ,
115115 cuda_graph_padding_enabled = torch_compile ,
@@ -136,7 +136,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend,
136136 def test_fp8 (self , fp8kv , attn_backend , torch_compile ):
137137 quant_config = QuantConfig (QuantAlgo .FP8 )
138138 torch_compile_config = TorchCompileConfig (
139- torch_compile_fullgraph = True ) if torch_compile else None
139+ enable_fullgraph = True ) if torch_compile else None
140140 pytorch_config = dict (
141141 torch_compile_config = torch_compile_config ,
142142 cuda_graph_padding_enabled = torch_compile ,
@@ -177,7 +177,7 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
177177 "discarded at graph breaks." )
178178 quant_config = QuantConfig (QuantAlgo .FP8 )
179179 torch_compile_config = TorchCompileConfig (
180- torch_compile_fullgraph = True ) if torch_compile else None
180+ enable_fullgraph = True ) if torch_compile else None
181181 pytorch_config = dict (
182182 torch_compile_config = torch_compile_config ,
183183 cuda_graph_padding_enabled = torch_compile ,
@@ -505,9 +505,8 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
505505 pytest .skip ("https://nvbugs/5252559" )
506506 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.9 )
507507 torch_compile_config = TorchCompileConfig (
508- torch_compile_fullgraph = True ,
509- torch_compile_piecewise_cuda_graph = cuda_graph
510- ) if torch_compile else None
508+ enable_fullgraph = True ,
509+ enable_piecewise_cuda_graph = cuda_graph ) if torch_compile else None
511510 pytorch_config = dict (
512511 disable_overlap_scheduler = not overlap_scheduler ,
513512 use_cuda_graph = cuda_graph ,
@@ -552,9 +551,8 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
552551 pytest .skip ("PP with torch.compile is not supported yet." )
553552 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.9 )
554553 torch_compile_config = TorchCompileConfig (
555- torch_compile_fullgraph = True ,
556- torch_compile_piecewise_cuda_graph = cuda_graph
557- ) if torch_compile else None
554+ enable_fullgraph = True ,
555+ enable_piecewise_cuda_graph = cuda_graph ) if torch_compile else None
558556 pytorch_config = dict (
559557 disable_overlap_scheduler = not overlap_scheduler ,
560558 use_cuda_graph = cuda_graph ,
@@ -597,9 +595,8 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
597595 pytest .skip ("https://nvbugs/5252559" )
598596 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.9 )
599597 torch_compile_config = TorchCompileConfig (
600- torch_compile_fullgraph = True ,
601- torch_compile_piecewise_cuda_graph = cuda_graph
602- ) if torch_compile else None
598+ enable_fullgraph = True ,
599+ enable_piecewise_cuda_graph = cuda_graph ) if torch_compile else None
603600 pytorch_config = dict (
604601 disable_overlap_scheduler = not overlap_scheduler ,
605602 use_cuda_graph = cuda_graph ,
@@ -719,9 +716,8 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
719716 pytest .skip ("PP with torch.compile is not supported yet." )
720717 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.9 )
721718 torch_compile_config = TorchCompileConfig (
722- torch_compile_fullgraph = True ,
723- torch_compile_piecewise_cuda_graph = cuda_graph
724- ) if torch_compile else None
719+ enable_fullgraph = True ,
720+ enable_piecewise_cuda_graph = cuda_graph ) if torch_compile else None
725721 pytorch_config = dict (
726722 disable_overlap_scheduler = not overlap_scheduler ,
727723 use_cuda_graph = cuda_graph ,
@@ -808,9 +804,8 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
808804 pytest .skip ("https://nvbugs/5252559" )
809805 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.9 )
810806 torch_compile_config = TorchCompileConfig (
811- torch_compile_fullgraph = True ,
812- torch_compile_piecewise_cuda_graph = cuda_graph
813- ) if torch_compile else None
807+ enable_fullgraph = True ,
808+ enable_piecewise_cuda_graph = cuda_graph ) if torch_compile else None
814809 pytorch_config = dict (
815810 disable_overlap_scheduler = not overlap_scheduler ,
816811 use_cuda_graph = cuda_graph ,
@@ -866,9 +861,8 @@ def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
866861 pytest .skip ("https://nvbugs/5336321" )
867862 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.9 )
868863 torch_compile_config = TorchCompileConfig (
869- torch_compile_fullgraph = True ,
870- torch_compile_piecewise_cuda_graph = cuda_graph
871- ) if torch_compile else None
864+ enable_fullgraph = True ,
865+ enable_piecewise_cuda_graph = cuda_graph ) if torch_compile else None
872866 pytorch_config = dict (
873867 disable_overlap_scheduler = not overlap_scheduler ,
874868 use_cuda_graph = cuda_graph ,
0 commit comments