Skip to content

Commit aea2d2c

Browse files
update model reading in text_to_speech_demo
1 parent 87f0fe2 commit aea2d2c

File tree

2 files changed

+9
-20
lines changed

2 files changed

+9
-20
lines changed

demos/text_to_speech_demo/python/models/forward_tacotron_ie.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717
import logging as log
18-
import os.path as osp
1918

2019
import numpy as np
2120

@@ -110,12 +109,9 @@ def gather(a, dim, index):
110109
[-1 if i == j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
111110
return a[tuple(expanded_index)]
112111

113-
def load_network(self, model_xml):
114-
model_bin_name = ".".join(osp.basename(model_xml).split('.')[:-1]) + ".bin"
115-
model_bin = osp.join(osp.dirname(model_xml), model_bin_name)
116-
log.info('Reading ForwardTacotron model {}'.format(model_xml))
117-
model = self.ie.read_model(model=model_xml, weights=model_bin)
118-
return model
112+
def load_network(self, model_path):
113+
log.info('Reading ForwardTacotron model {}'.format(model_path))
114+
return self.ie.read_model(model_path)
119115

120116
def create_infer_request(self, model, path):
121117
compiled_model = self.ie.compile_model(model, device_name=self.device)

demos/text_to_speech_demo/python/models/mel2wave_ie.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717
import logging as log
18-
import os.path as osp
1918

2019
import numpy as np
2120

@@ -64,12 +63,9 @@ def __init__(self, model_upsample, model_rnn, ie, target=11000, overlap=550, hop
6463
self.mel_len = self.upsample_model.input('mels').shape[1] - 2 * self.pad
6564
self.rnn_width = self.rnn_model.input('h1.1').shape[1]
6665

67-
def load_network(self, model_xml):
68-
model_bin_name = ".".join(osp.basename(model_xml).split('.')[:-1]) + ".bin"
69-
model_bin = osp.join(osp.dirname(model_xml), model_bin_name)
70-
log.info('Reading WaveRNN model {}'.format(model_xml))
71-
model = self.ie.read_model(model=model_xml, weights=model_bin)
72-
return model
66+
def load_network(self, model_path):
67+
log.info('Reading WaveRNN model {}'.format(model_path))
68+
return self.ie.read_model(model_path)
7369

7470
def create_infer_requests(self, model, path, batch_sizes=None):
7571
if batch_sizes is not None:
@@ -221,12 +217,9 @@ def __init__(self, model, ie, device='CPU', default_width=800):
221217
self.mel_len = self.model.input('mel').shape[2]
222218
self.widths = [self.mel_len * (i + 1) for i in range(self.scales)]
223219

224-
def load_network(self, model_xml):
225-
model_bin_name = ".".join(osp.basename(model_xml).split('.')[:-1]) + ".bin"
226-
model_bin = osp.join(osp.dirname(model_xml), model_bin_name)
227-
log.info('Reading MelGAN model {}'.format(model_xml))
228-
model = self.ie.read_model(model=model_xml, weights=model_bin)
229-
return model
220+
def load_network(self, model_path):
221+
log.info('Reading MelGAN model {}'.format(model_path))
222+
return self.ie.read_model(model_path)
230223

231224
def create_infer_requests(self, model, path, scales=None):
232225
if scales is not None:

0 commit comments

Comments
 (0)