Skip to content

Commit dcfb80a

Browse files
shivghaibrb-nv
authored andcommitted
lint/fmt using pre-commit
Signed-off-by: Shiv Ghai <[email protected]>
1 parent c8b18cf commit dcfb80a

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

tensorrt_llm/layers/attention.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,9 @@ def create_attention_const_params(model_cls, config):
702702
is_buffer=True))
703703
else:
704704

705-
def register_rope_params(rotary_base, names_to_register, is_local=False):
705+
def register_rope_params(rotary_base,
706+
names_to_register,
707+
is_local=False):
706708
# Rotary const weights.
707709
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
708710
max_position_embeddings,
@@ -1146,10 +1148,12 @@ def compute_cross_kv(encoder_output):
11461148
rotary_embedding_dim=self.rotary_embedding_dim,
11471149
rotary_embedding_base=self.rotary_embedding_base
11481150
if not self.is_local else self.rotary_embedding_base_local,
1149-
rotary_embedding_scale_type=self.rotary_embedding_scale_type if not self.is_local else RotaryScalingType.none,
1151+
rotary_embedding_scale_type=self.rotary_embedding_scale_type
1152+
if not self.is_local else RotaryScalingType.none,
11501153
rotary_embedding_short_m_scale=attention_params.short_mscale,
11511154
rotary_embedding_long_m_scale=attention_params.long_mscale,
1152-
rotary_embedding_scale=self.rotary_embedding_scale if not self.is_local else 1.0,
1155+
rotary_embedding_scale=self.rotary_embedding_scale
1156+
if not self.is_local else 1.0,
11531157
rotary_embedding_max_positions=self.max_position_embeddings,
11541158
rotary_embedding_original_max_positions=self.
11551159
original_max_position_embeddings,
@@ -2797,4 +2801,4 @@ def forward(self,
27972801
attention_mask=attention_mask,
27982802
max_input_length=max_input_length,
27992803
*args,
2800-
**kwargs)
2804+
**kwargs)

tests/unittest/others/test_layer.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,7 +2115,6 @@ def fuse_rg_lru(recurrent_layer):
21152115
atol=atol,
21162116
rtol=rtol)
21172117

2118-
21192118
def test_gemma3_local_attention_rope_scaling(self):
21202119
"""
21212120
Test that local attention layers in Gemma3 do NOT apply rope scaling,
@@ -2126,8 +2125,7 @@ def test_gemma3_local_attention_rope_scaling(self):
21262125
ensures that local attention layers get scale=1.0 and scale_type=none,
21272126
while global layers get the configured scaling.
21282127
"""
2129-
from tensorrt_llm.functional import (PositionEmbeddingType,
2130-
RotaryScalingType)
2128+
from tensorrt_llm.functional import PositionEmbeddingType
21312129
from tensorrt_llm.layers.attention import Attention
21322130

21332131
# Create a mock config similar to Gemma3 27B with rope_scaling
@@ -2138,10 +2136,7 @@ class MockGemma3Config:
21382136
max_position_embeddings = 32768
21392137
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
21402138
rotary_base = 1000000.0
2141-
rotary_scaling = {
2142-
"factor": 8.0,
2143-
"rope_type": "linear"
2144-
}
2139+
rotary_scaling = {"factor": 8.0, "rope_type": "linear"}
21452140
rotary_pct = 1.0
21462141
# Local attention uses a different base frequency
21472142
rope_local_base_freq = 10000.0
@@ -2202,8 +2197,8 @@ def register_parameter(cls, name, param):
22022197
# For local attention with scale=1.0 and base=10000:
22032198
# inv_freq = 1.0 / (10000 ** (arange(0, dim, 2) / dim))
22042199
dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128
2205-
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq**(
2206-
np.arange(0, dim, 2) / dim))
2200+
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq
2201+
**(np.arange(0, dim, 2) / dim))
22072202

22082203
np.testing.assert_allclose(
22092204
local_inv_freq,
@@ -2214,14 +2209,15 @@ def register_parameter(cls, name, param):
22142209
# For global attention with linear scaling (factor=8.0):
22152210
# scale = 1.0 / 8.0 = 0.125
22162211
# inv_freq = 0.125 / (1000000 ** (arange(0, dim, 2) / dim))
2217-
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**(
2218-
np.arange(0, dim, 2) / dim))
2212+
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**
2213+
(np.arange(0, dim, 2) / dim))
22192214

22202215
np.testing.assert_allclose(
22212216
global_inv_freq,
22222217
expected_global_inv_freq,
22232218
rtol=1e-5,
2224-
err_msg="Global rotary_inv_freq should be computed WITH linear scaling")
2219+
err_msg=
2220+
"Global rotary_inv_freq should be computed WITH linear scaling")
22252221

22262222

22272223
if __name__ == '__main__':

0 commit comments

Comments
 (0)