Skip to content

Commit ee07a7c

Browse files
authored
[None][fix] [Gemma3] Fix RoPE for local attention for Gemma3 (#9961)
Signed-off-by: Shiv Ghai <[email protected]>
1 parent 1865020 commit ee07a7c

File tree

2 files changed

+127
-8
lines changed

2 files changed

+127
-8
lines changed

tensorrt_llm/layers/attention.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -702,12 +702,16 @@ def create_attention_const_params(model_cls, config):
702702
is_buffer=True))
703703
else:
704704

705-
def register_rope_params(rotary_base, names_to_register):
705+
def register_rope_params(rotary_base, rotary_embedding_scale,
706+
rotary_embedding_scale_type,
707+
rotary_embedding_scaling,
708+
names_to_register):
706709
# Rotary const weights.
707710
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
708711
max_position_embeddings,
709712
rotary_embedding_dim,
710713
)
714+
711715
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
712716
max_position_embeddings, rotary_embedding_dim, rotary_base,
713717
rotary_embedding_scale, rotary_embedding_scale_type,
@@ -724,18 +728,25 @@ def register_rope_params(rotary_base, names_to_register):
724728
dtype='float32',
725729
is_buffer=True))
726730

727-
register_rope_params(rotary_base=rotary_embedding_base,
728-
names_to_register=[
729-
'embed_positions', 'rotary_inv_freq',
730-
'embed_positions_for_gpt_attention'
731-
])
731+
register_rope_params(
732+
rotary_base=rotary_embedding_base,
733+
rotary_embedding_scale=rotary_embedding_scale,
734+
rotary_embedding_scale_type=rotary_embedding_scale_type,
735+
rotary_embedding_scaling=rotary_embedding_scaling,
736+
names_to_register=[
737+
'embed_positions', 'rotary_inv_freq',
738+
'embed_positions_for_gpt_attention'
739+
])
732740

733741
# For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3.
734742
rotary_embedding_base_local = getattr(config,
735743
'rope_local_base_freq', None)
736744
if rotary_embedding_base_local is not None:
737745
register_rope_params(
738746
rotary_base=rotary_embedding_base_local,
747+
rotary_embedding_scale=1.0,
748+
rotary_embedding_scale_type=RotaryScalingType.none,
749+
rotary_embedding_scaling=None,
739750
names_to_register=[
740751
'embed_positions_local', 'rotary_inv_freq_local',
741752
'embed_positions_for_gpt_attention_local'
@@ -1141,10 +1152,12 @@ def compute_cross_kv(encoder_output):
11411152
rotary_embedding_dim=self.rotary_embedding_dim,
11421153
rotary_embedding_base=self.rotary_embedding_base
11431154
if not self.is_local else self.rotary_embedding_base_local,
1144-
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
1155+
rotary_embedding_scale_type=self.rotary_embedding_scale_type
1156+
if not self.is_local else RotaryScalingType.none,
11451157
rotary_embedding_short_m_scale=attention_params.short_mscale,
11461158
rotary_embedding_long_m_scale=attention_params.long_mscale,
1147-
rotary_embedding_scale=self.rotary_embedding_scale,
1159+
rotary_embedding_scale=self.rotary_embedding_scale
1160+
if not self.is_local else 1.0,
11481161
rotary_embedding_max_positions=self.max_position_embeddings,
11491162
rotary_embedding_original_max_positions=self.
11501163
original_max_position_embeddings,

tests/unittest/others/test_layer.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,6 +2115,112 @@ def fuse_rg_lru(recurrent_layer):
21152115
atol=atol,
21162116
rtol=rtol)
21172117

2118+
def test_gemma3_local_attention_rope_scaling(self):
2119+
"""
2120+
Test that local attention layers in Gemma3 do NOT apply rope scaling,
2121+
even when the config has rope_scaling defined.
2122+
2123+
This is important for Gemma3 which uses different RoPE parameters for
2124+
local (sliding window) attention vs global attention layers. The fix
2125+
ensures that local attention layers get scale=1.0 and scale_type=none,
2126+
while global layers get the configured scaling.
2127+
"""
2128+
from tensorrt_llm.functional import PositionEmbeddingType
2129+
from tensorrt_llm.layers.attention import Attention
2130+
2131+
# Create a mock config similar to Gemma3 27B with rope_scaling
2132+
class MockGemma3Config:
2133+
hidden_size = 5376
2134+
num_attention_heads = 32
2135+
head_size = 128
2136+
max_position_embeddings = 32768
2137+
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
2138+
# Use small rotary base values to avoid numerical instability in tests.
2139+
# Large bases (e.g. 1000000) get exponentiated, causing potential flakiness
2140+
# when comparing floating point results.
2141+
rotary_base = 100.0
2142+
rotary_scaling = {"factor": 8.0, "rope_type": "linear"}
2143+
rotary_pct = 1.0
2144+
# Local attention uses a different base frequency
2145+
rope_local_base_freq = 10.0
2146+
2147+
# Create a mock model class to receive registered parameters
2148+
class MockModelCls:
2149+
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
2150+
2151+
@classmethod
2152+
def register_parameter(cls, name, param):
2153+
setattr(cls, name, param)
2154+
2155+
config = MockGemma3Config()
2156+
2157+
# Call the method that creates attention const params
2158+
Attention.create_attention_const_params(MockModelCls, config)
2159+
2160+
# Verify that global rope parameters are registered
2161+
self.assertTrue(hasattr(MockModelCls, 'embed_positions'),
2162+
"Global embed_positions should be registered")
2163+
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq'),
2164+
"Global rotary_inv_freq should be registered")
2165+
self.assertTrue(
2166+
hasattr(MockModelCls, 'embed_positions_for_gpt_attention'),
2167+
"Global embed_positions_for_gpt_attention should be registered")
2168+
2169+
# Verify that local rope parameters are registered (since rope_local_base_freq is set)
2170+
self.assertTrue(hasattr(MockModelCls, 'embed_positions_local'),
2171+
"Local embed_positions should be registered")
2172+
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq_local'),
2173+
"Local rotary_inv_freq should be registered")
2174+
self.assertTrue(
2175+
hasattr(MockModelCls, 'embed_positions_for_gpt_attention_local'),
2176+
"Local embed_positions_for_gpt_attention should be registered")
2177+
2178+
# Get the parameter values
2179+
global_inv_freq = MockModelCls.rotary_inv_freq.raw_value
2180+
local_inv_freq = MockModelCls.rotary_inv_freq_local.raw_value
2181+
global_cos_sin = MockModelCls.embed_positions_for_gpt_attention.raw_value
2182+
local_cos_sin = MockModelCls.embed_positions_for_gpt_attention_local.raw_value
2183+
2184+
# The global and local inv_freq should be different because:
2185+
# 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
2186+
# 2. Local uses scale=1.0 (no scaling)
2187+
self.assertFalse(
2188+
np.allclose(global_inv_freq, local_inv_freq),
2189+
"Global and local rotary_inv_freq should be different "
2190+
"(global has scaling, local does not)")
2191+
2192+
# The cos/sin embeddings should also be different
2193+
self.assertFalse(
2194+
np.allclose(global_cos_sin, local_cos_sin),
2195+
"Global and local embed_positions_for_gpt_attention should be different "
2196+
"(global has scaling, local does not)")
2197+
2198+
# Additional verification: Check that local inv_freq matches unscaled calculation
2199+
# For local attention with scale=1.0 and base=10:
2200+
# inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim))
2201+
dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128
2202+
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq
2203+
**(np.arange(0, dim, 2) / dim))
2204+
2205+
np.testing.assert_allclose(
2206+
local_inv_freq,
2207+
expected_local_inv_freq,
2208+
rtol=1e-5,
2209+
err_msg="Local rotary_inv_freq should be computed WITHOUT scaling")
2210+
2211+
# For global attention with linear scaling (factor=8.0):
2212+
# scale = 1.0 / 8.0 = 0.125
2213+
# inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim))
2214+
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**
2215+
(np.arange(0, dim, 2) / dim))
2216+
2217+
np.testing.assert_allclose(
2218+
global_inv_freq,
2219+
expected_global_inv_freq,
2220+
rtol=1e-5,
2221+
err_msg=
2222+
"Global rotary_inv_freq should be computed WITH linear scaling")
2223+
21182224

21192225
if __name__ == '__main__':
21202226
unittest.main()

0 commit comments

Comments
 (0)