@@ -702,7 +702,9 @@ def create_attention_const_params(model_cls, config):
702702 is_buffer = True ))
703703 else :
704704
705- def register_rope_params (rotary_base , scale , scale_type , scaling ,
705+ def register_rope_params (rotary_base , rotary_embedding_scale ,
706+ rotary_embedding_scale_type ,
707+ rotary_embedding_scaling ,
706708 names_to_register ):
707709 # Rotary const weights.
708710 embed_positions = RopeEmbeddingUtils .create_sinusoidal_positions (
@@ -712,7 +714,8 @@ def register_rope_params(rotary_base, scale, scale_type, scaling,
712714
713715 rotary_inv_freq , embed_positions_for_gpt_attention = RopeEmbeddingUtils .create_sinusoidal_positions_for_attention_plugin (
714716 max_position_embeddings , rotary_embedding_dim , rotary_base ,
715- scale , scale_type , scaling )
717+ rotary_embedding_scale , rotary_embedding_scale_type ,
718+ rotary_embedding_scaling )
716719 model_cls .register_parameter (
717720 names_to_register [0 ],
718721 Parameter (embed_positions , dtype = 'float32' , is_buffer = True ))
@@ -725,24 +728,25 @@ def register_rope_params(rotary_base, scale, scale_type, scaling,
725728 dtype = 'float32' ,
726729 is_buffer = True ))
727730
728- register_rope_params (rotary_base = rotary_embedding_base ,
729- scale = rotary_embedding_scale ,
730- scale_type = rotary_embedding_scale_type ,
731- scaling = rotary_embedding_scaling ,
732- names_to_register = [
733- 'embed_positions' , 'rotary_inv_freq' ,
734- 'embed_positions_for_gpt_attention'
735- ])
731+ register_rope_params (
732+ rotary_base = rotary_embedding_base ,
733+ rotary_embedding_scale = rotary_embedding_scale ,
734+ rotary_embedding_scale_type = rotary_embedding_scale_type ,
735+ rotary_embedding_scaling = rotary_embedding_scaling ,
736+ names_to_register = [
737+ 'embed_positions' , 'rotary_inv_freq' ,
738+ 'embed_positions_for_gpt_attention'
739+ ])
736740
737741 # For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3.
738742 rotary_embedding_base_local = getattr (config ,
739743 'rope_local_base_freq' , None )
740744 if rotary_embedding_base_local is not None :
741745 register_rope_params (
742746 rotary_base = rotary_embedding_base_local ,
743- scale = 1.0 ,
744- scale_type = RotaryScalingType .none ,
745- scaling = None ,
747+ rotary_embedding_scale = 1.0 ,
748+ rotary_embedding_scale_type = RotaryScalingType .none ,
749+ rotary_embedding_scaling = None ,
746750 names_to_register = [
747751 'embed_positions_local' , 'rotary_inv_freq_local' ,
748752 'embed_positions_for_gpt_attention_local'
0 commit comments