Skip to content

Commit 89f0f2d

Browse files
committed
[Gemma3] Fix RoPE for local attention for Gemma3
Signed-off-by: Shiv Ghai <[email protected]>
1 parent 246a877 commit 89f0f2d

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

tensorrt_llm/layers/attention.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -702,16 +702,20 @@ def create_attention_const_params(model_cls, config):
702702
is_buffer=True))
703703
else:
704704

705-
def register_rope_params(rotary_base, names_to_register):
705+
def register_rope_params(rotary_base, names_to_register, is_local=False):
706706
# Rotary const weights.
707707
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
708708
max_position_embeddings,
709709
rotary_embedding_dim,
710710
)
711+
# For local attention, use no scaling (consistent with forward pass)
712+
local_scale = 1.0 if is_local else rotary_embedding_scale
713+
local_scale_type = RotaryScalingType.none if is_local else rotary_embedding_scale_type
714+
local_scaling = None if is_local else rotary_embedding_scaling
715+
711716
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
712717
max_position_embeddings, rotary_embedding_dim, rotary_base,
713-
rotary_embedding_scale, rotary_embedding_scale_type,
714-
rotary_embedding_scaling)
718+
local_scale, local_scale_type, local_scaling)
715719
model_cls.register_parameter(
716720
names_to_register[0],
717721
Parameter(embed_positions, dtype='float32', is_buffer=True))
@@ -739,7 +743,8 @@ def register_rope_params(rotary_base, names_to_register):
739743
names_to_register=[
740744
'embed_positions_local', 'rotary_inv_freq_local',
741745
'embed_positions_for_gpt_attention_local'
742-
])
746+
],
747+
is_local=True)
743748

744749
@staticmethod
745750
def fill_attention_params(model_cls, attention_params):
@@ -1141,10 +1146,10 @@ def compute_cross_kv(encoder_output):
11411146
rotary_embedding_dim=self.rotary_embedding_dim,
11421147
rotary_embedding_base=self.rotary_embedding_base
11431148
if not self.is_local else self.rotary_embedding_base_local,
1144-
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
1149+
rotary_embedding_scale_type=self.rotary_embedding_scale_type if not self.is_local else RotaryScalingType.none,
11451150
rotary_embedding_short_m_scale=attention_params.short_mscale,
11461151
rotary_embedding_long_m_scale=attention_params.long_mscale,
1147-
rotary_embedding_scale=self.rotary_embedding_scale,
1152+
rotary_embedding_scale=self.rotary_embedding_scale if not self.is_local else 1.0,
11481153
rotary_embedding_max_positions=self.max_position_embeddings,
11491154
rotary_embedding_original_max_positions=self.
11501155
original_max_position_embeddings,
@@ -2792,4 +2797,4 @@ def forward(self,
27922797
attention_mask=attention_mask,
27932798
max_input_length=max_input_length,
27942799
*args,
2795-
**kwargs)
2800+
**kwargs)

0 commit comments

Comments
 (0)