Skip to content

[None][feat] Use XQA JIT impl by default and mitigate perf loss with sliding window#10335

Merged
pengbowang-nv merged 3 commits intoNVIDIA:mainfrom
pengbowang-nv:dev-remove-xqa-precompile-path
Jan 15, 2026
Merged

[None][feat] Use XQA JIT impl by default and mitigate perf loss with sliding window#10335
pengbowang-nv merged 3 commits intoNVIDIA:mainfrom
pengbowang-nv:dev-remove-xqa-precompile-path

Conversation

@pengbowang-nv
Copy link
Collaborator

@pengbowang-nv pengbowang-nv commented Dec 30, 2025

Summary by CodeRabbit

  • New Features

    • Expanded token-per-page configuration options to support 32 tokens in addition to existing values.
  • Improvements

    • Enhanced sliding window attention masking logic for improved accuracy in specific inference scenarios.
    • Streamlined kernel implementation selection for more consistent performance behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

This PR uses XQA JIT for all cases.

Current XQA has 2 execution path, one is to use legacy pre-compiled cubin and the other is JIT.
With most kernel has been moved to JIT after https://github.com/NVIDIA/TensorRT-LLM/pull/6078 , remaining kernel are only speculative kernel withfp/bf16 kvcache and 64 % (num_q_head / num_kv_heads) != 0. Some affected models include: Mixtral-8x22B-Instruct-v0.1, Qwen2.5-VL-7B-Instruct, Llama-3.2-3B-Instruct, ...

If fully converted to JIT, we can avoid maintaining 2 separate host invoke code, and remove hundreds of cubins.

This PR remove pre-compiled cubin path from normal execution but keep the code and knob for a release. Then we remove pre-compiled cubin completely after 1-2 releases.

Also, this PR tried to mitigate a perf issue of sliding window with speculative after #8383 . Before this PR, enabling sliding window but not using it at runtime (pass max window size with a size larger than KV lens) will suffer a 5% perf loss compared to disabling sliding window feature completely. Below are some perf data between pre-compiled version, jit version, and jit version with the fix.

Using Qwen3-8B, spec length = 3, ISL=4096, OSL=1024

Time (relative, lower is better) Pre-compile(Legacy Cubin) JIT (Before mitigation) JIT JIT Without Sliding Window
SM90 1 0.912 0.983 0.981
SM89 1 0.986 0.991 0.991

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30138 [ run ] triggered by Bot. Commit: e58f22a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30138 [ run ] completed with state SUCCESS. Commit: e58f22a
/LLM/main/L0_MergeRequest_PR pipeline #23190 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30170 [ run ] triggered by Bot. Commit: e58f22a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30170 [ run ] completed with state SUCCESS. Commit: e58f22a
/LLM/main/L0_MergeRequest_PR pipeline #23217 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30266 [ run ] triggered by Bot. Commit: e58f22a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30266 [ run ] completed with state SUCCESS. Commit: e58f22a
/LLM/main/L0_MergeRequest_PR pipeline #23302 completed with status: 'SUCCESS'

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test

@pengbowang-nv pengbowang-nv force-pushed the dev-remove-xqa-precompile-path branch from dd2f2b4 to 3997de3 Compare January 14, 2026 04:52
@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31889 [ run ] triggered by Bot. Commit: 3997de3

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31889 [ run ] completed with state SUCCESS. Commit: 3997de3
/LLM/main/L0_MergeRequest_PR pipeline #24693 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
@pengbowang-nv pengbowang-nv force-pushed the dev-remove-xqa-precompile-path branch from 3997de3 to 5f55640 Compare January 14, 2026 10:01
@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test

@pengbowang-nv pengbowang-nv marked this pull request as ready for review January 14, 2026 10:03
@pengbowang-nv pengbowang-nv changed the title [None][feat] WIP: Use XQA JIT impl by default [None][feat] Use XQA JIT impl by default and mitigate perf loss with sliding window Jan 14, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 14, 2026

📝 Walkthrough

Walkthrough

These changes modify CUDA kernel generation and masking logic for XQA attention kernels. Updates include explicit C++ namespace declaration changes, expanded TOKENS_PER_PAGE configuration support (adding 32-token variant), introduction of a new sliding window + SPEC_DEC masking function, and simplified implementation selection logic that removes multi-branch platform-specific handling.

Changes

Cohort / File(s) Summary
Code generation namespace and configuration
cpp/kernels/xqa/gen_cubins.py
Replaces TRTLLM_NAMESPACE macros with explicit namespace tensorrt_llm { ... } declarations. Expands TOKENS_PER_PAGE supported values from {0, 64, 128} to {0, 32, 64, 128} in both standard and spec_dec code generation paths.
Kernel masking logic
cpp/kernels/xqa/mha.cu
Introduces new applyMaskFromInputSlidingAndSpecDec() device function for sliding window + SPEC_DEC masking scenarios. Adds runtime flag rtIsReallySliding to conditionally route between two masking paths. Updates computation of nbSeqItersWithoutMask under combined conditions.
Implementation selection
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp
Simplifies implementation selection by removing multi-branch logic that checked for Hopper XQA, SM120 MLA, and Ampere XQA support. Now uniformly selects JIT implementation regardless of platform capabilities.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning Description covers the problem (current dual execution paths), solution (remove pre-compiled cubin path), affected models, and includes performance data comparing pre-compiled vs JIT versions. However, Test Coverage section is empty, which is a required section. Complete the Test Coverage section by listing specific tests that validate the JIT implementation changes and the sliding window mitigation.
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly describes the main objective: switching XQA to use JIT implementation by default and mitigating a sliding window performance issue.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

🧹 Recent nitpick comments
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp (2)

76-90: LGTM - Simplified implementation selection to JIT by default.

The change correctly makes JIT the default implementation while preserving the environment variable override (XQA_ENABLE_JIT) for fallback to precompiled when needed. This aligns with the PR objective of reducing maintenance of the precompiled cubin path.

Note that mPrecompiledImpl is still created in the constructor (line 54) to support the environment variable fallback path. Consider adding a comment explaining this retention is intentional for the transition period, as mentioned in the PR objectives ("retaining the pre-compiled code and a knob for at least one release").


1-15: Update copyright year to reflect the latest modification.

The copyright header shows 2020-2023, but this file has meaningful modifications. Per coding guidelines, the copyright header should reflect the year of the latest meaningful modification.

cpp/kernels/xqa/gen_cubins.py (2)

92-103: Add closing namespace comment for generated code.

Per coding guidelines, closing braces of namespaces should have a comment indicating the namespace being closed. The generated C++ code at line 102 should include a comment like } // namespace tensorrt_llm.

Suggested fix
 cpp_file_suffex_text = R"""
 // clang-format on
 } // namespace kernels
-}
+} // namespace tensorrt_llm
 """

1-16: Consider updating copyright year.

The copyright header shows 2022-2025, and since meaningful modifications are being made, consider updating to reflect 2026.


📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e7882d5 and 5f55640.

📒 Files selected for processing (3)
  • cpp/kernels/xqa/gen_cubins.py
  • cpp/kernels/xqa/mha.cu
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp
🧰 Additional context used
📓 Path-based instructions (5)
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible
A variable that is not modified after its initialization should be declared as const
For naming of constants in C++, use uppercase snakecase with prefix 'k' (e.g., kDIGIT_NUM)
Except for 0, nullptr, true, and false, all other literals should only be used for variable initialization and not in comparisons or expressions
Use Allman indentation style for brace notation in C++ code
Put the semicolon for an empty for or while loop in a new line
The statement forming the body of a switch, while, do..while, or for statement must be a compound statement (use brace-delimited statements)
If and else statements should always be followed by brace-delimited statements, even if empty or a single statement
C++ filenames should use camelCase with first letter lowercase (e.g., thisIsAFilename.cpp)
All types (including class names) in C++ should use PascalCase with uppercase first letter (e.g., FooBarClass)
Local variables, methods, and namespaces in C++ should use camelCase with first letter lowercase (e.g., localFooBar)
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camelCase prefixed with 'g' (e.g., gDontUseGlobalFoos)
Non-magic-number global variables that are static or defined in an anonymous namespace should use camelCase prefixed with 's' (e.g., sMutableStaticGlobal)
Locally visible static variables should use camelCase with 's' as the first letter (e.g., static std::once_flag sFlag;)
Public, private, and protected class member variables should use camelCase prefixed with 'm' (e.g., mNbFooValues)
Do not use Hungarian notation in C++ except for 'apps hungarian' (e.g., 'nb' to indicate count: mNbLayers)
If a constructor parameter name conflicts with a public me...

Files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp
  • cpp/kernels/xqa/mha.cu
**/*.{cpp,cc,cxx,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cc,cxx,cu}: Use smart pointers for allocating objects on the heap in C++
Prefer unique_ptr for single resource ownership and shared_ptr for shared resource ownership in C++. Use weak_ptr only in exceptional cases
In C++ function calls where parameters are not obvious, use inline C comments to document the parameter (e.g., doSomeOperation(/* checkForErrors = */ false);)
Use the least forceful cast necessary in C++, or no cast if possible
Casting a pointer to void* in C++ should be implicit (except if removing const)
Casting in C++ should not remove any const or volatile qualification from the type of a pointer or reference
Do not use C-style casts (other than void casts) and functional notation casts (other than explicit constructor calls) in C++
Casting from void* to T* in C++ should be done with static_cast, not reinterpret_cast
Use reinterpret_cast in C++ as a last resort, where const_cast and static_cast won't work
Avoid dynamic_cast in C++
Do not use assignment operator in C++ subexpressions (e.g., x = y = z or if (x = y))
When practical, a C++ switch statement controlled by an enum should have a case for each enum value and not have a default clause
C++ switch statements should be well structured as structured multi-way branches, not as 'glorified gotos'
In C++ switch statements, prohibit fall-through except from one case label to another. Each case clause must be terminated with a break or throw
Do not end a C++ case clause with return; use break or throw instead
If a C++ switch clause is a compound statement, put the break inside the braces
Do not use C library functions in C++ whenever possible. Use C++ alternatives like brace initialization or std::fill_n() instead of memset()

Files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp
  • cpp/kernels/xqa/mha.cu
**/*.{h,hpp,hxx,cpp,cc,cxx,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All C++ class templates, function templates, class template member functions, and class template static members must be instantiated at least once

Files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp
  • cpp/kernels/xqa/mha.cu
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification

Files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp
  • cpp/kernels/xqa/mha.cu
  • cpp/kernels/xqa/gen_cubins.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g., some_file.py)
Python classes should use PascalCase (e.g., class SomeClass)
Python functions and methods should use snake_case (e.g., def my_awesome_function():)
Python local variables should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL)
Python constants should use upper snake_case (e.g., MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format """<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic

Files:

  • cpp/kernels/xqa/gen_cubins.py
🧠 Learnings (15)
📓 Common learnings
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
📚 Learning: 2025-08-15T06:46:54.897Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.

Applied to files:

  • cpp/kernels/xqa/mha.cu
📚 Learning: 2025-08-20T06:56:02.889Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.

Applied to files:

  • cpp/kernels/xqa/mha.cu
📚 Learning: 2025-09-23T15:01:00.070Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:15-17
Timestamp: 2025-09-23T15:01:00.070Z
Learning: In TensorRT-LLM NCCL device kernels, the <sstream> header is not needed as an explicit include in config.cu because it's provided transitively through other headers. Local compilation testing confirms this works without the explicit include.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-08-08T05:06:31.596Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:36-36
Timestamp: 2025-08-08T05:06:31.596Z
Learning: CUTLASS extension files (under cpp/tensorrt_llm/cutlass_extensions/) follow CUTLASS coding style conventions, including using `#pragma` once instead of TRTLLM_ prefixed header guards, even though they are .hpp files.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-09-23T15:01:00.070Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:15-17
Timestamp: 2025-09-23T15:01:00.070Z
Learning: In TensorRT-LLM NCCL device kernels (cpp/tensorrt_llm/kernels/nccl_device/config.cu), std::ostringstream is used but <sstream> doesn't need to be explicitly included because it's provided transitively through other headers like tensorrt_llm/common/cudaUtils.h or config.h. Local compilation testing confirms this works without the explicit include.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-09-23T15:13:48.819Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-09-23T14:58:05.372Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:42-49
Timestamp: 2025-09-23T14:58:05.372Z
Learning: In TensorRT-LLM NCCL device kernels (cpp/tensorrt_llm/kernels/nccl_device/), the token partitioning intentionally uses ceil-like distribution (same token_per_rank for all ranks) to ensure all ranks launch the same number of blocks. This is required for optimal NCCL device API barrier performance, even though it may launch extra blocks for non-existent tokens on later ranks. Runtime bounds checking in the kernel (blockID validation) handles the overshoot cases.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2026-01-06T03:07:15.754Z
Learnt from: CR
Repo: NVIDIA/TensorRT-LLM PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2026-01-06T03:07:15.754Z
Learning: Applies to **/*.{h,hpp,hxx} : Use a preprocessor guard in C++ header files with the format `TRTLLM_<FILENAME>_H`, where the filename is in uppercase with no underscores, no prefix underscores, and no trailing underscores

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-08-22T01:54:35.850Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 7104
File: cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h:999-1000
Timestamp: 2025-08-22T01:54:35.850Z
Learning: The `internal_cutlass_kernels` directory in TensorRT-LLM is a mirror of an internal NVIDIA repository and maintains its own implementation and API that may diverge from the public `cutlass_kernels` version. API inconsistencies between these two directories are intentional and by design, not bugs to be fixed.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2025-08-20T07:43:36.447Z
Learnt from: ChristinaZ
Repo: NVIDIA/TensorRT-LLM PR: 7068
File: cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh:169-172
Timestamp: 2025-08-20T07:43:36.447Z
Learning: In TensorRT-LLM MOE kernels, when processing up to 128 experts across 32 threads, each thread handles at most 4 experts (N < 5 constraint), where N represents candidates per thread rather than total system capacity.

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
📚 Learning: 2026-01-06T03:07:15.754Z
Learnt from: CR
Repo: NVIDIA/TensorRT-LLM PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2026-01-06T03:07:15.754Z
Learning: Applies to **/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py} : All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification

Applied to files:

  • cpp/kernels/xqa/gen_cubins.py
🧬 Code graph analysis (1)
cpp/kernels/xqa/mha.cu (1)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp (2)
  • divUp (63-66)
  • divUp (63-63)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (5)
cpp/kernels/xqa/mha.cu (4)

469-550: LGTM - New sliding window + spec dec masking function.

The new applyMaskFromInputSlidingAndSpecDec function correctly handles the combined masking logic for sliding window with speculative decoding. The early return optimization on line 488-491 when no masking is needed is a good performance optimization. The logic properly combines both the sliding window mask (begMask) and the speculative decoding mask (maskFlag).


553-604: LGTM - Simplified masking function for non-sliding window cases.

The refactored applyMaskFromInput function correctly handles speculative decoding masking without the sliding window complexity. This separation improves code clarity and allows the compiler to optimize each path independently.


1700-1724: Performance optimization for sliding window behavior.

The introduction of rtIsReallySliding is a good optimization that distinguishes between nominal sliding window configuration and actual sliding behavior at runtime. When the sequence length is shorter than the sliding window, the simpler masking path is used.

The nbSeqItersWithoutMask calculation at lines 1722-1723 correctly falls back to the simpler formula when not really sliding, which aligns with the performance mitigation mentioned in the PR objectives.


2007-2022: LGTM - Runtime path selection for masking.

The conditional logic correctly selects between the two masking functions based on rtIsReallySliding. This runtime branching enables the performance optimization when sliding window is configured but not actively constraining the attention window.

cpp/kernels/xqa/gen_cubins.py (1)

442-447: LGTM - Extended TOKENS_PER_PAGE options for spec_dec.

The expansion of TOKENS_PER_PAGE options from [0, 64, 128] to [0, 32, 64, 128] for speculative decoding configurations adds support for 32-token page sizes, aligning with the non-spec_dec configuration which already includes this option.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31960 [ run ] triggered by Bot. Commit: 5f55640

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31960 [ run ] completed with state SUCCESS. Commit: 5f55640
/LLM/main/L0_MergeRequest_PR pipeline #24757 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Copy link
Collaborator

@jhaotingc jhaotingc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for resolving perf issue when SLIDING_WINDOW=1!

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32052 [ run ] triggered by Bot. Commit: 5f55640

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32052 [ run ] completed with state SUCCESS. Commit: 5f55640
/LLM/main/L0_MergeRequest_PR pipeline #24841 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32077 [ run ] triggered by Bot. Commit: 5f55640

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32077 [ run ] completed with state SUCCESS. Commit: 5f55640
/LLM/main/L0_MergeRequest_PR pipeline #24861 completed with status: 'SUCCESS'

@pengbowang-nv pengbowang-nv merged commit 683515b into NVIDIA:main Jan 15, 2026
10 of 11 checks passed
greg-kwasniewski1 pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Jan 15, 2026
…sliding window (NVIDIA#10335)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
@pengbowang-nv pengbowang-nv deleted the dev-remove-xqa-precompile-path branch January 23, 2026 10:18
pengbowang-nv added a commit to pengbowang-nv/TensorRT-LLM that referenced this pull request Jan 23, 2026
…sliding window (NVIDIA#10335)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
pengbowang-nv added a commit that referenced this pull request Jan 26, 2026
…gate perf loss with sliding window (#10954)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 5, 2026
…d mitigate perf loss with sliding window (NVIDIA#10954)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 5, 2026
…d mitigate perf loss with sliding window (NVIDIA#10954)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 6, 2026
…d mitigate perf loss with sliding window (NVIDIA#10954)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 7, 2026
…d mitigate perf loss with sliding window (NVIDIA#10954)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 8, 2026
…d mitigate perf loss with sliding window (NVIDIA#10954)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 9, 2026
…d mitigate perf loss with sliding window (NVIDIA#10954)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 9, 2026
…d mitigate perf loss with sliding window (NVIDIA#10954)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants