@@ -118,7 +118,7 @@ def _init_custom(self):
118118 scaling_type = rope_scaling ["type" ]
119119 else :
120120 raise ValueError (f"Unknown RoPE scaling format { rope_scaling } " )
121- if scaling_type == "default" :
121+ if scaling_type == "default" or "mrope_section" in rope_scaling :
122122 self ._init_to_get_rotary ()
123123 elif scaling_type == "yarn" :
124124 self ._init_to_get_yarn_rotary ()
@@ -129,7 +129,7 @@ def _init_custom(self):
129129 elif scaling_type == "llama3" :
130130 self ._init_to_get_llama3_rotary ()
131131 elif scaling_type == "mrope" :
132- self ._init_to_get_mrope_rotary ()
132+ self ._init_to_get_rotary ()
133133 else :
134134 raise ValueError (f"Unknown RoPE scaling type { scaling_type } " )
135135 return
@@ -373,47 +373,3 @@ def _init_to_get_llama3_rotary(self, default_base=10000):
373373 self ._cos_cached = torch .cos (freqs ).to (self .data_type ).cuda ()
374374 self ._sin_cached = torch .sin (freqs ).to (self .data_type ).cuda ()
375375 return
376-
377- def _init_to_get_mrope_rotary (self , default_base = 10000 ):
378- partial_head_dim = int (self .config .get ("partial_rotary_factor" , 1 ) * self .head_dim_ )
379- if self .config .get ("rope_scaling" , {}) is None :
380- rope_scaling_factor = 1.0
381- else :
382- rope_scaling_factor = self .config .get ("rope_scaling" , {}).get ("factor" , 1.0 )
383-
384- base = self .config .get ("rope_theta" , float (default_base ))
385-
386- if "max_sequence_length" in self .config :
387- max_seq_len = self .config ["max_sequence_length" ]
388- else :
389- max_position_embeddings = self .config .get (
390- "max_position_embeddings" , 2048 if base <= 10000.0 + 1e-5 else 16384
391- )
392- max_seq_len = max_position_embeddings * rope_scaling_factor
393-
394- # NTK
395- try :
396- ntk_alpha = float (os .environ .get ("LIGHTLLM_NTK_ALPHA" , 1 ))
397- assert ntk_alpha >= 1
398- if ntk_alpha > 1 :
399- logger .info (f"Note: NTK enabled, alpha set to { ntk_alpha } " )
400- max_seq_len *= ntk_alpha
401- base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2 ))) # Base change formula
402- except :
403- pass
404-
405- self .inv_freq = 1.0 / (
406- base ** (torch .arange (0 , partial_head_dim , 2 , device = "cpu" , dtype = torch .float32 ) / partial_head_dim )
407- )
408-
409- t = (
410- torch .arange (max (max_seq_len + 1024 * 128 , self .max_seq_length ), device = "cpu" , dtype = torch .float32 )
411- / rope_scaling_factor
412- )
413- freqs = torch .outer (t , self .inv_freq ) # (T, D/2)
414- freqs = torch .cat ((freqs , freqs ), dim = - 1 ) # (T, D)
415-
416- self ._cos_cached = torch .cos (freqs ).to (self .data_type ).cuda ()
417- self ._sin_cached = torch .sin (freqs ).to (self .data_type ).cuda ()
418-
419- return
0 commit comments