Skip to content

Commit 6a388b1

Browse files
authored
chore: remove torch_compile prefix for TorchCompileConfig field members (NVIDIA#5261)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
1 parent 2b23cd5 commit 6a388b1

File tree

3 files changed

+28
-35
lines changed

3 files changed

+28
-35
lines changed

examples/pytorch/quickstart_advanced.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,10 @@ def setup_llm(args):
200200
print_iter_log=args.print_iter_log,
201201
enable_iter_perf_stats=args.print_iter_log,
202202
torch_compile_config=TorchCompileConfig(
203-
torch_compile_fullgraph=args.use_torch_compile,
204-
torch_compile_inductor_enabled=args.use_torch_compile,
205-
torch_compile_piecewise_cuda_graph=args.use_piecewise_cuda_graph)
203+
enable_fullgraph=args.use_torch_compile,
204+
enable_inductor=args.use_torch_compile,
205+
enable_piecewise_cuda_graph= \
206+
args.use_piecewise_cuda_graph)
206207
if args.use_torch_compile else None,
207208
moe_backend=args.moe_backend,
208209
enable_trtllm_sampler=args.enable_trtllm_sampler,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,18 +1604,18 @@ class TorchCompileConfig(BaseModel):
16041604
"""
16051605
Configuration for torch.compile.
16061606
"""
1607-
torch_compile_fullgraph: bool = Field(
1607+
enable_fullgraph: bool = Field(
16081608
default=True,
16091609
description="Enable full graph compilation in torch.compile.")
16101610

1611-
torch_compile_inductor_enabled: bool = Field(
1611+
enable_inductor: bool = Field(
16121612
default=False, description="Enable inductor backend in torch.compile.")
16131613

1614-
torch_compile_piecewise_cuda_graph: bool = Field(
1614+
enable_piecewise_cuda_graph: bool = Field(
16151615
default=False,
16161616
description="Enable piecewise CUDA graph in torch.compile.")
16171617

1618-
torch_compile_enable_userbuffers: bool = Field(
1618+
enable_userbuffers: bool = Field(
16191619
default=True,
16201620
description=
16211621
"When torch compile is enabled, userbuffers is enabled by default.")
@@ -1794,17 +1794,15 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
17941794
enable_iter_req_stats=self.enable_iter_req_stats,
17951795
print_iter_log=self.print_iter_log,
17961796
torch_compile_enabled=bool(self.torch_compile_config is not None),
1797-
torch_compile_fullgraph=self.torch_compile_config.
1798-
torch_compile_fullgraph
1797+
torch_compile_fullgraph=self.torch_compile_config.enable_fullgraph
17991798
if self.torch_compile_config is not None else True,
18001799
torch_compile_inductor_enabled=self.torch_compile_config.
1801-
torch_compile_inductor_enabled
1802-
if self.torch_compile_config is not None else False,
1800+
enable_inductor if self.torch_compile_config is not None else False,
18031801
torch_compile_piecewise_cuda_graph=self.torch_compile_config.
1804-
torch_compile_piecewise_cuda_graph
1802+
enable_piecewise_cuda_graph
18051803
if self.torch_compile_config is not None else False,
18061804
torch_compile_enable_userbuffers=self.torch_compile_config.
1807-
torch_compile_enable_userbuffers
1805+
enable_userbuffers
18081806
if self.torch_compile_config is not None else True,
18091807
autotuner_enabled=self.autotuner_enabled,
18101808
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_chunked_prefill(self, attn_backend):
8080
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
8181
def test_bfloat16(self, attn_backend, torch_compile):
8282
torch_compile_config = TorchCompileConfig(
83-
torch_compile_fullgraph=True) if torch_compile else None
83+
enable_fullgraph=True) if torch_compile else None
8484
pytorch_config = dict(
8585
torch_compile_config=torch_compile_config,
8686
cuda_graph_padding_enabled=torch_compile,
@@ -109,7 +109,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend,
109109
"Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be "
110110
"discarded at graph breaks.")
111111
torch_compile_config = TorchCompileConfig(
112-
torch_compile_fullgraph=True) if torch_compile else None
112+
enable_fullgraph=True) if torch_compile else None
113113
pytorch_config = dict(
114114
torch_compile_config=torch_compile_config,
115115
cuda_graph_padding_enabled=torch_compile,
@@ -136,7 +136,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend,
136136
def test_fp8(self, fp8kv, attn_backend, torch_compile):
137137
quant_config = QuantConfig(QuantAlgo.FP8)
138138
torch_compile_config = TorchCompileConfig(
139-
torch_compile_fullgraph=True) if torch_compile else None
139+
enable_fullgraph=True) if torch_compile else None
140140
pytorch_config = dict(
141141
torch_compile_config=torch_compile_config,
142142
cuda_graph_padding_enabled=torch_compile,
@@ -177,7 +177,7 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
177177
"discarded at graph breaks.")
178178
quant_config = QuantConfig(QuantAlgo.FP8)
179179
torch_compile_config = TorchCompileConfig(
180-
torch_compile_fullgraph=True) if torch_compile else None
180+
enable_fullgraph=True) if torch_compile else None
181181
pytorch_config = dict(
182182
torch_compile_config=torch_compile_config,
183183
cuda_graph_padding_enabled=torch_compile,
@@ -505,9 +505,8 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
505505
pytest.skip("https://nvbugs/5252559")
506506
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
507507
torch_compile_config = TorchCompileConfig(
508-
torch_compile_fullgraph=True,
509-
torch_compile_piecewise_cuda_graph=cuda_graph
510-
) if torch_compile else None
508+
enable_fullgraph=True,
509+
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
511510
pytorch_config = dict(
512511
disable_overlap_scheduler=not overlap_scheduler,
513512
use_cuda_graph=cuda_graph,
@@ -552,9 +551,8 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
552551
pytest.skip("PP with torch.compile is not supported yet.")
553552
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
554553
torch_compile_config = TorchCompileConfig(
555-
torch_compile_fullgraph=True,
556-
torch_compile_piecewise_cuda_graph=cuda_graph
557-
) if torch_compile else None
554+
enable_fullgraph=True,
555+
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
558556
pytorch_config = dict(
559557
disable_overlap_scheduler=not overlap_scheduler,
560558
use_cuda_graph=cuda_graph,
@@ -597,9 +595,8 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
597595
pytest.skip("https://nvbugs/5252559")
598596
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
599597
torch_compile_config = TorchCompileConfig(
600-
torch_compile_fullgraph=True,
601-
torch_compile_piecewise_cuda_graph=cuda_graph
602-
) if torch_compile else None
598+
enable_fullgraph=True,
599+
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
603600
pytorch_config = dict(
604601
disable_overlap_scheduler=not overlap_scheduler,
605602
use_cuda_graph=cuda_graph,
@@ -719,9 +716,8 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
719716
pytest.skip("PP with torch.compile is not supported yet.")
720717
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
721718
torch_compile_config = TorchCompileConfig(
722-
torch_compile_fullgraph=True,
723-
torch_compile_piecewise_cuda_graph=cuda_graph
724-
) if torch_compile else None
719+
enable_fullgraph=True,
720+
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
725721
pytorch_config = dict(
726722
disable_overlap_scheduler=not overlap_scheduler,
727723
use_cuda_graph=cuda_graph,
@@ -808,9 +804,8 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
808804
pytest.skip("https://nvbugs/5252559")
809805
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
810806
torch_compile_config = TorchCompileConfig(
811-
torch_compile_fullgraph=True,
812-
torch_compile_piecewise_cuda_graph=cuda_graph
813-
) if torch_compile else None
807+
enable_fullgraph=True,
808+
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
814809
pytorch_config = dict(
815810
disable_overlap_scheduler=not overlap_scheduler,
816811
use_cuda_graph=cuda_graph,
@@ -866,9 +861,8 @@ def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
866861
pytest.skip("https://nvbugs/5336321")
867862
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
868863
torch_compile_config = TorchCompileConfig(
869-
torch_compile_fullgraph=True,
870-
torch_compile_piecewise_cuda_graph=cuda_graph
871-
) if torch_compile else None
864+
enable_fullgraph=True,
865+
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
872866
pytorch_config = dict(
873867
disable_overlap_scheduler=not overlap_scheduler,
874868
use_cuda_graph=cuda_graph,

0 commit comments

Comments
 (0)