Skip to content

Conversation

@shivghai
Copy link
Contributor

@shivghai shivghai commented Dec 12, 2025

Description

Gemma3 has local and global attention layers. Right now, RoPE is incorrectly applied for local layers, this PR fixes that. The issue is not apparent when dealing with shorter sequences - but once the sequence is long enough (longer than 1024 tokens minimum, for a 1024 token local window attention), the outputs are garbled a little. I've already applied this to my own version of TensorRT-LLM and my Gemma3 is now working as expected.

Test Coverage

Added tests/unittest/others/test_layer.py::TestLayer::test_gemma3_local_attention_rope_scaling.py

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@shivghai shivghai changed the title [Gemma3] Fix RoPE for local attention for Gemma3 [Gemma3, Engine] Fix RoPE for local attention for Gemma3 Dec 12, 2025
@shivghai shivghai changed the title [Gemma3, Engine] Fix RoPE for local attention for Gemma3 [Gemma3] Fix RoPE for local attention for Gemma3 Dec 12, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 12, 2025

📝 Walkthrough

Walkthrough

Added local rope parameter support to attention layer by introducing an is_local parameter to the register_rope_params inner function, enabling separate configuration of local vs. global rotary embedding scaling. This allows proper rope parameter registration for local attention scenarios.

Changes

Cohort / File(s) Summary
Local Rope Parameter Handling
tensorrt_llm/layers/attention.py
Added is_local parameter to register_rope_params inner function within create_attention_const_params. Conditional computation of local scaling trio (local_scale, local_scale_type, local_scaling) based on locality. Updated RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin call to pass local scaling parameters. For non-homogeneous attention, invokes register_rope_params with is_local=True for secondary rope-parameter set. Conditionalized rotary_embedding_scale_type and rotary_embedding_scale in compute_cross_kv based on locality. Adjusted rotary-related arguments in gpt_attention call path to select local variants when is_local is true.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Rope parameter registration logic: Verify correctness of local scaling computation and parameter passing across conditional branches
  • Cross-KV computation: Ensure rotary_embedding_scale_type and rotary_embedding_scale are correctly conditionalized for local attention
  • GPT attention call path: Confirm proper selection of rotary variants (rotary_inv_freq, rotary_cos_sin) based on is_local flag and non-homogeneous attention handling

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title directly addresses the main change: fixing RoPE for local attention in Gemma3 models, which aligns with the changeset's focus on introducing is_local parameter handling and conditional scaling for local vs. global attention.
Description check ✅ Passed The PR description clearly explains the issue (incorrect RoPE application in Gemma3 local attention), the solution, and includes test coverage information.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@shivghai shivghai force-pushed the sg/fix-attention-for-gemma3-local-attention branch from 1750f85 to 09fcb61 Compare December 12, 2025 17:51
@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Dec 12, 2025
@shivghai shivghai force-pushed the sg/fix-attention-for-gemma3-local-attention branch from 2ae4f49 to 8256ade Compare December 12, 2025 18:41
@karljang karljang changed the title [Gemma3] Fix RoPE for local attention for Gemma3 [None][fix] [Gemma3] Fix RoPE for local attention for Gemma3 Dec 17, 2025
@karljang
Copy link
Collaborator

@shivghai ,
Thank you for your contribution. Before we proceed, could you please address the check failures mentioned above?

  • PR Checks: This requires you to check the checkbox in the issue description.
  • Pre-commit Check: This is related to coding-style.
    • You can run this check locally. Here’s a quote from the log:

    If you encounter this message during CI, reproduce it locally using: pre-commit run —all-files.
    To run pre-commit as part of your git workflow, use pre-commit install.

@shivghai
Copy link
Contributor Author

Hi @karljang , updated!

@karljang
Copy link
Collaborator

@shivghai , but the pre-commit check failure has not been resolved 😄
https://github.com/NVIDIA/TensorRT-LLM/actions/runs/20176504689/job/58346727806?pr=9961

@shivghai
Copy link
Contributor Author

hi @karljang - looks like the pre-commit check works for me now (and also did earlier for some reason). i can't see it auto-triggering in CI but it should be fine now. i will wait to see a failure

@shivghai
Copy link
Contributor Author

@karljang looks like we are set on the pre-commit side!

@karljang karljang requested review from brb-nv and kaiyux December 18, 2025 21:18
@karljang
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29026 [ run ] triggered by Bot. Commit: 72bf5b4

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 18, 2025

Thank you for the contribution, @shivghai ! We made the change in Pytorch codepath here but missed in TRT flow. I understand it's needed for the larger VLM models.

Hope we can get this in soon. Great work!

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29026 [ run ] completed with state SUCCESS. Commit: 72bf5b4
/LLM/main/L0_MergeRequest_PR pipeline #22247 completed with status: 'SUCCESS'

@shivghai
Copy link
Contributor Author

@brb-nv happy to contribute! this isn't needed exclusively for the VLMs, but for all Gemma3 variants:

5:1 interleaving of local/global layers. We alternate between a local sliding window self-attention (Beltagy et al., 2020) and global self-attention (Luong et al., 2015), with a pattern of 5 local layers for every global layer, starting with a local layer as the first layer of the model.

I'm surprised no one saw this sooner. But to be fair, most of the time, it might not be super noticeable; the sequence would have to be very long, and require a lot of attention in the middle of the sequence and also require that the output clearly depends on the middle of the sequence

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 19, 2025

this isn't needed exclusively for the VLMs, but for all Gemma3 variants

I meant it wasn't needed for the text-only 1B model because even the global layers don't use scaling there. Please correct me if I'm wrong.

And our focus lately shifted to Pytorch workflow (please give it a shot! :D). So, all fixes have been made there like this one:
#5857

Other stuff that's good to be aware of for Gemma3 users:
#5564 - The sliding window size isn't left-inclusive
#5976 - Custom masking for the VLMs

Pls consider using the Pytorch flow if you can.

@shivghai
Copy link
Contributor Author

From the Gemma 3 technical report:

then scale the 4B, 12B, and 27B models up to 128K tokens at the end of pre-training while rescaling RoPE (Chen et al., 2023). We find a scaling factor of 8 to work well in practice

You're right, looks like the 1B model is the only one which doesn't require rope scaling!

Anyway - thank you for all the contributions here, and we will make the switch to the torch workflow eventually :)

Copy link
Collaborator

@brb-nv brb-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with the request to address newer comments for completeness. Let's run CI right after and get this merged!

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 22, 2025

/bot run --disable-fail-fast

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 24, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29670 [ run ] triggered by Bot. Commit: 9d57d51

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29670 [ run ] completed with state SUCCESS. Commit: 9d57d51
/LLM/main/L0_MergeRequest_PR pipeline #22788 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@shivghai
Copy link
Contributor Author

@brb-nv looks like it failed again for unrelated issues (HF related issues)

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 25, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29853 [ run ] triggered by Bot. Commit: 9d57d51

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29853 [ run ] completed with state SUCCESS. Commit: 9d57d51
/LLM/main/L0_MergeRequest_PR pipeline #22955 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Signed-off-by: Shiv Ghai <[email protected]>
Signed-off-by: Shiv Ghai <[email protected]>
Signed-off-by: Shiv Ghai <[email protected]>
@brb-nv brb-nv force-pushed the sg/fix-attention-for-gemma3-local-attention branch from 9d57d51 to 025b7c7 Compare December 25, 2025 02:40
@brb-nv
Copy link
Collaborator

brb-nv commented Dec 25, 2025

Quite some unrelated failures, mostly related to HF. Rebased and retrying.

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 25, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29870 [ run ] triggered by Bot. Commit: 025b7c7

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29870 [ run ] completed with state SUCCESS. Commit: 025b7c7
/LLM/main/L0_MergeRequest_PR pipeline #22971 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 26, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29990 [ run ] triggered by Bot. Commit: 025b7c7

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29990 [ run ] completed with state SUCCESS. Commit: 025b7c7
/LLM/main/L0_MergeRequest_PR pipeline #23071 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@shivghai
Copy link
Contributor Author

more failures, look unrelated from the snippet i see

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 27, 2025

/bot run --disable-fail-fast

@brb-nv
Copy link
Collaborator

brb-nv commented Dec 27, 2025

Waiving the problematic test here: #10311

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30032 [ run ] triggered by Bot. Commit: 025b7c7

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30032 [ run ] completed with state SUCCESS. Commit: 025b7c7
/LLM/main/L0_MergeRequest_PR pipeline #23110 completed with status: 'SUCCESS'

@brb-nv brb-nv removed the request for review from kaiyux December 27, 2025 19:50
@brb-nv brb-nv merged commit ee07a7c into NVIDIA:main Dec 27, 2025
2 checks passed
@brb-nv
Copy link
Collaborator

brb-nv commented Dec 27, 2025

Thank you for your contribution, @shivghai. Cheers!

@shivghai shivghai deleted the sg/fix-attention-for-gemma3-local-attention branch December 27, 2025 20:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants