@@ -702,22 +702,17 @@ def create_attention_const_params(model_cls, config):
702702 is_buffer = True ))
703703 else :
704704
705- def register_rope_params (rotary_base ,
706- names_to_register ,
707- is_local = False ):
705+ def register_rope_params (rotary_base , scale , scale_type , scaling ,
706+ names_to_register ):
708707 # Rotary const weights.
709708 embed_positions = RopeEmbeddingUtils .create_sinusoidal_positions (
710709 max_position_embeddings ,
711710 rotary_embedding_dim ,
712711 )
713- # For local attention, use no scaling (consistent with forward pass)
714- local_scale = 1.0 if is_local else rotary_embedding_scale
715- local_scale_type = RotaryScalingType .none if is_local else rotary_embedding_scale_type
716- local_scaling = None if is_local else rotary_embedding_scaling
717712
718713 rotary_inv_freq , embed_positions_for_gpt_attention = RopeEmbeddingUtils .create_sinusoidal_positions_for_attention_plugin (
719714 max_position_embeddings , rotary_embedding_dim , rotary_base ,
720- local_scale , local_scale_type , local_scaling )
715+ scale , scale_type , scaling )
721716 model_cls .register_parameter (
722717 names_to_register [0 ],
723718 Parameter (embed_positions , dtype = 'float32' , is_buffer = True ))
@@ -731,6 +726,9 @@ def register_rope_params(rotary_base,
731726 is_buffer = True ))
732727
733728 register_rope_params (rotary_base = rotary_embedding_base ,
729+ scale = rotary_embedding_scale ,
730+ scale_type = rotary_embedding_scale_type ,
731+ scaling = rotary_embedding_scaling ,
734732 names_to_register = [
735733 'embed_positions' , 'rotary_inv_freq' ,
736734 'embed_positions_for_gpt_attention'
@@ -742,11 +740,13 @@ def register_rope_params(rotary_base,
742740 if rotary_embedding_base_local is not None :
743741 register_rope_params (
744742 rotary_base = rotary_embedding_base_local ,
743+ scale = 1.0 ,
744+ scale_type = RotaryScalingType .none ,
745+ scaling = None ,
745746 names_to_register = [
746747 'embed_positions_local' , 'rotary_inv_freq_local' ,
747748 'embed_positions_for_gpt_attention_local'
748- ],
749- is_local = True )
749+ ])
750750
751751 @staticmethod
752752 def fill_attention_params (model_cls , attention_params ):
0 commit comments