@@ -1680,6 +1680,102 @@ def set_decoder(self, spec, module):
16801680 self .set_linear (layer_spec .ffn .linear_1 , layer .mlp .fc2 )
16811681
16821682
1683+ @register_loader ("Phi3Config" )
1684+ class Phi3Loader (ModelLoader ):
1685+ @property
1686+ def architecture_name (self ):
1687+ return "AutoModelForCausalLM"
1688+
1689+ def get_model_spec (self , model ):
1690+ num_layers = model .config .num_hidden_layers
1691+
1692+ num_heads = model .config .num_attention_heads
1693+ num_heads_kv = getattr (model .config , "num_key_value_heads" , num_heads )
1694+ if num_heads_kv == num_heads :
1695+ num_heads_kv = None
1696+
1697+ rope_scaling = getattr (model .config , "rope_scaling" , None )
1698+ if rope_scaling :
1699+ rotary_scaling_type = _SUPPORTED_ROPE_SCALING .get (rope_scaling ["type" ])
1700+ rotary_scaling_factor = rope_scaling ["factor" ]
1701+
1702+ if rotary_scaling_type is None :
1703+ raise NotImplementedError (
1704+ "RoPE scaling type '%s' is not yet implemented. "
1705+ "The following RoPE scaling types are currently supported: %s"
1706+ % (rope_scaling ["type" ], ", " .join (_SUPPORTED_ROPE_SCALING .keys ()))
1707+ )
1708+ else :
1709+ rotary_scaling_type = None
1710+ rotary_scaling_factor = 1
1711+
1712+ spec = transformer_spec .TransformerDecoderModelSpec .from_config (
1713+ num_layers ,
1714+ num_heads ,
1715+ activation = common_spec .Activation .SWISH ,
1716+ pre_norm = True ,
1717+ ffn_glu = True ,
1718+ rms_norm = True ,
1719+ rotary_dim = 0 ,
1720+ rotary_interleave = False ,
1721+ rotary_scaling_type = rotary_scaling_type ,
1722+ rotary_scaling_factor = rotary_scaling_factor ,
1723+ rotary_base = getattr (model .config , "rope_theta" , 10000 ),
1724+ num_heads_kv = num_heads_kv ,
1725+ )
1726+
1727+ self .set_decoder (spec .decoder , model .model )
1728+ self .set_linear (spec .decoder .projection , model .lm_head )
1729+ return spec
1730+
1731+ def get_vocabulary (self , model , tokenizer ):
1732+ tokens = super ().get_vocabulary (model , tokenizer )
1733+
1734+ extra_ids = model .config .vocab_size - len (tokens )
1735+ for i in range (extra_ids ):
1736+ tokens .append ("<extra_id_%d>" % i )
1737+
1738+ return tokens
1739+
1740+ def set_vocabulary (self , spec , tokens ):
1741+ spec .register_vocabulary (tokens )
1742+
1743+ def set_config (self , config , model , tokenizer ):
1744+ config .bos_token = tokenizer .bos_token
1745+ config .eos_token = tokenizer .eos_token
1746+ config .unk_token = tokenizer .unk_token
1747+
1748+ def set_layer_norm (self , spec , layer_norm ):
1749+ spec .gamma = layer_norm .weight
1750+
1751+ def set_decoder (self , spec , module ):
1752+ spec .scale_embeddings = False
1753+ self .set_embeddings (spec .embeddings , module .embed_tokens )
1754+ self .set_layer_norm (spec .layer_norm , module .norm )
1755+
1756+ for layer_spec , layer in zip (spec .layer , module .layers ):
1757+ self .set_layer_norm (
1758+ layer_spec .self_attention .layer_norm , layer .input_layernorm
1759+ )
1760+ self .set_layer_norm (
1761+ layer_spec .ffn .layer_norm , layer .post_attention_layernorm
1762+ )
1763+
1764+ self .set_linear (
1765+ layer_spec .self_attention .linear [0 ], layer .self_attn .qkv_proj
1766+ )
1767+ self .set_linear (layer_spec .self_attention .linear [1 ], layer .self_attn .o_proj )
1768+
1769+ gate_proj , up_proj = layer .mlp .gate_up_proj .weight .chunk (2 , dim = 0 )
1770+ layer_spec .ffn .linear_0 .weight = gate_proj
1771+ layer_spec .ffn .linear_0_noact .weight = up_proj
1772+ self .set_linear (layer_spec .ffn .linear_1 , layer .mlp .down_proj )
1773+
1774+ delattr (layer , "self_attn" )
1775+ delattr (layer , "mlp" )
1776+ gc .collect ()
1777+
1778+
16831779@register_loader ("RWConfig" )
16841780class RWLoader (ModelLoader ):
16851781 @property
0 commit comments