@@ -2115,6 +2115,112 @@ def fuse_rg_lru(recurrent_layer):
21152115 atol = atol ,
21162116 rtol = rtol )
21172117
2118+ def test_gemma3_local_attention_rope_scaling (self ):
2119+ """
2120+ Test that local attention layers in Gemma3 do NOT apply rope scaling,
2121+ even when the config has rope_scaling defined.
2122+
2123+ This is important for Gemma3 which uses different RoPE parameters for
2124+ local (sliding window) attention vs global attention layers. The fix
2125+ ensures that local attention layers get scale=1.0 and scale_type=none,
2126+ while global layers get the configured scaling.
2127+ """
2128+ from tensorrt_llm .functional import PositionEmbeddingType
2129+ from tensorrt_llm .layers .attention import Attention
2130+
2131+ # Create a mock config similar to Gemma3 27B with rope_scaling
2132+ class MockGemma3Config :
2133+ hidden_size = 5376
2134+ num_attention_heads = 32
2135+ head_size = 128
2136+ max_position_embeddings = 32768
2137+ position_embedding_type = PositionEmbeddingType .rope_gpt_neox
2138+ # Use small rotary base values to avoid numerical instability in tests.
2139+ # Large bases (e.g. 1000000) get exponentiated, causing potential flakiness
2140+ # when comparing floating point results.
2141+ rotary_base = 100.0
2142+ rotary_scaling = {"factor" : 8.0 , "rope_type" : "linear" }
2143+ rotary_pct = 1.0
2144+ # Local attention uses a different base frequency
2145+ rope_local_base_freq = 10.0
2146+
2147+ # Create a mock model class to receive registered parameters
2148+ class MockModelCls :
2149+ position_embedding_type = PositionEmbeddingType .rope_gpt_neox
2150+
2151+ @classmethod
2152+ def register_parameter (cls , name , param ):
2153+ setattr (cls , name , param )
2154+
2155+ config = MockGemma3Config ()
2156+
2157+ # Call the method that creates attention const params
2158+ Attention .create_attention_const_params (MockModelCls , config )
2159+
2160+ # Verify that global rope parameters are registered
2161+ self .assertTrue (hasattr (MockModelCls , 'embed_positions' ),
2162+ "Global embed_positions should be registered" )
2163+ self .assertTrue (hasattr (MockModelCls , 'rotary_inv_freq' ),
2164+ "Global rotary_inv_freq should be registered" )
2165+ self .assertTrue (
2166+ hasattr (MockModelCls , 'embed_positions_for_gpt_attention' ),
2167+ "Global embed_positions_for_gpt_attention should be registered" )
2168+
2169+ # Verify that local rope parameters are registered (since rope_local_base_freq is set)
2170+ self .assertTrue (hasattr (MockModelCls , 'embed_positions_local' ),
2171+ "Local embed_positions should be registered" )
2172+ self .assertTrue (hasattr (MockModelCls , 'rotary_inv_freq_local' ),
2173+ "Local rotary_inv_freq should be registered" )
2174+ self .assertTrue (
2175+ hasattr (MockModelCls , 'embed_positions_for_gpt_attention_local' ),
2176+ "Local embed_positions_for_gpt_attention should be registered" )
2177+
2178+ # Get the parameter values
2179+ global_inv_freq = MockModelCls .rotary_inv_freq .raw_value
2180+ local_inv_freq = MockModelCls .rotary_inv_freq_local .raw_value
2181+ global_cos_sin = MockModelCls .embed_positions_for_gpt_attention .raw_value
2182+ local_cos_sin = MockModelCls .embed_positions_for_gpt_attention_local .raw_value
2183+
2184+ # The global and local inv_freq should be different because:
2185+ # 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
2186+ # 2. Local uses scale=1.0 (no scaling)
2187+ self .assertFalse (
2188+ np .allclose (global_inv_freq , local_inv_freq ),
2189+ "Global and local rotary_inv_freq should be different "
2190+ "(global has scaling, local does not)" )
2191+
2192+ # The cos/sin embeddings should also be different
2193+ self .assertFalse (
2194+ np .allclose (global_cos_sin , local_cos_sin ),
2195+ "Global and local embed_positions_for_gpt_attention should be different "
2196+ "(global has scaling, local does not)" )
2197+
2198+ # Additional verification: Check that local inv_freq matches unscaled calculation
2199+ # For local attention with scale=1.0 and base=10:
2200+ # inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim))
2201+ dim = config .head_size # rotary_embedding_dim = head_size * rotary_pct = 128
2202+ expected_local_inv_freq = 1.0 / (config .rope_local_base_freq
2203+ ** (np .arange (0 , dim , 2 ) / dim ))
2204+
2205+ np .testing .assert_allclose (
2206+ local_inv_freq ,
2207+ expected_local_inv_freq ,
2208+ rtol = 1e-5 ,
2209+ err_msg = "Local rotary_inv_freq should be computed WITHOUT scaling" )
2210+
2211+ # For global attention with linear scaling (factor=8.0):
2212+ # scale = 1.0 / 8.0 = 0.125
2213+ # inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim))
2214+ expected_global_inv_freq = (1.0 / 8.0 ) / (config .rotary_base **
2215+ (np .arange (0 , dim , 2 ) / dim ))
2216+
2217+ np .testing .assert_allclose (
2218+ global_inv_freq ,
2219+ expected_global_inv_freq ,
2220+ rtol = 1e-5 ,
2221+ err_msg =
2222+ "Global rotary_inv_freq should be computed WITH linear scaling" )
2223+
21182224
21192225if __name__ == '__main__' :
21202226 unittest .main ()
0 commit comments