@@ -2115,7 +2115,6 @@ def fuse_rg_lru(recurrent_layer):
21152115 atol = atol ,
21162116 rtol = rtol )
21172117
2118-
21192118 def test_gemma3_local_attention_rope_scaling (self ):
21202119 """
21212120 Test that local attention layers in Gemma3 do NOT apply rope scaling,
@@ -2126,8 +2125,7 @@ def test_gemma3_local_attention_rope_scaling(self):
21262125 ensures that local attention layers get scale=1.0 and scale_type=none,
21272126 while global layers get the configured scaling.
21282127 """
2129- from tensorrt_llm .functional import (PositionEmbeddingType ,
2130- RotaryScalingType )
2128+ from tensorrt_llm .functional import PositionEmbeddingType
21312129 from tensorrt_llm .layers .attention import Attention
21322130
21332131 # Create a mock config similar to Gemma3 27B with rope_scaling
@@ -2138,10 +2136,7 @@ class MockGemma3Config:
21382136 max_position_embeddings = 32768
21392137 position_embedding_type = PositionEmbeddingType .rope_gpt_neox
21402138 rotary_base = 1000000.0
2141- rotary_scaling = {
2142- "factor" : 8.0 ,
2143- "rope_type" : "linear"
2144- }
2139+ rotary_scaling = {"factor" : 8.0 , "rope_type" : "linear" }
21452140 rotary_pct = 1.0
21462141 # Local attention uses a different base frequency
21472142 rope_local_base_freq = 10000.0
@@ -2202,8 +2197,8 @@ def register_parameter(cls, name, param):
22022197 # For local attention with scale=1.0 and base=10000:
22032198 # inv_freq = 1.0 / (10000 ** (arange(0, dim, 2) / dim))
22042199 dim = config .head_size # rotary_embedding_dim = head_size * rotary_pct = 128
2205- expected_local_inv_freq = 1.0 / (config .rope_local_base_freq ** (
2206- np .arange (0 , dim , 2 ) / dim ))
2200+ expected_local_inv_freq = 1.0 / (config .rope_local_base_freq
2201+ ** ( np .arange (0 , dim , 2 ) / dim ))
22072202
22082203 np .testing .assert_allclose (
22092204 local_inv_freq ,
@@ -2214,14 +2209,15 @@ def register_parameter(cls, name, param):
22142209 # For global attention with linear scaling (factor=8.0):
22152210 # scale = 1.0 / 8.0 = 0.125
22162211 # inv_freq = 0.125 / (1000000 ** (arange(0, dim, 2) / dim))
2217- expected_global_inv_freq = (1.0 / 8.0 ) / (config .rotary_base ** (
2218- np .arange (0 , dim , 2 ) / dim ))
2212+ expected_global_inv_freq = (1.0 / 8.0 ) / (config .rotary_base **
2213+ ( np .arange (0 , dim , 2 ) / dim ))
22192214
22202215 np .testing .assert_allclose (
22212216 global_inv_freq ,
22222217 expected_global_inv_freq ,
22232218 rtol = 1e-5 ,
2224- err_msg = "Global rotary_inv_freq should be computed WITH linear scaling" )
2219+ err_msg =
2220+ "Global rotary_inv_freq should be computed WITH linear scaling" )
22252221
22262222
22272223if __name__ == '__main__' :
0 commit comments