Skip to content

Commit 2ae4f49

Browse files
committed
update
1 parent 09fcb61 commit 2ae4f49

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

tests/unittest/others/test_layer.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,5 +2116,113 @@ def fuse_rg_lru(recurrent_layer):
21162116
rtol=rtol)
21172117

21182118

2119+
def test_gemma3_local_attention_rope_scaling(self):
2120+
"""
2121+
Test that local attention layers in Gemma3 do NOT apply rope scaling,
2122+
even when the config has rope_scaling defined.
2123+
2124+
This is important for Gemma3 which uses different RoPE parameters for
2125+
local (sliding window) attention vs global attention layers. The fix
2126+
ensures that local attention layers get scale=1.0 and scale_type=none,
2127+
while global layers get the configured scaling.
2128+
"""
2129+
from tensorrt_llm.functional import (PositionEmbeddingType,
2130+
RotaryScalingType)
2131+
from tensorrt_llm.layers.attention import Attention
2132+
2133+
# Create a mock config similar to Gemma3 27B with rope_scaling
2134+
class MockGemma3Config:
2135+
hidden_size = 5376
2136+
num_attention_heads = 32
2137+
head_size = 128
2138+
max_position_embeddings = 32768
2139+
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
2140+
rotary_base = 1000000.0
2141+
rotary_scaling = {
2142+
"factor": 8.0,
2143+
"rope_type": "linear"
2144+
}
2145+
rotary_pct = 1.0
2146+
# Local attention uses a different base frequency
2147+
rope_local_base_freq = 10000.0
2148+
2149+
# Create a mock model class to receive registered parameters
2150+
class MockModelCls:
2151+
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
2152+
2153+
@classmethod
2154+
def register_parameter(cls, name, param):
2155+
setattr(cls, name, param)
2156+
2157+
config = MockGemma3Config()
2158+
2159+
# Call the method that creates attention const params
2160+
Attention.create_attention_const_params(MockModelCls, config)
2161+
2162+
# Verify that global rope parameters are registered
2163+
self.assertTrue(hasattr(MockModelCls, 'embed_positions'),
2164+
"Global embed_positions should be registered")
2165+
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq'),
2166+
"Global rotary_inv_freq should be registered")
2167+
self.assertTrue(
2168+
hasattr(MockModelCls, 'embed_positions_for_gpt_attention'),
2169+
"Global embed_positions_for_gpt_attention should be registered")
2170+
2171+
# Verify that local rope parameters are registered (since rope_local_base_freq is set)
2172+
self.assertTrue(hasattr(MockModelCls, 'embed_positions_local'),
2173+
"Local embed_positions should be registered")
2174+
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq_local'),
2175+
"Local rotary_inv_freq should be registered")
2176+
self.assertTrue(
2177+
hasattr(MockModelCls, 'embed_positions_for_gpt_attention_local'),
2178+
"Local embed_positions_for_gpt_attention should be registered")
2179+
2180+
# Get the parameter values
2181+
global_inv_freq = MockModelCls.rotary_inv_freq.raw_value
2182+
local_inv_freq = MockModelCls.rotary_inv_freq_local.raw_value
2183+
global_cos_sin = MockModelCls.embed_positions_for_gpt_attention.raw_value
2184+
local_cos_sin = MockModelCls.embed_positions_for_gpt_attention_local.raw_value
2185+
2186+
# The global and local inv_freq should be different because:
2187+
# 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
2188+
# 2. Local uses scale=1.0 (no scaling)
2189+
# Also they use different base frequencies (1000000 vs 10000)
2190+
self.assertFalse(
2191+
np.allclose(global_inv_freq, local_inv_freq),
2192+
"Global and local rotary_inv_freq should be different "
2193+
"(global has scaling, local does not)")
2194+
2195+
# The cos/sin embeddings should also be different
2196+
self.assertFalse(
2197+
np.allclose(global_cos_sin, local_cos_sin),
2198+
"Global and local embed_positions_for_gpt_attention should be different "
2199+
"(global has scaling, local does not)")
2200+
2201+
# Additional verification: Check that local inv_freq matches unscaled calculation
2202+
# For local attention with scale=1.0 and base=10000:
2203+
# inv_freq = 1.0 / (10000 ** (arange(0, dim, 2) / dim))
2204+
dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128
2205+
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq**(
2206+
np.arange(0, dim, 2) / dim))
2207+
2208+
np.testing.assert_allclose(
2209+
local_inv_freq,
2210+
expected_local_inv_freq,
2211+
rtol=1e-5,
2212+
err_msg="Local rotary_inv_freq should be computed WITHOUT scaling")
2213+
2214+
# For global attention with linear scaling (factor=8.0):
2215+
# scale = 1.0 / 8.0 = 0.125
2216+
# inv_freq = 0.125 / (1000000 ** (arange(0, dim, 2) / dim))
2217+
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**(
2218+
np.arange(0, dim, 2) / dim))
2219+
2220+
np.testing.assert_allclose(
2221+
global_inv_freq,
2222+
expected_global_inv_freq,
2223+
rtol=1e-5,
2224+
err_msg="Global rotary_inv_freq should be computed WITH linear scaling")
2225+
2226+
21192227
if __name__ == '__main__':
21202228
unittest.main()

0 commit comments

Comments
 (0)