@@ -2116,5 +2116,113 @@ def fuse_rg_lru(recurrent_layer):
21162116 rtol = rtol )
21172117
21182118
2119+ def test_gemma3_local_attention_rope_scaling (self ):
2120+ """
2121+ Test that local attention layers in Gemma3 do NOT apply rope scaling,
2122+ even when the config has rope_scaling defined.
2123+
2124+ This is important for Gemma3 which uses different RoPE parameters for
2125+ local (sliding window) attention vs global attention layers. The fix
2126+ ensures that local attention layers get scale=1.0 and scale_type=none,
2127+ while global layers get the configured scaling.
2128+ """
2129+ from tensorrt_llm .functional import (PositionEmbeddingType ,
2130+ RotaryScalingType )
2131+ from tensorrt_llm .layers .attention import Attention
2132+
2133+ # Create a mock config similar to Gemma3 27B with rope_scaling
2134+ class MockGemma3Config :
2135+ hidden_size = 5376
2136+ num_attention_heads = 32
2137+ head_size = 128
2138+ max_position_embeddings = 32768
2139+ position_embedding_type = PositionEmbeddingType .rope_gpt_neox
2140+ rotary_base = 1000000.0
2141+ rotary_scaling = {
2142+ "factor" : 8.0 ,
2143+ "rope_type" : "linear"
2144+ }
2145+ rotary_pct = 1.0
2146+ # Local attention uses a different base frequency
2147+ rope_local_base_freq = 10000.0
2148+
2149+ # Create a mock model class to receive registered parameters
2150+ class MockModelCls :
2151+ position_embedding_type = PositionEmbeddingType .rope_gpt_neox
2152+
2153+ @classmethod
2154+ def register_parameter (cls , name , param ):
2155+ setattr (cls , name , param )
2156+
2157+ config = MockGemma3Config ()
2158+
2159+ # Call the method that creates attention const params
2160+ Attention .create_attention_const_params (MockModelCls , config )
2161+
2162+ # Verify that global rope parameters are registered
2163+ self .assertTrue (hasattr (MockModelCls , 'embed_positions' ),
2164+ "Global embed_positions should be registered" )
2165+ self .assertTrue (hasattr (MockModelCls , 'rotary_inv_freq' ),
2166+ "Global rotary_inv_freq should be registered" )
2167+ self .assertTrue (
2168+ hasattr (MockModelCls , 'embed_positions_for_gpt_attention' ),
2169+ "Global embed_positions_for_gpt_attention should be registered" )
2170+
2171+ # Verify that local rope parameters are registered (since rope_local_base_freq is set)
2172+ self .assertTrue (hasattr (MockModelCls , 'embed_positions_local' ),
2173+ "Local embed_positions should be registered" )
2174+ self .assertTrue (hasattr (MockModelCls , 'rotary_inv_freq_local' ),
2175+ "Local rotary_inv_freq should be registered" )
2176+ self .assertTrue (
2177+ hasattr (MockModelCls , 'embed_positions_for_gpt_attention_local' ),
2178+ "Local embed_positions_for_gpt_attention should be registered" )
2179+
2180+ # Get the parameter values
2181+ global_inv_freq = MockModelCls .rotary_inv_freq .raw_value
2182+ local_inv_freq = MockModelCls .rotary_inv_freq_local .raw_value
2183+ global_cos_sin = MockModelCls .embed_positions_for_gpt_attention .raw_value
2184+ local_cos_sin = MockModelCls .embed_positions_for_gpt_attention_local .raw_value
2185+
2186+ # The global and local inv_freq should be different because:
2187+ # 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
2188+ # 2. Local uses scale=1.0 (no scaling)
2189+ # Also they use different base frequencies (1000000 vs 10000)
2190+ self .assertFalse (
2191+ np .allclose (global_inv_freq , local_inv_freq ),
2192+ "Global and local rotary_inv_freq should be different "
2193+ "(global has scaling, local does not)" )
2194+
2195+ # The cos/sin embeddings should also be different
2196+ self .assertFalse (
2197+ np .allclose (global_cos_sin , local_cos_sin ),
2198+ "Global and local embed_positions_for_gpt_attention should be different "
2199+ "(global has scaling, local does not)" )
2200+
2201+ # Additional verification: Check that local inv_freq matches unscaled calculation
2202+ # For local attention with scale=1.0 and base=10000:
2203+ # inv_freq = 1.0 / (10000 ** (arange(0, dim, 2) / dim))
2204+ dim = config .head_size # rotary_embedding_dim = head_size * rotary_pct = 128
2205+ expected_local_inv_freq = 1.0 / (config .rope_local_base_freq ** (
2206+ np .arange (0 , dim , 2 ) / dim ))
2207+
2208+ np .testing .assert_allclose (
2209+ local_inv_freq ,
2210+ expected_local_inv_freq ,
2211+ rtol = 1e-5 ,
2212+ err_msg = "Local rotary_inv_freq should be computed WITHOUT scaling" )
2213+
2214+ # For global attention with linear scaling (factor=8.0):
2215+ # scale = 1.0 / 8.0 = 0.125
2216+ # inv_freq = 0.125 / (1000000 ** (arange(0, dim, 2) / dim))
2217+ expected_global_inv_freq = (1.0 / 8.0 ) / (config .rotary_base ** (
2218+ np .arange (0 , dim , 2 ) / dim ))
2219+
2220+ np .testing .assert_allclose (
2221+ global_inv_freq ,
2222+ expected_global_inv_freq ,
2223+ rtol = 1e-5 ,
2224+ err_msg = "Global rotary_inv_freq should be computed WITH linear scaling" )
2225+
2226+
21192227if __name__ == '__main__' :
21202228 unittest .main ()
0 commit comments