@@ -702,16 +702,20 @@ def create_attention_const_params(model_cls, config):
702702 is_buffer = True ))
703703 else :
704704
705- def register_rope_params (rotary_base , names_to_register ):
705+ def register_rope_params (rotary_base , names_to_register , is_local = False ):
706706 # Rotary const weights.
707707 embed_positions = RopeEmbeddingUtils .create_sinusoidal_positions (
708708 max_position_embeddings ,
709709 rotary_embedding_dim ,
710710 )
711+ # For local attention, use no scaling (consistent with forward pass)
712+ local_scale = 1.0 if is_local else rotary_embedding_scale
713+ local_scale_type = RotaryScalingType .none if is_local else rotary_embedding_scale_type
714+ local_scaling = None if is_local else rotary_embedding_scaling
715+
711716 rotary_inv_freq , embed_positions_for_gpt_attention = RopeEmbeddingUtils .create_sinusoidal_positions_for_attention_plugin (
712717 max_position_embeddings , rotary_embedding_dim , rotary_base ,
713- rotary_embedding_scale , rotary_embedding_scale_type ,
714- rotary_embedding_scaling )
718+ local_scale , local_scale_type , local_scaling )
715719 model_cls .register_parameter (
716720 names_to_register [0 ],
717721 Parameter (embed_positions , dtype = 'float32' , is_buffer = True ))
@@ -739,7 +743,8 @@ def register_rope_params(rotary_base, names_to_register):
739743 names_to_register = [
740744 'embed_positions_local' , 'rotary_inv_freq_local' ,
741745 'embed_positions_for_gpt_attention_local'
742- ])
746+ ],
747+ is_local = True )
743748
744749 @staticmethod
745750 def fill_attention_params (model_cls , attention_params ):
@@ -1141,10 +1146,10 @@ def compute_cross_kv(encoder_output):
11411146 rotary_embedding_dim = self .rotary_embedding_dim ,
11421147 rotary_embedding_base = self .rotary_embedding_base
11431148 if not self .is_local else self .rotary_embedding_base_local ,
1144- rotary_embedding_scale_type = self .rotary_embedding_scale_type ,
1149+ rotary_embedding_scale_type = self .rotary_embedding_scale_type if not self . is_local else RotaryScalingType . none ,
11451150 rotary_embedding_short_m_scale = attention_params .short_mscale ,
11461151 rotary_embedding_long_m_scale = attention_params .long_mscale ,
1147- rotary_embedding_scale = self .rotary_embedding_scale ,
1152+ rotary_embedding_scale = self .rotary_embedding_scale if not self . is_local else 1.0 ,
11481153 rotary_embedding_max_positions = self .max_position_embeddings ,
11491154 rotary_embedding_original_max_positions = self .
11501155 original_max_position_embeddings ,
@@ -2792,4 +2797,4 @@ def forward(self,
27922797 attention_mask = attention_mask ,
27932798 max_input_length = max_input_length ,
27942799 * args ,
2795- ** kwargs )
2800+ ** kwargs )
0 commit comments