[#10052][feat] AutoDeploy enable cudagraphs for flashinfer BatchDecode#10193
[#10052][feat] AutoDeploy enable cudagraphs for flashinfer BatchDecode#10193suyoggupta merged 15 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
📝 WalkthroughWalkthroughThis change introduces host-side preparation mechanisms for attention operations with CUDA-graph-aware decoding support. The flashinfer-python dependency is updated, interfaces extended with registration/execution methods for host preparation, and multiple components wired together to invoke these functions within the execution pipeline. Changes
Sequence DiagramsequenceDiagram
participant Executor as Executor
participant SeqInfo as SequenceInfo
participant KVCache as KV Cache Transform
participant Planner as FlashInferPlanner
participant Attention as FlashInferAttention
Note over KVCache,Attention: Setup Phase (Graph Capture)
KVCache->>SeqInfo: register_host_prepare_for_attention_forward(descriptor.host_prepare)
SeqInfo->>SeqInfo: Store in _host_prepare_functions
Note over Executor,Planner: Generate Phase
Executor->>SeqInfo: run_host_prepare_for_attention_forward()
activate SeqInfo
SeqInfo->>Planner: host_prepare_for_forward(sequence_info)
activate Planner
alt num_prefill == 0 (Generate-only)
Planner->>Planner: plan_generate_only()<br/>(precompute & cache wrapper)
end
Planner-->>SeqInfo: Return (host prep complete)
deactivate Planner
deactivate SeqInfo
Executor->>Planner: plan(batch_info)<br/>(uses cached wrappers)
Executor->>Attention: Execute forward<br/>(with CUDA-graph wrapper)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py (1)
229-233: Registration inside loop is functionally correct but redundant.Since
host_prepare_for_forwardis a classmethod, the same function reference is registered for each attention node. Thesetdeduplicates correctly, but this could be moved outside the loop for clarity.🔎 Suggested optimization
Move registration outside the loop since it's the same function for all attention nodes:
+ # Attention descriptor should register its host function with SequenceInfo. + # This function will be called before graph invocation. + cm.info.register_host_prepare_for_attention_forward( + attn_descriptor.host_prepare_for_forward + ) + # replace fused attention node with attention node that has kv cache num_cached_attn_replacements = 0 for idx, attn_node in enumerate(source_attn_nodes): # ... existing code ... self._insert_cached_attn_node( gm, attn_node, qkv, meta_nodes_std, meta_nodes_extra, cache_in_nodes, buffer_in_nodes, constants, ) - # Attention descriptor should register its host function with SequenceInfo. - # This function will be called before graph invocation. - cm.info.register_host_prepare_for_attention_forward( - attn_descriptor.host_prepare_for_forward - ) num_cached_attn_replacements += 1tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (1)
124-147: Consider adding a fallback or warning if no matching wrapper is found.The method silently does nothing if no
PlanParamswith matchingnum_seqis found. This could mask issues during debugging. Consider logging a warning or using a more direct lookup approach.🔎 Alternative with direct lookup
If
num_sequniquely identifies the wrapper, consider restructuringcached_cuda_graph_decode_wrappersto usenum_seqas a key for O(1) lookup, or add a warning if no match is found:def plan_generate_only( self, num_seq: int, cu_num_pages: torch.Tensor, cache_loc: torch.Tensor, last_page_len: torch.Tensor, ): + found = False for plan_params in self.cached_cuda_graph_decode_wrappers: if plan_params.num_seq == num_seq: wrapper = self.cached_cuda_graph_decode_wrappers[plan_params] fast_decode_plan( wrapper, cu_num_pages, cache_loc, last_page_len, plan_params.n_heads, plan_params.n_kv_heads, plan_params.head_dim, plan_params.page_size, q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, ) + found = True break + if not found: + ad_logger.debug(f"No cached wrapper found for {num_seq=}")
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
requirements.txttensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.pytensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.pytensorrt_llm/_torch/auto_deploy/shim/ad_executor.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used
Python files should use snake_case naming:some_file.py
Python classes should use PascalCase naming:class SomeClass
Python functions and methods should use snake_case naming:def my_awesome_function():
Python local variables should use snake_case naming:my_variable = ...
Python variable names that start with a number should be prefixed with 'k':k_99th_percentile = ...
Python global variables should use upper snake_case with prefix 'G':G_MY_GLOBAL = ...
Python constants should use upper snake_case naming:MY_CONSTANT = ...
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings in Python for classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible, using the else block for logic
Files:
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache.pytensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
**/*.{cpp,h,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification
Files:
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache.pytensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
🧠 Learnings (3)
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
Applied to files:
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache.pytensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
📚 Learning: 2025-08-21T09:41:49.347Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:2010-2045
Timestamp: 2025-08-21T09:41:49.347Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, updateSequenceCacheBlockOffsets is specifically for updating bookkeeping when blocks are added during the context phase, not for refreshing offsets after detach operations. During detach operations, GenerationRequest::removeFrontBlock handles the necessary cache block bookkeeping internally.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
📚 Learning: 2025-08-15T06:46:53.813Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:53.813Z
Learning: In the TensorRT-LLM KV cache manager, SWA (Sliding Window Attention) combined with beam search is currently in a broken/non-functional state and is planned for future rework. During preparatory refactoring phases, code related to SWA+beam search may intentionally remain in a non-working state until the broader rework is completed.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.pytensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
run_host_prepare_for_attention_forward(1070-1072)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (10)
num_pages(620-621)num_pages(624-637)device(123-125)device(528-529)reset(821-827)cache_loc(600-601)is_generate(616-617)host_prepare_for_forward(1252-1258)SequenceInfo(330-1072)get_host_view(149-158)tensorrt_llm/_torch/attention_backend/flashinfer.py (1)
page_size(200-204)tensorrt_llm/models/modeling_utils.py (1)
kv_dtype(470-480)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (1)
host_prepare_for_forward(470-482)
🪛 Ruff (0.14.10)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
473-473: Unpacked variable num_prefill_tokens is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
1252-1252: Unused class method argument: sequence_info
(ARG003)
🔇 Additional comments (6)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (2)
1065-1072: LGTM!The registration and execution pattern is well-designed. Using a
setensures each host prepare function is registered only once, preventing duplicate invocations.
1251-1258: No-op base implementation is intentional.The unused
sequence_infoargument is expected since this is a base class method providing a default no-op. Subclasses (likeFlashInferAttention) will override and use the argument. The static analysis hint ARG003 is a false positive in this context.tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)
643-644: LGTM!The placement of
run_host_prepare_for_attention_forward()is correct—after input preparation is complete and before the forward pass. This ensures host-side preparation (like FlashInfer planning for generate-only batches) happens with fully prepared input buffers.tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (3)
244-245: LGTM!Moving the planner reset after extracting batch metadata is appropriate. This ensures the reset happens at a consistent point in the metadata preparation flow.
74-97: LGTM!The
_init_decode_wrappermethod properly handles both CUDA graph and non-CUDA graph modes. Passing paged KV buffers during wrapper initialization for CUDA graph mode aligns with FlashInfer's requirements for graph capture.
178-194: LGTM!The CUDA graph wrapper caching logic during warm-up is well-structured. Caching per
PlanParamsallows for efficient reuse across multiple graph captures with different configurations.
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
|
/bot run |
|
PR_Github #29546 [ run ] triggered by Bot. Commit: |
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
|
/bot run |
|
PR_Github #29548 [ run ] triggered by Bot. Commit: |
|
PR_Github #29548 [ run ] completed with state
|
lucaslie
left a comment
There was a problem hiding this comment.
looks great. Just few comments
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Outdated
Show resolved
Hide resolved
|
PR_Github #29662 [ run ] completed with state |
|
/bot run |
|
PR_Github #29675 [ run ] triggered by Bot. Commit: |
|
@suyoggupta, if you don't have a ticket you can instead assign yourself to the PR and then add it to our board via the "Projects" tab. Adding those steps for you |
|
PR_Github #29675 [ run ] completed with state
|
|
/bot run |
|
PR_Github #29705 [ run ] triggered by Bot. Commit: |
|
PR_Github #29705 [ run ] completed with state
|
|
/bot run |
|
PR_Github #29731 [ run ] triggered by Bot. Commit: |
|
PR_Github #29731 [ run ] completed with state
|
|
/bot run |
|
PR_Github #29753 [ run ] triggered by Bot. Commit: |
|
PR_Github #29753 [ run ] completed with state
|
|
/bot run |
|
PR_Github #29771 [ run ] triggered by Bot. Commit: |
|
PR_Github #29771 [ run ] completed with state |
…hDecode (NVIDIA#10193) Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…hDecode (NVIDIA#10193) Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…hDecode (NVIDIA#10193) Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Daniil Kulko <kulkodaniil@gmail.com>
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.
Perf data for 8k/16k concurrency 1 (using trtllm-bench):
before:

tpot: 4.4ms
20% of the time went in flashinfer BatchDecode
after:

tpot: 3.5ms (-20%)