|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# Copyright 2020 TensorFlowTTS Team. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""Dataset modules.""" |
| 16 | + |
| 17 | +import os |
| 18 | +import numpy as np |
| 19 | +import tensorflow as tf |
| 20 | + |
| 21 | +from tensorflow_tts.datasets.abstract_dataset import AbstractDataset |
| 22 | +from tensorflow_tts.utils import find_files, remove_outlier |
| 23 | + |
| 24 | + |
| 25 | +def average_by_duration(x, durs): |
| 26 | + mel_len = durs.sum() |
| 27 | + durs_cum = np.cumsum(np.pad(durs, (1, 0))) |
| 28 | + |
| 29 | + # calculate charactor f0/energy |
| 30 | + x_char = np.zeros((durs.shape[0],), dtype=np.float32) |
| 31 | + for idx, start, end in zip(range(mel_len), durs_cum[:-1], durs_cum[1:]): |
| 32 | + values = x[start:end][np.where(x[start:end] != 0.0)[0]] |
| 33 | + x_char[idx] = np.mean(values) if len(values) > 0 else 0.0 # np.mean([]) = nan. |
| 34 | + |
| 35 | + return x_char.astype(np.float32) |
| 36 | + |
| 37 | + |
| 38 | +@tf.function( |
| 39 | + input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.int32)] |
| 40 | +) |
| 41 | +def tf_average_by_duration(x, durs): |
| 42 | + outs = tf.numpy_function(average_by_duration, [x, durs], tf.float32) |
| 43 | + return outs |
| 44 | + |
| 45 | + |
| 46 | +class CharactorDurationF0EnergyMelDataset(AbstractDataset): |
| 47 | + """Tensorflow Charactor Duration F0 Energy Mel dataset.""" |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + root_dir, |
| 52 | + charactor_query="*-ids.npy", |
| 53 | + mel_query="*-norm-feats.npy", |
| 54 | + duration_query="*-durations.npy", |
| 55 | + f0_query="*-raw-f0.npy", |
| 56 | + energy_query="*-raw-energy.npy", |
| 57 | + f0_stat="./dump/stats_f0.npy", |
| 58 | + energy_stat="./dump/stats_energy.npy", |
| 59 | + charactor_load_fn=np.load, |
| 60 | + mel_load_fn=np.load, |
| 61 | + duration_load_fn=np.load, |
| 62 | + f0_load_fn=np.load, |
| 63 | + energy_load_fn=np.load, |
| 64 | + mel_length_threshold=0, |
| 65 | + ): |
| 66 | + """Initialize dataset. |
| 67 | +
|
| 68 | + Args: |
| 69 | + root_dir (str): Root directory including dumped files. |
| 70 | + charactor_query (str): Query to find charactor files in root_dir. |
| 71 | + mel_query (str): Query to find feature files in root_dir. |
| 72 | + duration_query (str): Query to find duration files in root_dir. |
| 73 | + f0_query (str): Query to find f0 files in root_dir. |
| 74 | + energy_query (str): Query to find energy files in root_dir. |
| 75 | + f0_stat (str): str path of f0_stat. |
| 76 | + energy_stat (str): str path of energy_stat. |
| 77 | + charactor_load_fn (func): Function to load charactor file. |
| 78 | + mel_load_fn (func): Function to load feature file. |
| 79 | + duration_load_fn (func): Function to load duration file. |
| 80 | + f0_load_fn (func): Function to load f0 file. |
| 81 | + energy_load_fn (func): Function to load energy file. |
| 82 | + mel_length_threshold (int): Threshold to remove short feature files. |
| 83 | +
|
| 84 | + """ |
| 85 | + # find all of charactor and mel files. |
| 86 | + charactor_files = sorted(find_files(root_dir, charactor_query)) |
| 87 | + mel_files = sorted(find_files(root_dir, mel_query)) |
| 88 | + duration_files = sorted(find_files(root_dir, duration_query)) |
| 89 | + f0_files = sorted(find_files(root_dir, f0_query)) |
| 90 | + energy_files = sorted(find_files(root_dir, energy_query)) |
| 91 | + |
| 92 | + # assert the number of files |
| 93 | + assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}." |
| 94 | + assert ( |
| 95 | + len(mel_files) |
| 96 | + == len(charactor_files) |
| 97 | + == len(duration_files) |
| 98 | + == len(f0_files) |
| 99 | + == len(energy_files) |
| 100 | + ), f"Number of charactor, mel, duration, f0 and energy files are different" |
| 101 | + |
| 102 | + if ".npy" in charactor_query: |
| 103 | + suffix = charactor_query[1:] |
| 104 | + utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files] |
| 105 | + |
| 106 | + # set global params |
| 107 | + self.utt_ids = utt_ids |
| 108 | + self.mel_files = mel_files |
| 109 | + self.charactor_files = charactor_files |
| 110 | + self.duration_files = duration_files |
| 111 | + self.f0_files = f0_files |
| 112 | + self.energy_files = energy_files |
| 113 | + self.mel_load_fn = mel_load_fn |
| 114 | + self.charactor_load_fn = charactor_load_fn |
| 115 | + self.duration_load_fn = duration_load_fn |
| 116 | + self.f0_load_fn = f0_load_fn |
| 117 | + self.energy_load_fn = energy_load_fn |
| 118 | + self.mel_length_threshold = mel_length_threshold |
| 119 | + |
| 120 | + self.speakers_map = {} # TODO |
| 121 | + sp_id = 0 |
| 122 | + for i in self.utt_ids: |
| 123 | + sp_name = i.split("_")[0] |
| 124 | + if sp_name not in self.speakers_map: |
| 125 | + self.speakers_map[sp_name] = sp_id |
| 126 | + sp_id += 1 |
| 127 | + self.speakers = [self.speakers_map[i.split("_")[0]] for i in self.utt_ids] # TODO change but at the moment mfa folder name = speaker name |
| 128 | + |
| 129 | + self.f0_stat = np.load(f0_stat) |
| 130 | + self.energy_stat = np.load(energy_stat) |
| 131 | + |
| 132 | + def get_args(self): |
| 133 | + return [self.utt_ids] |
| 134 | + |
| 135 | + def _norm_mean_std(self, x, mean, std): |
| 136 | + x = remove_outlier(x) |
| 137 | + zero_idxs = np.where(x == 0.0)[0] |
| 138 | + x = (x - mean) / std |
| 139 | + x[zero_idxs] = 0.0 |
| 140 | + return x |
| 141 | + |
| 142 | + def generator(self, utt_ids): |
| 143 | + for i, utt_id in enumerate(utt_ids): |
| 144 | + mel_file = self.mel_files[i] |
| 145 | + charactor_file = self.charactor_files[i] |
| 146 | + duration_file = self.duration_files[i] |
| 147 | + f0_file = self.f0_files[i] |
| 148 | + energy_file = self.energy_files[i] |
| 149 | + mel = self.mel_load_fn(mel_file) |
| 150 | + charactor = self.charactor_load_fn(charactor_file) |
| 151 | + duration = self.duration_load_fn(duration_file) |
| 152 | + f0 = self.f0_load_fn(f0_file) |
| 153 | + energy = self.energy_load_fn(energy_file) |
| 154 | + |
| 155 | + f0 = self._norm_mean_std(f0, self.f0_stat[0], self.f0_stat[1]) |
| 156 | + energy = self._norm_mean_std( |
| 157 | + energy, self.energy_stat[0], self.energy_stat[1] |
| 158 | + ) |
| 159 | + |
| 160 | + # calculate charactor f0/energy |
| 161 | + f0 = tf_average_by_duration(f0, duration) |
| 162 | + energy = tf_average_by_duration(energy, duration) |
| 163 | + speaker_id = self.speakers[i] |
| 164 | + items = { |
| 165 | + "utt_ids": utt_id, |
| 166 | + "input_ids": charactor, |
| 167 | + "speaker_ids": speaker_id, |
| 168 | + "duration_gts": duration, |
| 169 | + "f0_gts": f0, |
| 170 | + "energy_gts": energy, |
| 171 | + "mel_gts": mel, |
| 172 | + "mel_lengths": len(mel), |
| 173 | + } |
| 174 | + |
| 175 | + yield items |
| 176 | + |
| 177 | + def create( |
| 178 | + self, |
| 179 | + allow_cache=False, |
| 180 | + batch_size=1, |
| 181 | + is_shuffle=False, |
| 182 | + map_fn=None, |
| 183 | + reshuffle_each_iteration=True, |
| 184 | + ): |
| 185 | + """Create tf.dataset function.""" |
| 186 | + output_types = self.get_output_dtypes() |
| 187 | + datasets = tf.data.Dataset.from_generator( |
| 188 | + self.generator, output_types=output_types, args=(self.get_args()) |
| 189 | + ) |
| 190 | + |
| 191 | + datasets = datasets.filter( |
| 192 | + lambda x: x["mel_lengths"] > self.mel_length_threshold |
| 193 | + ) |
| 194 | + |
| 195 | + if allow_cache: |
| 196 | + datasets = datasets.cache() |
| 197 | + |
| 198 | + if is_shuffle: |
| 199 | + datasets = datasets.shuffle( |
| 200 | + self.get_len_dataset(), |
| 201 | + reshuffle_each_iteration=reshuffle_each_iteration, |
| 202 | + ) |
| 203 | + |
| 204 | + # define padded shapes |
| 205 | + padded_shapes = { |
| 206 | + "utt_ids": [], |
| 207 | + "input_ids": [None], |
| 208 | + "speaker_ids": [], |
| 209 | + "duration_gts": [None], |
| 210 | + "f0_gts": [None], |
| 211 | + "energy_gts": [None], |
| 212 | + "mel_gts": [None, None], |
| 213 | + "mel_lengths": [], |
| 214 | + } |
| 215 | + |
| 216 | + datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes) |
| 217 | + datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE) |
| 218 | + return datasets |
| 219 | + |
| 220 | + def get_output_dtypes(self): |
| 221 | + output_types = { |
| 222 | + "utt_ids": tf.string, |
| 223 | + "input_ids": tf.int32, |
| 224 | + "speaker_ids": tf.int32, |
| 225 | + "duration_gts": tf.int32, |
| 226 | + "f0_gts": tf.float32, |
| 227 | + "energy_gts": tf.float32, |
| 228 | + "mel_gts": tf.float32, |
| 229 | + "mel_lengths": tf.int32, |
| 230 | + } |
| 231 | + return output_types |
| 232 | + |
| 233 | + def get_len_dataset(self): |
| 234 | + return len(self.utt_ids) |
| 235 | + |
| 236 | + def __name__(self): |
| 237 | + return "CharactorDurationF0EnergyMelDataset" |
0 commit comments