|
| 1 | +""" |
| 2 | + Copyright (c) 2020 Intel Corporation |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +import os.path as osp |
| 18 | + |
| 19 | +import numpy as np |
| 20 | + |
| 21 | +from utils.text_preprocessing import text_to_sequence, _symbol_to_id |
| 22 | + |
| 23 | + |
| 24 | +class ForwardTacotronIE: |
| 25 | + def __init__(self, model_duration, model_forward, ie, device='CPU', verbose=False): |
| 26 | + self.verbose = verbose |
| 27 | + self.device = device |
| 28 | + |
| 29 | + self.ie = ie |
| 30 | + |
| 31 | + self.duration_predictor_net = self.load_network(model_duration) |
| 32 | + self.duration_predictor_exec = self.create_exec_network(self.duration_predictor_net) |
| 33 | + |
| 34 | + self.forward_net = self.load_network(model_forward) |
| 35 | + self.forward_exec = self.create_exec_network(self.forward_net) |
| 36 | + |
| 37 | + # fixed length of the sequence of symbols |
| 38 | + self.duration_len = self.duration_predictor_net.input_info['input_seq'].input_data.shape[1] |
| 39 | + # fixed length of the input embeddings for forward |
| 40 | + self.forward_len = self.forward_net.input_info['data'].input_data.shape[1] |
| 41 | + if self.verbose: |
| 42 | + print('Forward limitations : {0} symbols and {1} embeddings'.format(self.duration_len, self.forward_len)) |
| 43 | + |
| 44 | + def seq_to_indexes(self, text): |
| 45 | + res = text_to_sequence(text) |
| 46 | + if self.verbose: |
| 47 | + print(res) |
| 48 | + return res |
| 49 | + |
| 50 | + @staticmethod |
| 51 | + def build_index(duration, x): |
| 52 | + duration[np.where(duration < 0)] = 0 |
| 53 | + tot_duration = np.cumsum(duration, 1) |
| 54 | + max_duration = int(tot_duration.max().item()) |
| 55 | + index = np.zeros([x.shape[0], max_duration, x.shape[2]], dtype='long') |
| 56 | + |
| 57 | + for i in range(tot_duration.shape[0]): |
| 58 | + pos = 0 |
| 59 | + for j in range(tot_duration.shape[1]): |
| 60 | + pos1 = tot_duration[i, j] |
| 61 | + index[i, pos:pos1, :] = j |
| 62 | + pos = pos1 |
| 63 | + index[i, pos:, :] = j |
| 64 | + return index |
| 65 | + |
| 66 | + @staticmethod |
| 67 | + def gather(a, dim, index): |
| 68 | + expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)] |
| 69 | + return a[tuple(expanded_index)] |
| 70 | + |
| 71 | + def load_network(self, model_xml): |
| 72 | + model_bin_name = ".".join(osp.basename(model_xml).split('.')[:-1]) + ".bin" |
| 73 | + model_bin = osp.join(osp.dirname(model_xml), model_bin_name) |
| 74 | + print("Loading network files:\n\t{}\n\t{}".format(model_xml, model_bin)) |
| 75 | + net = self.ie.read_network(model=model_xml, weights=model_bin) |
| 76 | + return net |
| 77 | + |
| 78 | + def create_exec_network(self, net): |
| 79 | + exec_net = self.ie.load_network(network=net, device_name=self.device) |
| 80 | + return exec_net |
| 81 | + |
| 82 | + def infer_duration(self, sequence, alpha=1.0, non_empty_symbols=None): |
| 83 | + out = self.duration_predictor_exec.infer(inputs={"input_seq": sequence}) |
| 84 | + duration = out["duration"] * alpha |
| 85 | + |
| 86 | + duration = (duration + 0.5).astype('int').flatten() |
| 87 | + duration = np.expand_dims(duration, axis=0) |
| 88 | + preprocessed_embeddings = out["embeddings"] |
| 89 | + |
| 90 | + if non_empty_symbols is not None: |
| 91 | + duration = duration[:, :non_empty_symbols] |
| 92 | + preprocessed_embeddings = preprocessed_embeddings[:, :non_empty_symbols] |
| 93 | + indexes = self.build_index(duration, preprocessed_embeddings) |
| 94 | + if self.verbose: |
| 95 | + print("Index: {0}, duration: {1}, embeddings: {2}, non_empty_symbols: {3}" |
| 96 | + .format(indexes.shape, duration.shape, preprocessed_embeddings.shape, non_empty_symbols)) |
| 97 | + |
| 98 | + return self.gather(preprocessed_embeddings, 1, indexes) |
| 99 | + |
| 100 | + def infer_mel(self, aligned_emb): |
| 101 | + out = self.forward_exec.infer(inputs={"data": aligned_emb}) |
| 102 | + return out['mel'] |
| 103 | + |
| 104 | + def forward(self, text, alpha=1.0): |
| 105 | + sequence = self.seq_to_indexes(text) |
| 106 | + if len(sequence) <= self.duration_len: |
| 107 | + non_empty_symbols = None |
| 108 | + if len(sequence) < self.duration_len: |
| 109 | + non_empty_symbols = len(sequence) |
| 110 | + sequence += [_symbol_to_id[' ']] * (self.duration_len - len(sequence)) |
| 111 | + sequence = np.array(sequence) |
| 112 | + sequence = np.expand_dims(sequence, axis=0) |
| 113 | + if self.verbose: |
| 114 | + print("Seq shape: {0}".format(sequence.shape)) |
| 115 | + aligned_emb = self.infer_duration(sequence, alpha, non_empty_symbols=non_empty_symbols) |
| 116 | + if self.verbose: |
| 117 | + print("AEmb shape: {0}".format(aligned_emb.shape)) |
| 118 | + else: |
| 119 | + punctuation = '!\'(),.:;? ' |
| 120 | + delimiters = [_symbol_to_id[p] for p in punctuation] |
| 121 | + # try to find optimal fragmentation for inference |
| 122 | + ranges = [i+1 for i, val in enumerate(sequence) if val in delimiters] |
| 123 | + if len(sequence) not in ranges: |
| 124 | + ranges.append(len(sequence)) |
| 125 | + optimal_ranges = [] |
| 126 | + prev_begin = 0 |
| 127 | + for i in range(len(ranges)-1): |
| 128 | + if ranges[i] < 0: |
| 129 | + continue |
| 130 | + res1 = (ranges[i] - prev_begin) % self.duration_len |
| 131 | + res2 = (ranges[i + 1] - prev_begin) % self.duration_len |
| 132 | + if res1 > res2 or res1 == 0: |
| 133 | + if res2 == 0: |
| 134 | + optimal_ranges.append(ranges[i+1]) |
| 135 | + ranges[i+1] = -1 |
| 136 | + else: |
| 137 | + optimal_ranges.append(ranges[i]) |
| 138 | + prev_begin = optimal_ranges[-1] |
| 139 | + if self.verbose: |
| 140 | + print(optimal_ranges) |
| 141 | + if len(sequence) not in optimal_ranges: |
| 142 | + optimal_ranges.append(len(sequence)) |
| 143 | + |
| 144 | + outputs = [] |
| 145 | + start_idx = 0 |
| 146 | + for edge in optimal_ranges: |
| 147 | + sub_sequence = sequence[start_idx:edge] |
| 148 | + start_idx = edge |
| 149 | + non_empty_symbols = None |
| 150 | + if len(sub_sequence) < self.duration_len: |
| 151 | + non_empty_symbols = len(sub_sequence) |
| 152 | + sub_sequence += [_symbol_to_id[' ']] * (self.duration_len - len(sub_sequence)) |
| 153 | + sub_sequence = np.array(sub_sequence) |
| 154 | + sub_sequence = np.expand_dims(sub_sequence, axis=0) |
| 155 | + if self.verbose: |
| 156 | + print("Sub seq shape: {0}".format(sub_sequence.shape)) |
| 157 | + outputs.append(self.infer_duration(sub_sequence, alpha, non_empty_symbols=non_empty_symbols)) |
| 158 | + |
| 159 | + if self.verbose: |
| 160 | + print("Sub AEmb: {0}".format(outputs[-1].shape)) |
| 161 | + |
| 162 | + aligned_emb = np.concatenate(outputs, axis=1) |
| 163 | + mels = [] |
| 164 | + n_iters = aligned_emb.shape[1] // self.forward_len + 1 |
| 165 | + for i in range(n_iters): |
| 166 | + start_idx = i * self.forward_len |
| 167 | + end_idx = min((i+1) * self.forward_len, aligned_emb.shape[1]) |
| 168 | + if start_idx >= aligned_emb.shape[1]: |
| 169 | + break |
| 170 | + sub_aligned_emb = aligned_emb[:, start_idx:end_idx, :] |
| 171 | + if sub_aligned_emb.shape[1] < self.forward_len: |
| 172 | + sub_aligned_emb = np.pad(sub_aligned_emb, |
| 173 | + ((0, 0), (0, self.forward_len - sub_aligned_emb.shape[1]), (0, 0)), |
| 174 | + 'constant', constant_values=0) |
| 175 | + if self.verbose: |
| 176 | + print("SAEmb shape: {0}".format(sub_aligned_emb.shape)) |
| 177 | + mel = self.infer_mel(sub_aligned_emb)[:, :end_idx - start_idx] |
| 178 | + mels.append(mel) |
| 179 | + |
| 180 | + res = np.concatenate(mels, axis=1) |
| 181 | + if self.verbose: |
| 182 | + print("MEL shape :{0}".format(res.shape)) |
| 183 | + |
| 184 | + return res |
0 commit comments