[#9587][fix] AutoDeploy: Support Gemma3 VLM#10096
[#9587][fix] AutoDeploy: Support Gemma3 VLM#10096bmarimuthu-nv wants to merge 24 commits intoNVIDIA:mainfrom
Conversation
8282fc4 to
7abd5bc
Compare
|
/bot run |
|
@coderabbitai summary |
✅ Actions performedSummary regeneration triggered. |
📝 WalkthroughWalkthroughThis PR introduces comprehensive Vision-Language Model (VLM) support to AutoDeploy, enabling Gemma3 and similar multimodal models to export and run efficiently. Changes include custom attention mask generation operators, FlashInfer integration for VLM masking, export-time transformations for mask metadata tagging, runtime VLM mask preparation, and extensive supporting infrastructure including utilities, patches, and test coverage. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Areas requiring extra attention:
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro 📒 Files selected for processing (25)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
tensorrt_llm/_torch/auto_deploy/transform/library/sync_tied_weights.py
Outdated
Show resolved
Hide resolved
| @TransformRegistry.register("sync_tied_weights") | ||
| class SyncTiedWeights(BaseTransform): | ||
| """Sync tied weights that cross the export boundary. |
There was a problem hiding this comment.
we should also check if we can just do this with load hooks during/before export.
There was a problem hiding this comment.
I looked into the load hooks and it feels to be a better fit for this instead of a separate transform. So I reverted the transform and made the export add a post load hook to sync tied weights - 7d1c5c8
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Outdated
Show resolved
Hide resolved
| # return a list of tensors | ||
| return self.cache_seq_interface.info.unnest_sequences(logits) | ||
|
|
||
| def _prepare_vlm_kwargs(self, kwargs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
There was a problem hiding this comment.
I think we need to find a way to make this part of the export pipeline without requiring eager hacks. Here is my broad suggestion which relies on model patching or re-writing the model definition:
- We have a reference op
torch.ops.auto_deploy.create_attention_maskthat takes a set of arguments is called inside the model (either by rewriting or patching the forward function or some masking utils) and can return the correct mask. It can be called multiple times to create different masks. The different mask are then given as argument totorch.ops.auto_deploy.torch_attention. - When you export the model you now have multiple instances of these mask ops available. Since it is part of the export graph each layer and invocation of torch_attention already has the right mask (sliding, full, etc...)
- During swapping from torch_attention to the backend attention (e.g.
flashinfer_attention) you swap the generictorch.ops.auto_deploy.create_attention_maskfor the backend-specific mask creation tensor. Or you keep the generic one if the generic mask creation is compatible with the attention backend.
I think this design of shifting more into the export + model patching stage should avoid a lot of the hard-coded heuristics like here throughout the code base.
Let me know what you think and happy to discuss this further
govind-ramnarayan
left a comment
There was a problem hiding this comment.
Thanks! Mostly looks good to me. Leaving some comments. Mostly would be good to revisit some of these tests for usefulness.
Also would be good to clarify (and document where it makes sense) why we have a custom mask during prefill and not generation (because there may be multi-token sequences in the future that are not prefill with speculative decoding e.g. here: https://github.com/NVIDIA/TensorRT-LLM/pull/10096/changes#diff-3ea5c563f6bdbaf80e42e4281b753c2a69e59c6dc9b43595b263b352c3f6cca3R280
| @@ -252,16 +276,23 @@ def flashinfer_mha_with_cache( | |||
| n_heads = q.shape[1] | |||
| n_kv_heads = k.shape[1] | |||
|
|
|||
| is_generate = s == 1 | |||
| # Custom mask only applies during prefill, not generation | |||
There was a problem hiding this comment.
Q: Just curious, is this a fundamental aspect of prefill, or is it directly applicable to any request with s > 1. The reason I ask is because of speculative decoding - where decode sequences can have multiple tokens. I don't think you need to investigate this, but clarifying the reason that the custom mask only applies during prefill (maybe in your own notes or a Google doc if it is too long for a comment) might be useful in the future so we don't need to ask this when doing (speculative decoding) x VLMs.
There was a problem hiding this comment.
Sure, the primary reason why we apply custom mask is because in VLMs the of the presence of image or other modal tokens might need a non-causal mask (like bidirectional masking). This custom mask is needed only for context generation/prefill. Once it prefill is done, during generation, we are doing causal token generation and hence we fallback to the attn backend's causal mask generation (ImageTextToText type models).
As for supporting spec decoding (s > 1), maybe we can look into having an explicit param to denote the phase (prefill / decode), instead of relying on s == 1. But as along we are running text generation models, decode phase with s > 1 should still use the causal mask from attn-backend instead of the custom mask. So we should be fine with an explicit param when doing (speculative decoding) x VLMs.
| modeling_gemma3.Gemma3Model.forward | ||
| ) | ||
|
|
||
| def _revert_patch(self): |
There was a problem hiding this comment.
Nit: Maybe assert that this function covers all the keys in self.original_values. I can see a potential issue where we add more stuff to _apply_patch() in the future but do not properly update _revert_patch().
I think maybe if this type of pattern becomes common for other models, we could try to turn this into a utility that just iterates over self.original_values and figures out the values to update from the keys; but this seems overly complicated for now.
There was a problem hiding this comment.
Makes sense! that's a good idea and seems valid for all export only patches at least 👍
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_vlm.py
Outdated
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_vlm.py
Outdated
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_vlm.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
df120f0 to
a1d5ee0
Compare
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Summary by CodeRabbit
Summary by Author
Background:
Gemma3ForConditionalGeneration->Gemma3Model->Gemma3TextModelWe export only
Gemma3TextModel.Challenges:
Gemma3Modeland mask tensor for a given layer dix/type (sliding window or full attention) prepared already and is directly passed toGemma3TextModellm_headis in theGemma3ForCausalGeneration, inputembeddingis inGemma3TextModel. But weight tying is enabled between the lm_head and embedding. But we export onlyGemma3TextModeland load weights. Sadly, in the checkpoint, thelm_headweight andembeddingweight are not identical (for some reason). So the generation is bad if we don't enforce the weight tying after loading the weights.Solution V2
for supporting gemma3 VLM, the following changes are made in this branch:
1.a we achieved this by
1.a.1 adding placeholder op in attention_interface during tracing
1.a.2 creating a per model, per backend custom mask generation hook that can be replace the placeholder maskgen op during KVCachetransformation pass
additional infra support to support custom attention masking:
Summary
Export time: Capture metadata in markers
Transform time: Replace markers with backend-specific computation
Runtime: Execute the mask op with dynamic inputs
Infra updates:
AdditionalGraphInput- generic way to add inputs post-exportCustomMaskGeneratorRegistry- extensible for new models/backends_ad_prefix convention - generic bypass mechanismSolutions
Problem 1: Custom Mask preparation and supplying
Solution:
During export:
We create a custom mask generation op that creates the boolean mask and provides it to the flashinfer backend.
token_type_idsas graph inputtoken_type_idsas an explicit parameter, so it gets consumed there and doesn't flow through**lm_kwargstolanguage_model.Gemma3TextModel.__call__to injecttoken_type_idsBEFORE any pre-hooks run, specifically the args capturing prehook inexport_to_gmin auto deploy. The flow: patched_call runs → injectstoken_type_idsintokwargsflashinfer_gemma3_mask_genop is inserted into the graph, takingtoken_type_idsas inputwindow_leftparameter for sliding attention layersDuring running inference:
Gemma3Model.forwardsignature hastoken_type_idsas an explicit parameter, so it gets consumed there and doesn't flow through**lm_kwargstolanguage_model.register_forward_pre_hookfor the graphModule. The hook reads the currenttoken_type_idsfromengine.cache_seq_interface.info._extra_args(which gets populated during _prepare_inputs()Problem 2: Respecting weight tying between exported graph layer and eager layer (in parent module/outside the graph)
Solution (OLD):
Added a new transform
sync_tied_weightsthat runs after weights are loaded (stage: post_load_fusion) that:Detects cross-boundary tied weights by:
* Reading _tied_weights_keys from the model
* Using get_input_embeddings() / get_output_embeddings() to find the actual pair
* Checking which weights are inside GraphModules (exported) vs outside
Syncs the weights by making the non-exported weight (lm_head.weight) point to the exported weight's tensor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.
Fixes #9587
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.