Skip to content

Commit 496b419

Browse files
authored
[None][doc] Add doc for torch.compile & piecewise cuda graph (#8527)
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
1 parent db99a93 commit 496b419

File tree

7 files changed

+393
-1
lines changed

7 files changed

+393
-1
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ repos:
6666
additional_dependencies:
6767
- tomli
6868
# add ignore words list
69-
args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md", "--skip", "security_scanning/*"]
69+
args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"]
7070
- repo: https://github.com/astral-sh/ruff-pre-commit
7171
rev: v0.9.4
7272
hooks:

docs/source/features/torch_compile_and_piecewise_cuda_graph.md

Lines changed: 363 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ Welcome to TensorRT LLM's Documentation!
7575
features/checkpoint-loading.md
7676
features/auto_deploy/auto-deploy.md
7777
features/ray-orchestrator.md
78+
features/torch_compile_and_piecewise_cuda_graph.md
7879

7980
.. toctree::
8081
:maxdepth: 2

docs/source/media/current_model_definition_ds.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/source/media/custom_backend_overview.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/source/media/piecewise_runner.svg

Lines changed: 4 additions & 0 deletions
Loading

tensorrt_llm/llmapi/llm_args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,6 +2400,10 @@ def validate_torch_compile_max_num_streams(cls, v):
24002400
"torch_compile_config.max_num_streams must be >= 1")
24012401
return v
24022402

2403+
@staticmethod
2404+
def _generate_capture_num_tokens() -> List[int]:
2405+
return [2**i for i in range(8)] + [i for i in range(256, 3073, 256)]
2406+
24032407

24042408
class TorchLlmArgs(BaseLlmArgs):
24052409
# Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs
@@ -2715,6 +2719,18 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
27152719

27162720
return self
27172721

2722+
@model_validator(mode='after')
2723+
def validate_torch_compile_config(self) -> 'TorchLlmArgs':
2724+
if self.torch_compile_config is None:
2725+
return self
2726+
2727+
config = self.torch_compile_config
2728+
if config.enable_piecewise_cuda_graph:
2729+
if config.capture_num_tokens is None:
2730+
config.capture_num_tokens = TorchCompileConfig._generate_capture_num_tokens(
2731+
)
2732+
return self
2733+
27182734
@model_validator(mode='after')
27192735
def sync_quant_config_with_kv_cache_config_dtype(self) -> 'TorchLlmArgs':
27202736
if self.kv_cache_config is None:

0 commit comments

Comments
 (0)