Skip to content

Commit bfc8027

Browse files
authored
Add Gemma3 Model support for NvTensorRtRtx execution provider (microsoft#1520)
## Problem Gemma3 models failed to build with NvTensorRtRtx execution provider due to: - Unsupported standalone RotaryEmbedding nodes - Dual RoPE configuration requirements (local vs global attention layers) - Missing cos/sin cache inputs to GroupQueryAttention nodes ## Root Cause 1. NvTensorRtRtx doesn't support standalone RotaryEmbedding contrib ops 2. Gemma3 uses mixed attention pattern: 5/6 layers use local attention (sliding window + rope_local_base_freq), 1/6 layers use global attention 3. Previous implementation created incompatible RotaryEmbedding nodes 4. GroupQueryAttention expected cos/sin caches but they weren't provided ## Solution - **Dual Cache System**: Create separate cos/sin caches for local vs global layers - `cos_cache_global/sin_cache_global` - `cos_cache_local/sin_cache_local` (theta=rope_local_base_freq) - **Skip RotaryEmbedding Nodes**: For NvTensorRtRtx, skip standalone node creation - **Internal RoPE Processing**: Set `use_rope_in_attn=True` to handle RoPE inside GroupQueryAttention - **Dynamic Layer Configuration**: Automatically select appropriate cache/theta based on layer type using sliding_window_pattern ## Implementation Details - Layer pattern: `(layer_id + 1) % sliding_window_pattern == 0` → Global, else Local - NvTensorRtRtx: RoPE handled internally by GroupQueryAttention with proper cache inputs - Other EPs: Standard RotaryEmbedding nodes with appropriate caches ## Testing ✅ Gemma3-4b-it model builds successfully with NvTensorRtRtx EP ✅ Maintains compatibility with other execution providers @BLSharda @kunal-vaishnavi @baijumeswani
1 parent 45880d6 commit bfc8027

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

src/python/py/models/builder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,7 @@ def make_attention_init(self):
361361
)
362362

363363
# Some EPs don't support fusing rotary embeddings inside GQA yet
364-
self.attention_attrs["use_rope_in_attn"] = (
365-
self.ep not in ["dml", "webgpu"]
366-
and not self.attention_attrs["q_norm"]
367-
and not self.attention_attrs["k_norm"]
368-
)
364+
self.attention_attrs["use_rope_in_attn"] = self.ep not in ["dml", "webgpu"]
369365
if self.attention_attrs["use_rope_in_attn"]:
370366
# GQA + Rot.Emb. does not require `position_ids` as input
371367
self.input_names.remove("position_ids")

0 commit comments

Comments
 (0)