Skip to content

Commit 025b7c7

Browse files
shivghaibrb-nv
authored andcommitted
address comments
Signed-off-by: Shiv Ghai <[email protected]>
1 parent 81ff881 commit 025b7c7

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

tensorrt_llm/layers/attention.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,9 @@ def create_attention_const_params(model_cls, config):
702702
is_buffer=True))
703703
else:
704704

705-
def register_rope_params(rotary_base, scale, scale_type, scaling,
705+
def register_rope_params(rotary_base, rotary_embedding_scale,
706+
rotary_embedding_scale_type,
707+
rotary_embedding_scaling,
706708
names_to_register):
707709
# Rotary const weights.
708710
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
@@ -712,7 +714,8 @@ def register_rope_params(rotary_base, scale, scale_type, scaling,
712714

713715
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
714716
max_position_embeddings, rotary_embedding_dim, rotary_base,
715-
scale, scale_type, scaling)
717+
rotary_embedding_scale, rotary_embedding_scale_type,
718+
rotary_embedding_scaling)
716719
model_cls.register_parameter(
717720
names_to_register[0],
718721
Parameter(embed_positions, dtype='float32', is_buffer=True))
@@ -725,24 +728,25 @@ def register_rope_params(rotary_base, scale, scale_type, scaling,
725728
dtype='float32',
726729
is_buffer=True))
727730

728-
register_rope_params(rotary_base=rotary_embedding_base,
729-
scale=rotary_embedding_scale,
730-
scale_type=rotary_embedding_scale_type,
731-
scaling=rotary_embedding_scaling,
732-
names_to_register=[
733-
'embed_positions', 'rotary_inv_freq',
734-
'embed_positions_for_gpt_attention'
735-
])
731+
register_rope_params(
732+
rotary_base=rotary_embedding_base,
733+
rotary_embedding_scale=rotary_embedding_scale,
734+
rotary_embedding_scale_type=rotary_embedding_scale_type,
735+
rotary_embedding_scaling=rotary_embedding_scaling,
736+
names_to_register=[
737+
'embed_positions', 'rotary_inv_freq',
738+
'embed_positions_for_gpt_attention'
739+
])
736740

737741
# For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3.
738742
rotary_embedding_base_local = getattr(config,
739743
'rope_local_base_freq', None)
740744
if rotary_embedding_base_local is not None:
741745
register_rope_params(
742746
rotary_base=rotary_embedding_base_local,
743-
scale=1.0,
744-
scale_type=RotaryScalingType.none,
745-
scaling=None,
747+
rotary_embedding_scale=1.0,
748+
rotary_embedding_scale_type=RotaryScalingType.none,
749+
rotary_embedding_scaling=None,
746750
names_to_register=[
747751
'embed_positions_local', 'rotary_inv_freq_local',
748752
'embed_positions_for_gpt_attention_local'

tests/unittest/others/test_layer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2184,7 +2184,6 @@ def register_parameter(cls, name, param):
21842184
# The global and local inv_freq should be different because:
21852185
# 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
21862186
# 2. Local uses scale=1.0 (no scaling)
2187-
# Also they use different base frequencies (1000000 vs 10000)
21882187
self.assertFalse(
21892188
np.allclose(global_inv_freq, local_inv_freq),
21902189
"Global and local rotary_inv_freq should be different "
@@ -2197,8 +2196,8 @@ def register_parameter(cls, name, param):
21972196
"(global has scaling, local does not)")
21982197

21992198
# Additional verification: Check that local inv_freq matches unscaled calculation
2200-
# For local attention with scale=1.0 and base=10000:
2201-
# inv_freq = 1.0 / (10000 ** (arange(0, dim, 2) / dim))
2199+
# For local attention with scale=1.0 and base=10:
2200+
# inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim))
22022201
dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128
22032202
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq
22042203
**(np.arange(0, dim, 2) / dim))
@@ -2211,7 +2210,7 @@ def register_parameter(cls, name, param):
22112210

22122211
# For global attention with linear scaling (factor=8.0):
22132212
# scale = 1.0 / 8.0 = 0.125
2214-
# inv_freq = 0.125 / (1000000 ** (arange(0, dim, 2) / dim))
2213+
# inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim))
22152214
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**
22162215
(np.arange(0, dim, 2) / dim))
22172216

0 commit comments

Comments
 (0)