4040
4141_SUPPORTED_ROPE_SCALING = {
4242 "linear" : attention_spec .RotaryScalingType .Linear ,
43+ "su" : attention_spec .RotaryScalingType .Su ,
4344}
4445
4546_MODEL_LOADERS = {}
@@ -346,9 +347,11 @@ def set_common_layers(self, spec, module):
346347 spec .scale_embeddings = module .embed_scale
347348 self .set_position_encodings (spec .position_encodings , module .embed_positions )
348349 self .set_embeddings (
349- spec .embeddings [0 ]
350- if isinstance (spec .embeddings , list )
351- else spec .embeddings ,
350+ (
351+ spec .embeddings [0 ]
352+ if isinstance (spec .embeddings , list )
353+ else spec .embeddings
354+ ),
352355 module .embed_tokens ,
353356 )
354357
@@ -1066,9 +1069,11 @@ def set_config(self, config, model, tokenizer):
10661069 def set_stack (self , spec , module , is_decoder = False ):
10671070 self .set_layer_norm (spec .layer_norm , module .final_layer_norm )
10681071 self .set_embeddings (
1069- spec .embeddings [0 ]
1070- if isinstance (spec .embeddings , list )
1071- else spec .embeddings ,
1072+ (
1073+ spec .embeddings [0 ]
1074+ if isinstance (spec .embeddings , list )
1075+ else spec .embeddings
1076+ ),
10721077 module .embed_tokens ,
10731078 )
10741079
@@ -1298,9 +1303,11 @@ def get_model_spec(self, model):
12981303 spec = transformer_spec .TransformerDecoderModelSpec .from_config (
12991304 num_layers ,
13001305 num_heads ,
1301- activation = common_spec .Activation .GELU
1302- if activation_config == "gelu"
1303- else common_spec .Activation .GELUTanh ,
1306+ activation = (
1307+ common_spec .Activation .GELU
1308+ if activation_config == "gelu"
1309+ else common_spec .Activation .GELUTanh
1310+ ),
13041311 pre_norm = True ,
13051312 ffn_glu = True ,
13061313 rms_norm = True ,
@@ -1694,10 +1701,14 @@ def get_model_spec(self, model):
16941701 if num_heads_kv == num_heads :
16951702 num_heads_kv = None
16961703
1704+ original_max_position_embeddings = getattr (
1705+ model .config , "original_max_position_embeddings" , 0
1706+ )
1707+ max_position_embeddings = getattr (model .config , "max_position_embeddings" , 0 )
16971708 rope_scaling = getattr (model .config , "rope_scaling" , None )
16981709 if rope_scaling :
16991710 rotary_scaling_type = _SUPPORTED_ROPE_SCALING .get (rope_scaling ["type" ])
1700- rotary_scaling_factor = rope_scaling [ "factor" ]
1711+ rotary_scaling_factor = rope_scaling . get ( "factor" , 1 )
17011712
17021713 if rotary_scaling_type is None :
17031714 raise NotImplementedError (
@@ -1721,6 +1732,8 @@ def get_model_spec(self, model):
17211732 rotary_scaling_type = rotary_scaling_type ,
17221733 rotary_scaling_factor = rotary_scaling_factor ,
17231734 rotary_base = getattr (model .config , "rope_theta" , 10000 ),
1735+ original_max_position_embeddings = original_max_position_embeddings ,
1736+ max_position_embeddings = max_position_embeddings ,
17241737 num_heads_kv = num_heads_kv ,
17251738 )
17261739
@@ -1748,6 +1761,16 @@ def set_config(self, config, model, tokenizer):
17481761 def set_layer_norm (self , spec , layer_norm ):
17491762 spec .gamma = layer_norm .weight
17501763
1764+ def set_rotary_embeddings (
1765+ self , spec , rotary_scaling_long_factor , rotary_scaling_short_factor
1766+ ):
1767+ spec .rotary_scaling_long_factor = torch .tensor (
1768+ rotary_scaling_long_factor , dtype = torch .float32
1769+ )
1770+ spec .rotary_scaling_short_factor = torch .tensor (
1771+ rotary_scaling_short_factor , dtype = torch .float32
1772+ )
1773+
17511774 def set_decoder (self , spec , module ):
17521775 spec .scale_embeddings = False
17531776 self .set_embeddings (spec .embeddings , module .embed_tokens )
@@ -1765,6 +1788,15 @@ def set_decoder(self, spec, module):
17651788 layer_spec .self_attention .linear [0 ], layer .self_attn .qkv_proj
17661789 )
17671790 self .set_linear (layer_spec .self_attention .linear [1 ], layer .self_attn .o_proj )
1791+ if (
1792+ layer .self_attn .rotary_emb .long_factor is not None
1793+ and layer .self_attn .rotary_emb .short_factor is not None
1794+ ):
1795+ self .set_rotary_embeddings (
1796+ layer_spec .self_attention ,
1797+ layer .self_attn .rotary_emb .long_factor ,
1798+ layer .self_attn .rotary_emb .short_factor ,
1799+ )
17681800
17691801 gate_proj , up_proj = layer .mlp .gate_up_proj .weight .chunk (2 , dim = 0 )
17701802 layer_spec .ffn .linear_0 .weight = gate_proj
0 commit comments