Skip to content

Commit 48398f6

Browse files
committed
Add moonshine converter.
Converts safetensor model def + tokenizer_config.json to ctranslate2 model spec for Moonshine.
1 parent 6c97db4 commit 48398f6

File tree

5 files changed

+175
-0
lines changed

5 files changed

+175
-0
lines changed

docs/conversion.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The Python module includes a [conversion API](python/ctranslate2.converters.rst)
88

99
* [Fairseq](guides/fairseq.md)
1010
* [Marian](guides/marian.md)
11+
* [Moonshine](guides/moonshine.md)
1112
* [OpenNMT-py](guides/opennmt_py.md)
1213
* [OpenNMT-tf](guides/opennmt_tf.md)
1314
* [OPUS-MT](guides/opus_mt.md)

docs/guides/moonshine.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Marian
2+
3+
CTranslate2 supports [Moonshine](https://github.com/usefulsensors/moonshine) transcription models. The conversion requires the paths to the model and vocabularies:
4+
5+
See the following repos for moonshine model.safetensor and tokenizer.json files: [tiny](https://huggingface.co/UsefulSensors/moonshine-tiny/tree/main) [base](https://huggingface.co/UsefulSensors/moonshine-base/tree/main).
6+
7+
```bash
8+
ct2-moonshine-converter --model_path model.safetensors --vocab_path tokenizer.json --moonshine_variant tiny \
9+
--output_dir ct2_model
10+
```

python/ctranslate2/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ctranslate2.converters.converter import Converter
22
from ctranslate2.converters.fairseq import FairseqConverter
33
from ctranslate2.converters.marian import MarianConverter
4+
from ctranslate2.converters.moonshine import MoonshineConverter
45
from ctranslate2.converters.openai_gpt2 import OpenAIGPT2Converter
56
from ctranslate2.converters.opennmt_py import OpenNMTPyConverter
67
from ctranslate2.converters.opennmt_tf import OpenNMTTFConverter
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import argparse
2+
import re
3+
4+
import numpy as np
5+
import json
6+
7+
from ctranslate2.specs import (
8+
TransformerSpec,
9+
TransformerEncoderSpec,
10+
TransformerDecoderSpec,
11+
)
12+
from ctranslate2.specs.common_spec import Activation
13+
from ctranslate2.specs.moonshine_spec import MoonshineSpec
14+
from ctranslate2.converters import utils
15+
from ctranslate2.converters.converter import Converter
16+
from safetensors.torch import safe_open
17+
18+
19+
class MoonshineConverter(Converter):
20+
def __init__(self, safetensor_file, vocab_file, moonshine_variant):
21+
self.safetensor_file = safetensor_file
22+
self.vocab_file = vocab_file
23+
if moonshine_variant == 'tiny':
24+
self.layers = 6
25+
self.heads = 8
26+
elif moonshine_variant == 'base':
27+
self.layers=8
28+
self.heads = 8
29+
else:
30+
raise ValueError('moonshine_variant must be one of ["tiny", "base"]')
31+
32+
def _load(self):
33+
spec = MoonshineSpec(num_encoder_layers=self.layers, num_encoder_heads=self.heads, num_decoder_layers=self.layers, num_decoder_heads=self.heads)
34+
self.load_preprocessor(spec.preprocessor)
35+
self.load_encoder(spec.encoder)
36+
self.load_decoder(spec.decoder)
37+
spec.register_vocabulary(self.load_vocab())
38+
return spec
39+
40+
def load_vocab(self):
41+
tokens_dict = {}
42+
with open(self.vocab_file, encoding="utf-8") as f:
43+
tokenizer_dict = json.load(f)
44+
d = tokenizer_dict['model']['vocab']
45+
for token in d.keys():
46+
idx = d[token]
47+
token = re.sub(r"\\([^x])", r"\1", token)
48+
token = token[1:-1]
49+
if token.startswith("\\x"):
50+
# Convert the digraph \x to the actual escaped sequence.
51+
token = chr(int(token[2:], base=16))
52+
elif token.startswith("'") and token.endswith("'"):
53+
token = token[1:-1]
54+
token = token.replace("''", "'")
55+
if idx is not None:
56+
tokens_dict[idx] = token
57+
added_tokens = tokenizer_dict['added_tokens']
58+
for t in added_tokens:
59+
tokens_dict[t['id']] = t['content']
60+
61+
return [tokens_dict[idx] for idx in sorted(tokens_dict.keys())]
62+
63+
def load_attention(self, att_spec, st_prefix, self_attention=True):
64+
st = safe_open(self.safetensor_file, framework="pt", device="cpu")
65+
attn_w = [
66+
st.get_tensor(f"{st_prefix}.to_{dst}.weight") for dst in ["q", "k", "v"]
67+
]
68+
if self_attention:
69+
att_spec.linear[0].weight = np.concatenate(attn_w)
70+
else:
71+
att_spec.linear[0].weight = attn_w[0]
72+
att_spec.linear[1].weight = np.concatenate(attn_w[1:])
73+
att_spec.linear[-1].weight = st.get_tensor(f"{st_prefix}.to_out.weight")
74+
75+
def load_ffn(self, ffn_spec, st_prefix, swiglu=False):
76+
st = safe_open(self.safetensor_file, framework="pt", device="cpu")
77+
if swiglu:
78+
ffn_spec.linear_0_noact.weight = st.get_tensor(f"{st_prefix}.ff_noact.weight")
79+
ffn_spec.linear_0.weight = st.get_tensor(f"{st_prefix}.ff_proj.weight")
80+
ffn_spec.linear_0_noact.bias = st.get_tensor(f"{st_prefix}.ff_noact.bias")
81+
ffn_spec.linear_0.bias = st.get_tensor(f"{st_prefix}.ff_proj.bias")
82+
ffn_spec.linear_1.weight = st.get_tensor(f"{st_prefix}.ff_out.weight")
83+
ffn_spec.linear_1.bias = st.get_tensor(f"{st_prefix}.ff_out.bias")
84+
else:
85+
ffn_spec.linear_0.weight = st.get_tensor(f"{st_prefix}.ff.0.weight")
86+
ffn_spec.linear_0.bias = st.get_tensor(f"{st_prefix}.ff.0.bias")
87+
ffn_spec.linear_1.weight = st.get_tensor(f"{st_prefix}.ff.2.weight")
88+
ffn_spec.linear_1.bias = st.get_tensor(f"{st_prefix}.ff.2.bias")
89+
90+
def load_layernorm(self, ln_spec, ln_prefix):
91+
st = safe_open(self.safetensor_file, framework="pt", device="cpu")
92+
ln_spec.gamma = st.get_tensor(f"{ln_prefix}.weight")
93+
ln_spec.beta = np.zeros(ln_spec.gamma.shape)
94+
95+
def load_embeddings(self, embedding_spec, embedding_prefix):
96+
st = safe_open(self.safetensor_file, framework="pt", device="cpu")
97+
embedding_spec.weight = st.get_tensor(f"{embedding_prefix}.weight")
98+
99+
def load_preprocessor(self, preprocess_spec):
100+
st = safe_open(self.safetensor_file, framework="pt", device="cpu")
101+
preprocess_prefix = "model.preprocessor.audio_preprocess"
102+
preprocess_spec.conv1.weight = st.get_tensor(f"{preprocess_prefix}.0.weight")
103+
preprocess_spec.layernorm.gamma = st.get_tensor(f"{preprocess_prefix}.2.weight")
104+
preprocess_spec.layernorm.beta = st.get_tensor(f"{preprocess_prefix}.2.bias")
105+
preprocess_spec.conv2.weight = st.get_tensor(f"{preprocess_prefix}.3.weight")
106+
preprocess_spec.conv2.bias = st.get_tensor(f"{preprocess_prefix}.3.bias")
107+
preprocess_spec.conv3.weight = st.get_tensor(f"{preprocess_prefix}.5.weight")
108+
preprocess_spec.conv3.bias = st.get_tensor(f"{preprocess_prefix}.5.bias")
109+
110+
def load_encoder(self, encoder_spec):
111+
self.load_layernorm(encoder_spec.layer_norm, "model.encoder.post_norm")
112+
for idx, l in enumerate(encoder_spec.layer):
113+
self.load_attention(l.self_attention, f"model.encoder.layers.{idx}.attention")
114+
self.load_layernorm(
115+
l.self_attention.layer_norm, f"model.encoder.layers.{idx}.norm1"
116+
)
117+
self.load_ffn(l.ffn, f"model.encoder.layers.{idx}.ff")
118+
self.load_layernorm(l.ffn.layer_norm, f"model.encoder.layers.{idx}.norm2")
119+
120+
def load_decoder(self, decoder_spec):
121+
self.load_layernorm(decoder_spec.layer_norm, "model.decoder.final_norm")
122+
self.load_embeddings(decoder_spec.embeddings, "model.decoder.token_embedding")
123+
decoder_spec.projection.weight = decoder_spec.embeddings.weight
124+
for idx, l in enumerate(decoder_spec.layer):
125+
self.load_attention(
126+
l.self_attention, f"model.decoder.layers.{idx}.self_attention"
127+
)
128+
self.load_layernorm(
129+
l.self_attention.layer_norm, f"model.decoder.layers.{idx}.norm1"
130+
)
131+
self.load_attention(l.attention, f"model.decoder.layers.{idx}.cross_attention", self_attention=False)
132+
self.load_layernorm(l.attention.layer_norm, f"model.decoder.layers.{idx}.norm2")
133+
self.load_ffn(l.ffn, f"model.decoder.layers.{idx}.ff", swiglu=True)
134+
self.load_layernorm(l.ffn.layer_norm, f"model.decoder.layers.{idx}.norm3")
135+
136+
137+
def main():
138+
parser = argparse.ArgumentParser(
139+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
140+
)
141+
parser.add_argument(
142+
"--model_path", required=True, help="Path to the model .safetensor file."
143+
)
144+
parser.add_argument(
145+
"--vocab_path",
146+
required=True,
147+
help="Path to tokenizer.json config file.",
148+
)
149+
parser.add_argument(
150+
"--moonshine_variant",
151+
required=True,
152+
help="Moonshine variant to convert. Must be one of ['tiny', 'base']",
153+
)
154+
155+
Converter.declare_arguments(parser)
156+
args = parser.parse_args()
157+
converter = MoonshineConverter(args.model_path, args.vocab_path, args.moonshine_variant)
158+
converter.convert_from_args(args)
159+
160+
161+
if __name__ == "__main__":
162+
main()

python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _maybe_add_library_root(lib_name):
110110
"console_scripts": [
111111
"ct2-fairseq-converter=ctranslate2.converters.fairseq:main",
112112
"ct2-marian-converter=ctranslate2.converters.marian:main",
113+
"ct2-moonshine-converter=ctranslate2.converters.moonshine:main",
113114
"ct2-openai-gpt2-converter=ctranslate2.converters.openai_gpt2:main",
114115
"ct2-opennmt-py-converter=ctranslate2.converters.opennmt_py:main",
115116
"ct2-opennmt-tf-converter=ctranslate2.converters.opennmt_tf:main",

0 commit comments

Comments
 (0)