Skip to content

Commit 9d54f5d

Browse files
authored
Add phi3 converter (#1680)
* add phi3 converter * PhiLoader to Phi3Loader * fix black
1 parent 0527ef7 commit 9d54f5d

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

python/ctranslate2/converters/transformers.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,102 @@ def set_decoder(self, spec, module):
16801680
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2)
16811681

16821682

1683+
@register_loader("Phi3Config")
1684+
class Phi3Loader(ModelLoader):
1685+
@property
1686+
def architecture_name(self):
1687+
return "AutoModelForCausalLM"
1688+
1689+
def get_model_spec(self, model):
1690+
num_layers = model.config.num_hidden_layers
1691+
1692+
num_heads = model.config.num_attention_heads
1693+
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1694+
if num_heads_kv == num_heads:
1695+
num_heads_kv = None
1696+
1697+
rope_scaling = getattr(model.config, "rope_scaling", None)
1698+
if rope_scaling:
1699+
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
1700+
rotary_scaling_factor = rope_scaling["factor"]
1701+
1702+
if rotary_scaling_type is None:
1703+
raise NotImplementedError(
1704+
"RoPE scaling type '%s' is not yet implemented. "
1705+
"The following RoPE scaling types are currently supported: %s"
1706+
% (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
1707+
)
1708+
else:
1709+
rotary_scaling_type = None
1710+
rotary_scaling_factor = 1
1711+
1712+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1713+
num_layers,
1714+
num_heads,
1715+
activation=common_spec.Activation.SWISH,
1716+
pre_norm=True,
1717+
ffn_glu=True,
1718+
rms_norm=True,
1719+
rotary_dim=0,
1720+
rotary_interleave=False,
1721+
rotary_scaling_type=rotary_scaling_type,
1722+
rotary_scaling_factor=rotary_scaling_factor,
1723+
rotary_base=getattr(model.config, "rope_theta", 10000),
1724+
num_heads_kv=num_heads_kv,
1725+
)
1726+
1727+
self.set_decoder(spec.decoder, model.model)
1728+
self.set_linear(spec.decoder.projection, model.lm_head)
1729+
return spec
1730+
1731+
def get_vocabulary(self, model, tokenizer):
1732+
tokens = super().get_vocabulary(model, tokenizer)
1733+
1734+
extra_ids = model.config.vocab_size - len(tokens)
1735+
for i in range(extra_ids):
1736+
tokens.append("<extra_id_%d>" % i)
1737+
1738+
return tokens
1739+
1740+
def set_vocabulary(self, spec, tokens):
1741+
spec.register_vocabulary(tokens)
1742+
1743+
def set_config(self, config, model, tokenizer):
1744+
config.bos_token = tokenizer.bos_token
1745+
config.eos_token = tokenizer.eos_token
1746+
config.unk_token = tokenizer.unk_token
1747+
1748+
def set_layer_norm(self, spec, layer_norm):
1749+
spec.gamma = layer_norm.weight
1750+
1751+
def set_decoder(self, spec, module):
1752+
spec.scale_embeddings = False
1753+
self.set_embeddings(spec.embeddings, module.embed_tokens)
1754+
self.set_layer_norm(spec.layer_norm, module.norm)
1755+
1756+
for layer_spec, layer in zip(spec.layer, module.layers):
1757+
self.set_layer_norm(
1758+
layer_spec.self_attention.layer_norm, layer.input_layernorm
1759+
)
1760+
self.set_layer_norm(
1761+
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
1762+
)
1763+
1764+
self.set_linear(
1765+
layer_spec.self_attention.linear[0], layer.self_attn.qkv_proj
1766+
)
1767+
self.set_linear(layer_spec.self_attention.linear[1], layer.self_attn.o_proj)
1768+
1769+
gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
1770+
layer_spec.ffn.linear_0.weight = gate_proj
1771+
layer_spec.ffn.linear_0_noact.weight = up_proj
1772+
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
1773+
1774+
delattr(layer, "self_attn")
1775+
delattr(layer, "mlp")
1776+
gc.collect()
1777+
1778+
16831779
@register_loader("RWConfig")
16841780
class RWLoader(ModelLoader):
16851781
@property

0 commit comments

Comments
 (0)