Skip to content

Commit b679318

Browse files
authored
onnx
1 parent 2064528 commit b679318

File tree

2 files changed

+405
-0
lines changed

2 files changed

+405
-0
lines changed

model_onnx.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
#!/usr/bin/env python3
2+
from pathlib import Path
3+
from typing import Any, Dict
4+
5+
import math
6+
import onnx
7+
import torch
8+
import argparse
9+
10+
from onnxruntime.quantization import QuantType, quantize_dynamic
11+
12+
import utils
13+
import commons
14+
import attentions
15+
from torch import nn
16+
from models import DurationPredictor, ResidualCouplingBlock, Generator
17+
from text.symbols import symbols
18+
19+
20+
class TextEncoder(nn.Module):
21+
def __init__(
22+
self,
23+
n_vocab,
24+
out_channels,
25+
hidden_channels,
26+
filter_channels,
27+
n_heads,
28+
n_layers,
29+
kernel_size,
30+
p_dropout,
31+
):
32+
super().__init__()
33+
self.n_vocab = n_vocab
34+
self.out_channels = out_channels
35+
self.hidden_channels = hidden_channels
36+
self.filter_channels = filter_channels
37+
self.n_heads = n_heads
38+
self.n_layers = n_layers
39+
self.kernel_size = kernel_size
40+
self.p_dropout = p_dropout
41+
42+
self.emb = nn.Embedding(n_vocab, hidden_channels)
43+
# self.emb_bert = nn.Linear(256, hidden_channels)
44+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
45+
46+
self.encoder = attentions.Encoder(
47+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
48+
)
49+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
50+
51+
def forward(self, x, x_lengths):
52+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
53+
# if bert is not None:
54+
# b = self.emb_bert(bert)
55+
# x = x + b
56+
x = torch.transpose(x, 1, -1) # [b, h, t]
57+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
58+
x.dtype
59+
)
60+
61+
x = self.encoder(x * x_mask, x_mask)
62+
stats = self.proj(x) * x_mask
63+
64+
m, logs = torch.split(stats, self.out_channels, dim=1)
65+
return x, m, logs, x_mask
66+
67+
68+
class SynthesizerEval(nn.Module):
69+
"""
70+
Synthesizer for Training
71+
"""
72+
73+
def __init__(
74+
self,
75+
n_vocab,
76+
spec_channels,
77+
segment_size,
78+
inter_channels,
79+
hidden_channels,
80+
filter_channels,
81+
n_heads,
82+
n_layers,
83+
kernel_size,
84+
p_dropout,
85+
resblock,
86+
resblock_kernel_sizes,
87+
resblock_dilation_sizes,
88+
upsample_rates,
89+
upsample_initial_channel,
90+
upsample_kernel_sizes,
91+
n_speakers=0,
92+
gin_channels=0,
93+
use_sdp=False,
94+
**kwargs
95+
):
96+
97+
super().__init__()
98+
self.n_vocab = n_vocab
99+
self.spec_channels = spec_channels
100+
self.inter_channels = inter_channels
101+
self.hidden_channels = hidden_channels
102+
self.filter_channels = filter_channels
103+
self.n_heads = n_heads
104+
self.n_layers = n_layers
105+
self.kernel_size = kernel_size
106+
self.p_dropout = p_dropout
107+
self.resblock = resblock
108+
self.resblock_kernel_sizes = resblock_kernel_sizes
109+
self.resblock_dilation_sizes = resblock_dilation_sizes
110+
self.upsample_rates = upsample_rates
111+
self.upsample_initial_channel = upsample_initial_channel
112+
self.upsample_kernel_sizes = upsample_kernel_sizes
113+
self.segment_size = segment_size
114+
self.n_speakers = n_speakers
115+
self.gin_channels = gin_channels
116+
117+
self.enc_p = TextEncoder(
118+
n_vocab,
119+
inter_channels,
120+
hidden_channels,
121+
filter_channels,
122+
n_heads,
123+
n_layers,
124+
kernel_size,
125+
p_dropout,
126+
)
127+
self.dec = Generator(
128+
inter_channels,
129+
resblock,
130+
resblock_kernel_sizes,
131+
resblock_dilation_sizes,
132+
upsample_rates,
133+
upsample_initial_channel,
134+
upsample_kernel_sizes,
135+
gin_channels=gin_channels,
136+
)
137+
self.flow = ResidualCouplingBlock(
138+
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
139+
)
140+
self.dp = DurationPredictor(
141+
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
142+
)
143+
if n_speakers > 1:
144+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
145+
146+
def remove_weight_norm(self):
147+
self.flow.remove_weight_norm()
148+
149+
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1):
150+
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
151+
if self.n_speakers > 0:
152+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
153+
else:
154+
g = None
155+
156+
logw = self.dp(x, x_mask, g=g)
157+
w = torch.exp(logw) * x_mask * length_scale
158+
w_ceil = torch.ceil(w)
159+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
160+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
161+
x_mask.dtype
162+
)
163+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
164+
attn = commons.generate_path(w_ceil, attn_mask)
165+
166+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
167+
1, 2
168+
) # [b, t', t], [b, t, d] -> [b, d, t']
169+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
170+
1, 2
171+
) # [b, t', t], [b, t, d] -> [b, d, t']
172+
173+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
174+
z = self.flow(z_p, y_mask, g=g, reverse=True)
175+
o = self.dec((z * y_mask), g=g)
176+
return o.squeeze()
177+
178+
179+
class OnnxModel(torch.nn.Module):
180+
def __init__(self, model: SynthesizerEval):
181+
super().__init__()
182+
self.model = model
183+
184+
def forward(
185+
self,
186+
x,
187+
x_lengths,
188+
noise_scale=1,
189+
length_scale=1,
190+
):
191+
return self.model.infer(
192+
x=x,
193+
x_lengths=x_lengths,
194+
noise_scale=noise_scale,
195+
length_scale=length_scale,
196+
)
197+
198+
199+
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
200+
"""Add meta data to an ONNX model. It is changed in-place.
201+
202+
Args:
203+
filename:
204+
Filename of the ONNX model to be changed.
205+
meta_data:
206+
Key-value pairs.
207+
"""
208+
model = onnx.load(filename)
209+
for key, value in meta_data.items():
210+
meta = model.metadata_props.add()
211+
meta.key = key
212+
meta.value = str(value)
213+
214+
onnx.save(model, filename)
215+
216+
217+
@torch.no_grad()
218+
def main():
219+
parser = argparse.ArgumentParser(description='Inference code for bert vits models')
220+
parser.add_argument('--config', type=str, required=True)
221+
parser.add_argument('--model', type=str, required=True)
222+
args = parser.parse_args()
223+
config_file = args.config
224+
checkpoint = args.model
225+
226+
hps = utils.get_hparams_from_file(config_file)
227+
print(hps)
228+
229+
net_g = SynthesizerEval(
230+
len(symbols),
231+
hps.data.filter_length // 2 + 1,
232+
hps.train.segment_size // hps.data.hop_length,
233+
n_speakers=hps.data.n_speakers,
234+
**hps.model,
235+
)
236+
237+
_ = net_g.eval()
238+
_ = utils.load_model(checkpoint, net_g)
239+
net_g.remove_weight_norm()
240+
241+
x = torch.randint(low=0, high=100, size=(50,), dtype=torch.int64)
242+
x = x.unsqueeze(0)
243+
244+
x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
245+
noise_scale = torch.tensor([1], dtype=torch.float32)
246+
length_scale = torch.tensor([1], dtype=torch.float32)
247+
248+
model = OnnxModel(net_g)
249+
250+
opset_version = 13
251+
252+
filename = "vits-chinese.onnx"
253+
254+
torch.onnx.export(
255+
model,
256+
(x, x_length, noise_scale, length_scale),
257+
filename,
258+
opset_version=opset_version,
259+
input_names=[
260+
"x",
261+
"x_length",
262+
"noise_scale",
263+
"length_scale",
264+
],
265+
output_names=["y"],
266+
dynamic_axes={
267+
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
268+
"x_length": {0: "N"},
269+
"y": {0: "N", 2: "L"},
270+
},
271+
)
272+
meta_data = {
273+
"model_type": "vits",
274+
"comment": "csukuangfj",
275+
"language": "Chinese",
276+
"add_blank": int(hps.data.add_blank),
277+
"n_speakers": int(hps.data.n_speakers),
278+
"sample_rate": hps.data.sampling_rate,
279+
"punctuation": "",
280+
}
281+
print("meta_data", meta_data)
282+
add_meta_data(filename=filename, meta_data=meta_data)
283+
284+
print("Generate int8 quantization models")
285+
filename_int8 = "vits-chinese.int8.onnx"
286+
quantize_dynamic(
287+
model_input=filename,
288+
model_output=filename_int8,
289+
weight_type=QuantType.QUInt8,
290+
)
291+
print(f"Saved to {filename} and {filename_int8}")
292+
293+
294+
if __name__ == "__main__":
295+
main()

0 commit comments

Comments
 (0)