Skip to content

[#10052][feat] AutoDeploy enable cudagraphs for flashinfer BatchDecode#10193

Merged
suyoggupta merged 15 commits intoNVIDIA:mainfrom
nv-auto-deploy:sg/fi-cg
Dec 24, 2025
Merged

[#10052][feat] AutoDeploy enable cudagraphs for flashinfer BatchDecode#10193
suyoggupta merged 15 commits intoNVIDIA:mainfrom
nv-auto-deploy:sg/fi-cg

Conversation

@suyoggupta
Copy link
Collaborator

@suyoggupta suyoggupta commented Dec 22, 2025

Summary by CodeRabbit

Release Notes

  • New Features
    • Improved decoding performance with CUDA graph-aware execution and paged key-value cache optimization.

✏️ Tip: You can customize this high-level summary in your review settings.

Perf data for 8k/16k concurrency 1 (using trtllm-bench):

before:
tpot: 4.4ms
20% of the time went in flashinfer BatchDecode
Screenshot 2025-12-22 at 9 03 29 PM

after:
tpot: 3.5ms (-20%)
Screenshot 2025-12-22 at 9 06 55 PM

nvchenghaoz and others added 4 commits December 19, 2025 13:14
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
@suyoggupta suyoggupta changed the title Sg/fi cg [#10052][feat] AutoDeploy enable cudagraphs for flashinfer BatchDecode Dec 23, 2025
@suyoggupta suyoggupta marked this pull request as ready for review December 23, 2025 05:08
@suyoggupta suyoggupta requested review from a team as code owners December 23, 2025 05:08
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 23, 2025

📝 Walkthrough

Walkthrough

This change introduces host-side preparation mechanisms for attention operations with CUDA-graph-aware decoding support. The flashinfer-python dependency is updated, interfaces extended with registration/execution methods for host preparation, and multiple components wired together to invoke these functions within the execution pipeline.

Changes

Cohort / File(s) Summary
Dependency Update
requirements.txt
Updated flashinfer-python from 0.3.x-<0.4.0 to 0.4.0-<0.5.3, tightening version constraints with an upper bound.
Attention Interface
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Added host-side preparation registry to SequenceInfo: new _host_prepare_functions set, register_host_prepare_for_attention_forward(), and run_host_prepare_for_attention_forward() methods. Introduced host_prepare_for_forward() classmethod on AttentionDescriptor.
FlashInfer Implementation
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Added CUDA-graph-aware decoding: new paged KV cache buffers (paged_kv_indptr_buffer, paged_kv_indices_buffer, paged_kv_last_page_len_buffer); renamed cached_decode_wrappers to cached_cuda_graph_decode_wrappers; extended _init_decode_wrapper() with CUDA graph parameters; updated init_workspace() to allocate new buffers; added plan_generate_only() for generate-only batch precomputation; introduced host_prepare_for_forward() hooks on both _FlashInferPlanner and FlashInferAttention.
Integration Points
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py, tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Added invocation of run_host_prepare_for_attention_forward() in executor's input-building path; added registration of attention descriptor's host_prepare_for_forward() in KV cache transform immediately after inserting cached attention nodes.

Sequence Diagram

sequenceDiagram
    participant Executor as Executor
    participant SeqInfo as SequenceInfo
    participant KVCache as KV Cache Transform
    participant Planner as FlashInferPlanner
    participant Attention as FlashInferAttention
    
    Note over KVCache,Attention: Setup Phase (Graph Capture)
    KVCache->>SeqInfo: register_host_prepare_for_attention_forward(descriptor.host_prepare)
    SeqInfo->>SeqInfo: Store in _host_prepare_functions
    
    Note over Executor,Planner: Generate Phase
    Executor->>SeqInfo: run_host_prepare_for_attention_forward()
    activate SeqInfo
    SeqInfo->>Planner: host_prepare_for_forward(sequence_info)
    activate Planner
    alt num_prefill == 0 (Generate-only)
        Planner->>Planner: plan_generate_only()<br/>(precompute & cache wrapper)
    end
    Planner-->>SeqInfo: Return (host prep complete)
    deactivate Planner
    deactivate SeqInfo
    
    Executor->>Planner: plan(batch_info)<br/>(uses cached wrappers)
    Executor->>Attention: Execute forward<br/>(with CUDA-graph wrapper)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.57% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ⚠️ Warning The PR description is minimal and lacks required template sections. It provides only performance metrics without explaining the what, why, test coverage, or addressing the PR checklist items. Complete the PR description using the provided template. Add sections for Description (explaining changes and rationale), Test Coverage (listing relevant tests), and confirm PR Checklist items are addressed.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the main feature: enabling CUDA graphs for flashinfer BatchDecode within AutoDeploy, matching the PR's primary objective.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py (1)

229-233: Registration inside loop is functionally correct but redundant.

Since host_prepare_for_forward is a classmethod, the same function reference is registered for each attention node. The set deduplicates correctly, but this could be moved outside the loop for clarity.

🔎 Suggested optimization

Move registration outside the loop since it's the same function for all attention nodes:

+        # Attention descriptor should register its host function with SequenceInfo.
+        # This function will be called before graph invocation.
+        cm.info.register_host_prepare_for_attention_forward(
+            attn_descriptor.host_prepare_for_forward
+        )
+
         # replace fused attention node with attention node that has kv cache
         num_cached_attn_replacements = 0
         for idx, attn_node in enumerate(source_attn_nodes):
             # ... existing code ...
             
             self._insert_cached_attn_node(
                 gm,
                 attn_node,
                 qkv,
                 meta_nodes_std,
                 meta_nodes_extra,
                 cache_in_nodes,
                 buffer_in_nodes,
                 constants,
             )
-            # Attention descriptor should register its host function with SequenceInfo.
-            # This function will be called before graph invocation.
-            cm.info.register_host_prepare_for_attention_forward(
-                attn_descriptor.host_prepare_for_forward
-            )
             num_cached_attn_replacements += 1
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (1)

124-147: Consider adding a fallback or warning if no matching wrapper is found.

The method silently does nothing if no PlanParams with matching num_seq is found. This could mask issues during debugging. Consider logging a warning or using a more direct lookup approach.

🔎 Alternative with direct lookup

If num_seq uniquely identifies the wrapper, consider restructuring cached_cuda_graph_decode_wrappers to use num_seq as a key for O(1) lookup, or add a warning if no match is found:

     def plan_generate_only(
         self,
         num_seq: int,
         cu_num_pages: torch.Tensor,
         cache_loc: torch.Tensor,
         last_page_len: torch.Tensor,
     ):
+        found = False
         for plan_params in self.cached_cuda_graph_decode_wrappers:
             if plan_params.num_seq == num_seq:
                 wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
                 fast_decode_plan(
                     wrapper,
                     cu_num_pages,
                     cache_loc,
                     last_page_len,
                     plan_params.n_heads,
                     plan_params.n_kv_heads,
                     plan_params.head_dim,
                     plan_params.page_size,
                     q_data_type=plan_params.q_dtype,
                     kv_data_type=plan_params.kv_dtype,
                     sm_scale=plan_params.sm_scale,
                 )
+                found = True
                 break
+        if not found:
+            ad_logger.debug(f"No cached wrapper found for {num_seq=}")
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7421224 and 10ebfcb.

📒 Files selected for processing (5)
  • requirements.txt
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used
Python files should use snake_case naming: some_file.py
Python classes should use PascalCase naming: class SomeClass
Python functions and methods should use snake_case naming: def my_awesome_function():
Python local variables should use snake_case naming: my_variable = ...
Python variable names that start with a number should be prefixed with 'k': k_99th_percentile = ...
Python global variables should use upper snake_case with prefix 'G': G_MY_GLOBAL = ...
Python constants should use upper snake_case naming: MY_CONSTANT = ...
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings in Python for classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible, using the else block for logic

Files:

  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
**/*.{cpp,h,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification

Files:

  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
🧠 Learnings (3)
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
📚 Learning: 2025-08-21T09:41:49.347Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:2010-2045
Timestamp: 2025-08-21T09:41:49.347Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, updateSequenceCacheBlockOffsets is specifically for updating bookkeeping when blocks are added during the context phase, not for refreshing offsets after detach operations. During detach operations, GenerationRequest::removeFrontBlock handles the necessary cache block bookkeeping internally.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
📚 Learning: 2025-08-15T06:46:53.813Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:53.813Z
Learning: In the TensorRT-LLM KV cache manager, SWA (Sliding Window Attention) combined with beam search is currently in a broken/non-functional state and is planned for future rework. During preparatory refactoring phases, code related to SWA+beam search may intentionally remain in a non-working state until the broader rework is completed.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
  • run_host_prepare_for_attention_forward (1070-1072)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (10)
  • num_pages (620-621)
  • num_pages (624-637)
  • device (123-125)
  • device (528-529)
  • reset (821-827)
  • cache_loc (600-601)
  • is_generate (616-617)
  • host_prepare_for_forward (1252-1258)
  • SequenceInfo (330-1072)
  • get_host_view (149-158)
tensorrt_llm/_torch/attention_backend/flashinfer.py (1)
  • page_size (200-204)
tensorrt_llm/models/modeling_utils.py (1)
  • kv_dtype (470-480)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (1)
  • host_prepare_for_forward (470-482)
🪛 Ruff (0.14.10)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

473-473: Unpacked variable num_prefill_tokens is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

1252-1252: Unused class method argument: sequence_info

(ARG003)

🔇 Additional comments (6)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (2)

1065-1072: LGTM!

The registration and execution pattern is well-designed. Using a set ensures each host prepare function is registered only once, preventing duplicate invocations.


1251-1258: No-op base implementation is intentional.

The unused sequence_info argument is expected since this is a base class method providing a default no-op. Subclasses (like FlashInferAttention) will override and use the argument. The static analysis hint ARG003 is a false positive in this context.

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)

643-644: LGTM!

The placement of run_host_prepare_for_attention_forward() is correct—after input preparation is complete and before the forward pass. This ensures host-side preparation (like FlashInfer planning for generate-only batches) happens with fully prepared input buffers.

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (3)

244-245: LGTM!

Moving the planner reset after extracting batch metadata is appropriate. This ensures the reset happens at a consistent point in the metadata preparation flow.


74-97: LGTM!

The _init_decode_wrapper method properly handles both CUDA graph and non-CUDA graph modes. Passing paged KV buffers during wrapper initialization for CUDA graph mode aligns with FlashInfer's requirements for graph capture.


178-194: LGTM!

The CUDA graph wrapper caching logic during warm-up is well-structured. Caching per PlanParams allows for efficient reuse across multiple graph captures with different configurations.

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29546 [ run ] triggered by Bot. Commit: 980bc2b

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29548 [ run ] triggered by Bot. Commit: ad94d8f

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29548 [ run ] completed with state FAILURE. Commit: ad94d8f
/LLM/main/L0_MergeRequest_PR pipeline #22721 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Copy link
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great. Just few comments

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29662 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 12 PM PST on 12/23.

@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29675 [ run ] triggered by Bot. Commit: f5b0a52

@lucaslie
Copy link
Member

@suyoggupta, if you don't have a ticket you can instead assign yourself to the PR and then add it to our board via the "Projects" tab. Adding those steps for you

@lucaslie lucaslie moved this from Backlog to In review in AutoDeploy Board Dec 24, 2025
@tensorrt-cicd
Copy link
Collaborator

PR_Github #29675 [ run ] completed with state SUCCESS. Commit: f5b0a52
/LLM/main/L0_MergeRequest_PR pipeline #22792 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29705 [ run ] triggered by Bot. Commit: f1c8b27

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29705 [ run ] completed with state FAILURE. Commit: f1c8b27
/LLM/main/L0_MergeRequest_PR pipeline #22822 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29731 [ run ] triggered by Bot. Commit: f1c8b27

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29731 [ run ] completed with state FAILURE. Commit: f1c8b27
/LLM/main/L0_MergeRequest_PR pipeline #22844 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29753 [ run ] triggered by Bot. Commit: f1c8b27

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29753 [ run ] completed with state FAILURE. Commit: f1c8b27
/LLM/main/L0_MergeRequest_PR pipeline #22863 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29771 [ run ] triggered by Bot. Commit: b82fd92

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29771 [ run ] completed with state SUCCESS. Commit: b82fd92
/LLM/main/L0_MergeRequest_PR pipeline #22883 completed with status: 'SUCCESS'

Copy link
Collaborator

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@suyoggupta suyoggupta merged commit e2891a6 into NVIDIA:main Dec 24, 2025
5 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in AutoDeploy Board Dec 24, 2025
yingguo-trt pushed a commit to yingguo-trt/TensorRT-LLM that referenced this pull request Dec 25, 2025
…hDecode (NVIDIA#10193)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
JunyiXu-nv pushed a commit to JunyiXu-nv/TensorRT-LLM that referenced this pull request Dec 30, 2025
…hDecode (NVIDIA#10193)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
videodanchik pushed a commit to videodanchik/TensorRT-LLM that referenced this pull request Jan 14, 2026
…hDecode (NVIDIA#10193)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Daniil Kulko <kulkodaniil@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

6 participants