Skip to content

Commit 81ff881

Browse files
shivghaibrb-nv
authored andcommitted
update
Signed-off-by: Shiv Ghai <[email protected]>
1 parent dcfb80a commit 81ff881

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

tensorrt_llm/layers/attention.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -702,22 +702,17 @@ def create_attention_const_params(model_cls, config):
702702
is_buffer=True))
703703
else:
704704

705-
def register_rope_params(rotary_base,
706-
names_to_register,
707-
is_local=False):
705+
def register_rope_params(rotary_base, scale, scale_type, scaling,
706+
names_to_register):
708707
# Rotary const weights.
709708
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
710709
max_position_embeddings,
711710
rotary_embedding_dim,
712711
)
713-
# For local attention, use no scaling (consistent with forward pass)
714-
local_scale = 1.0 if is_local else rotary_embedding_scale
715-
local_scale_type = RotaryScalingType.none if is_local else rotary_embedding_scale_type
716-
local_scaling = None if is_local else rotary_embedding_scaling
717712

718713
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
719714
max_position_embeddings, rotary_embedding_dim, rotary_base,
720-
local_scale, local_scale_type, local_scaling)
715+
scale, scale_type, scaling)
721716
model_cls.register_parameter(
722717
names_to_register[0],
723718
Parameter(embed_positions, dtype='float32', is_buffer=True))
@@ -731,6 +726,9 @@ def register_rope_params(rotary_base,
731726
is_buffer=True))
732727

733728
register_rope_params(rotary_base=rotary_embedding_base,
729+
scale=rotary_embedding_scale,
730+
scale_type=rotary_embedding_scale_type,
731+
scaling=rotary_embedding_scaling,
734732
names_to_register=[
735733
'embed_positions', 'rotary_inv_freq',
736734
'embed_positions_for_gpt_attention'
@@ -742,11 +740,13 @@ def register_rope_params(rotary_base,
742740
if rotary_embedding_base_local is not None:
743741
register_rope_params(
744742
rotary_base=rotary_embedding_base_local,
743+
scale=1.0,
744+
scale_type=RotaryScalingType.none,
745+
scaling=None,
745746
names_to_register=[
746747
'embed_positions_local', 'rotary_inv_freq_local',
747748
'embed_positions_for_gpt_attention_local'
748-
],
749-
is_local=True)
749+
])
750750

751751
@staticmethod
752752
def fill_attention_params(model_cls, attention_params):

tests/unittest/others/test_layer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,11 +2135,14 @@ class MockGemma3Config:
21352135
head_size = 128
21362136
max_position_embeddings = 32768
21372137
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
2138-
rotary_base = 1000000.0
2138+
# Use small rotary base values to avoid numerical instability in tests.
2139+
# Large bases (e.g. 1000000) get exponentiated, causing potential flakiness
2140+
# when comparing floating point results.
2141+
rotary_base = 100.0
21392142
rotary_scaling = {"factor": 8.0, "rope_type": "linear"}
21402143
rotary_pct = 1.0
21412144
# Local attention uses a different base frequency
2142-
rope_local_base_freq = 10000.0
2145+
rope_local_base_freq = 10.0
21432146

21442147
# Create a mock model class to receive registered parameters
21452148
class MockModelCls:

0 commit comments

Comments
 (0)