Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions tensorrt_llm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,12 +702,16 @@ def create_attention_const_params(model_cls, config):
is_buffer=True))
else:

def register_rope_params(rotary_base, names_to_register):
def register_rope_params(rotary_base, rotary_embedding_scale,
rotary_embedding_scale_type,
rotary_embedding_scaling,
names_to_register):
# Rotary const weights.
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
max_position_embeddings,
rotary_embedding_dim,
)

rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
max_position_embeddings, rotary_embedding_dim, rotary_base,
rotary_embedding_scale, rotary_embedding_scale_type,
Expand All @@ -724,18 +728,25 @@ def register_rope_params(rotary_base, names_to_register):
dtype='float32',
is_buffer=True))

register_rope_params(rotary_base=rotary_embedding_base,
names_to_register=[
'embed_positions', 'rotary_inv_freq',
'embed_positions_for_gpt_attention'
])
register_rope_params(
rotary_base=rotary_embedding_base,
rotary_embedding_scale=rotary_embedding_scale,
rotary_embedding_scale_type=rotary_embedding_scale_type,
rotary_embedding_scaling=rotary_embedding_scaling,
names_to_register=[
'embed_positions', 'rotary_inv_freq',
'embed_positions_for_gpt_attention'
])

# For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3.
rotary_embedding_base_local = getattr(config,
'rope_local_base_freq', None)
if rotary_embedding_base_local is not None:
register_rope_params(
rotary_base=rotary_embedding_base_local,
rotary_embedding_scale=1.0,
rotary_embedding_scale_type=RotaryScalingType.none,
rotary_embedding_scaling=None,
names_to_register=[
'embed_positions_local', 'rotary_inv_freq_local',
'embed_positions_for_gpt_attention_local'
Expand Down Expand Up @@ -1141,10 +1152,12 @@ def compute_cross_kv(encoder_output):
rotary_embedding_dim=self.rotary_embedding_dim,
rotary_embedding_base=self.rotary_embedding_base
if not self.is_local else self.rotary_embedding_base_local,
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
rotary_embedding_scale_type=self.rotary_embedding_scale_type
if not self.is_local else RotaryScalingType.none,
rotary_embedding_short_m_scale=attention_params.short_mscale,
rotary_embedding_long_m_scale=attention_params.long_mscale,
rotary_embedding_scale=self.rotary_embedding_scale,
rotary_embedding_scale=self.rotary_embedding_scale
if not self.is_local else 1.0,
rotary_embedding_max_positions=self.max_position_embeddings,
rotary_embedding_original_max_positions=self.
original_max_position_embeddings,
Expand Down
106 changes: 106 additions & 0 deletions tests/unittest/others/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,112 @@ def fuse_rg_lru(recurrent_layer):
atol=atol,
rtol=rtol)

def test_gemma3_local_attention_rope_scaling(self):
"""
Test that local attention layers in Gemma3 do NOT apply rope scaling,
even when the config has rope_scaling defined.

This is important for Gemma3 which uses different RoPE parameters for
local (sliding window) attention vs global attention layers. The fix
ensures that local attention layers get scale=1.0 and scale_type=none,
while global layers get the configured scaling.
"""
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.layers.attention import Attention

# Create a mock config similar to Gemma3 27B with rope_scaling
class MockGemma3Config:
hidden_size = 5376
num_attention_heads = 32
head_size = 128
max_position_embeddings = 32768
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
# Use small rotary base values to avoid numerical instability in tests.
# Large bases (e.g. 1000000) get exponentiated, causing potential flakiness
# when comparing floating point results.
rotary_base = 100.0
rotary_scaling = {"factor": 8.0, "rope_type": "linear"}
rotary_pct = 1.0
# Local attention uses a different base frequency
rope_local_base_freq = 10.0

# Create a mock model class to receive registered parameters
class MockModelCls:
position_embedding_type = PositionEmbeddingType.rope_gpt_neox

@classmethod
def register_parameter(cls, name, param):
setattr(cls, name, param)

config = MockGemma3Config()

# Call the method that creates attention const params
Attention.create_attention_const_params(MockModelCls, config)

# Verify that global rope parameters are registered
self.assertTrue(hasattr(MockModelCls, 'embed_positions'),
"Global embed_positions should be registered")
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq'),
"Global rotary_inv_freq should be registered")
self.assertTrue(
hasattr(MockModelCls, 'embed_positions_for_gpt_attention'),
"Global embed_positions_for_gpt_attention should be registered")

# Verify that local rope parameters are registered (since rope_local_base_freq is set)
self.assertTrue(hasattr(MockModelCls, 'embed_positions_local'),
"Local embed_positions should be registered")
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq_local'),
"Local rotary_inv_freq should be registered")
self.assertTrue(
hasattr(MockModelCls, 'embed_positions_for_gpt_attention_local'),
"Local embed_positions_for_gpt_attention should be registered")

# Get the parameter values
global_inv_freq = MockModelCls.rotary_inv_freq.raw_value
local_inv_freq = MockModelCls.rotary_inv_freq_local.raw_value
global_cos_sin = MockModelCls.embed_positions_for_gpt_attention.raw_value
local_cos_sin = MockModelCls.embed_positions_for_gpt_attention_local.raw_value

# The global and local inv_freq should be different because:
# 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
# 2. Local uses scale=1.0 (no scaling)
self.assertFalse(
np.allclose(global_inv_freq, local_inv_freq),
"Global and local rotary_inv_freq should be different "
"(global has scaling, local does not)")

# The cos/sin embeddings should also be different
self.assertFalse(
np.allclose(global_cos_sin, local_cos_sin),
"Global and local embed_positions_for_gpt_attention should be different "
"(global has scaling, local does not)")

# Additional verification: Check that local inv_freq matches unscaled calculation
# For local attention with scale=1.0 and base=10:
# inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim))
dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq
**(np.arange(0, dim, 2) / dim))

np.testing.assert_allclose(
local_inv_freq,
expected_local_inv_freq,
rtol=1e-5,
err_msg="Local rotary_inv_freq should be computed WITHOUT scaling")

# For global attention with linear scaling (factor=8.0):
# scale = 1.0 / 8.0 = 0.125
# inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim))
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**
(np.arange(0, dim, 2) / dim))

np.testing.assert_allclose(
global_inv_freq,
expected_global_inv_freq,
rtol=1e-5,
err_msg=
"Global rotary_inv_freq should be computed WITH linear scaling")


if __name__ == '__main__':
unittest.main()