|
| 1 | +import os |
| 2 | +import random |
| 3 | +from copy import deepcopy |
| 4 | +import pandas as pd |
| 5 | +import logging |
| 6 | +from tqdm import tqdm |
| 7 | +import json |
| 8 | +import glob |
| 9 | +from resemblyzer import VoiceEncoder |
| 10 | +import traceback |
| 11 | +import numpy as np |
| 12 | +import pretty_midi |
| 13 | +import librosa |
| 14 | +from scipy.interpolate import interp1d |
| 15 | + |
| 16 | +from utils.hparams import hparams |
| 17 | +from data_gen.tts.data_gen_utils import build_phone_encoder |
| 18 | +from utils.pitch_utils import f0_to_coarse |
| 19 | +from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError |
| 20 | +from data_gen.tts.binarizer_zh import ZhBinarizer |
| 21 | +from vocoders.base_vocoder import VOCODERS |
| 22 | + |
| 23 | + |
| 24 | +def split_train_test_set(item_names): |
| 25 | + item_names = deepcopy(item_names) |
| 26 | + test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])] |
| 27 | + train_item_names = [x for x in item_names if x not in set(test_item_names)] |
| 28 | + logging.info("train {}".format(len(train_item_names))) |
| 29 | + logging.info("test {}".format(len(test_item_names))) |
| 30 | + return train_item_names, test_item_names |
| 31 | + |
| 32 | + |
| 33 | +class SingingBinarizer(BaseBinarizer): |
| 34 | + def __init__(self, processed_data_dir=None): |
| 35 | + if processed_data_dir is None: |
| 36 | + processed_data_dir = hparams['processed_data_dir'] |
| 37 | + self.processed_data_dirs = processed_data_dir.split(",") |
| 38 | + self.binarization_args = hparams['binarization_args'] |
| 39 | + self.pre_align_args = hparams['pre_align_args'] |
| 40 | + self.item2txt = {} |
| 41 | + self.item2ph = {} |
| 42 | + self.item2wavfn = {} |
| 43 | + self.item2f0fn = {} |
| 44 | + self.item2tgfn = {} |
| 45 | + self.item2spk = {} |
| 46 | + |
| 47 | + def load_meta_data(self): |
| 48 | + for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): |
| 49 | + wav_suffix = '_wf0.wav' |
| 50 | + txt_suffix = '.txt' |
| 51 | + ph_suffix = '_ph.txt' |
| 52 | + tg_suffix = '.TextGrid' |
| 53 | + all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}') |
| 54 | + |
| 55 | + for piece_path in all_wav_pieces: |
| 56 | + item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)] |
| 57 | + if len(self.processed_data_dirs) > 1: |
| 58 | + item_name = f'ds{ds_id}_{item_name}' |
| 59 | + self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline() |
| 60 | + self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline() |
| 61 | + self.item2wavfn[item_name] = piece_path |
| 62 | + |
| 63 | + self.item2spk[item_name] = 'SPK1' |
| 64 | + if len(self.processed_data_dirs) > 1: |
| 65 | + self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}" |
| 66 | + self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix) |
| 67 | + |
| 68 | + self.item_names = sorted(list(self.item2txt.keys())) |
| 69 | + if self.binarization_args['shuffle']: |
| 70 | + random.seed(1234) |
| 71 | + random.shuffle(self.item_names) |
| 72 | + self._train_item_names, self._test_item_names = split_train_test_set(self.item_names) |
| 73 | + |
| 74 | + @property |
| 75 | + def train_item_names(self): |
| 76 | + return self._train_item_names |
| 77 | + |
| 78 | + @property |
| 79 | + def valid_item_names(self): |
| 80 | + return self._test_item_names |
| 81 | + |
| 82 | + @property |
| 83 | + def test_item_names(self): |
| 84 | + return self._test_item_names |
| 85 | + |
| 86 | + def process(self): |
| 87 | + self.load_meta_data() |
| 88 | + os.makedirs(hparams['binary_data_dir'], exist_ok=True) |
| 89 | + self.spk_map = self.build_spk_map() |
| 90 | + print("| spk_map: ", self.spk_map) |
| 91 | + spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json" |
| 92 | + json.dump(self.spk_map, open(spk_map_fn, 'w')) |
| 93 | + |
| 94 | + self.phone_encoder = self._phone_encoder() |
| 95 | + self.process_data('valid') |
| 96 | + self.process_data('test') |
| 97 | + self.process_data('train') |
| 98 | + |
| 99 | + def _phone_encoder(self): |
| 100 | + ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json" |
| 101 | + ph_set = [] |
| 102 | + if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn): |
| 103 | + for ph_sent in self.item2ph.values(): |
| 104 | + ph_set += ph_sent.split(' ') |
| 105 | + ph_set = sorted(set(ph_set)) |
| 106 | + json.dump(ph_set, open(ph_set_fn, 'w')) |
| 107 | + print("| Build phone set: ", ph_set) |
| 108 | + else: |
| 109 | + ph_set = json.load(open(ph_set_fn, 'r')) |
| 110 | + print("| Load phone set: ", ph_set) |
| 111 | + return build_phone_encoder(hparams['binary_data_dir']) |
| 112 | + |
| 113 | + # @staticmethod |
| 114 | + # def get_pitch(wav_fn, spec, res): |
| 115 | + # wav_suffix = '_wf0.wav' |
| 116 | + # f0_suffix = '_f0.npy' |
| 117 | + # f0fn = wav_fn.replace(wav_suffix, f0_suffix) |
| 118 | + # pitch_info = np.load(f0fn) |
| 119 | + # f0 = [x[1] for x in pitch_info] |
| 120 | + # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)] |
| 121 | + # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)] |
| 122 | + # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)] |
| 123 | + # # f0_x_coor = np.arange(0, 1, 1 / len(f0)) |
| 124 | + # # f0_x_coor[-1] = 1 |
| 125 | + # # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)] |
| 126 | + # if sum(f0) == 0: |
| 127 | + # raise BinarizationError("Empty f0") |
| 128 | + # assert len(f0) == len(spec), (len(f0), len(spec)) |
| 129 | + # pitch_coarse = f0_to_coarse(f0) |
| 130 | + # |
| 131 | + # # vis f0 |
| 132 | + # # import matplotlib.pyplot as plt |
| 133 | + # # from textgrid import TextGrid |
| 134 | + # # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid') |
| 135 | + # # fig = plt.figure(figsize=(12, 6)) |
| 136 | + # # plt.pcolor(spec.T, vmin=-5, vmax=0) |
| 137 | + # # ax = plt.gca() |
| 138 | + # # ax2 = ax.twinx() |
| 139 | + # # ax2.plot(f0, color='red') |
| 140 | + # # ax2.set_ylim(0, 800) |
| 141 | + # # itvs = TextGrid.fromFile(tg_fn)[0] |
| 142 | + # # for itv in itvs: |
| 143 | + # # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size'] |
| 144 | + # # plt.vlines(x=x, ymin=0, ymax=80, color='black') |
| 145 | + # # plt.text(x=x, y=20, s=itv.mark, color='black') |
| 146 | + # # plt.savefig('tmp/20211229_singing_plots_test.png') |
| 147 | + # |
| 148 | + # res['f0'] = f0 |
| 149 | + # res['pitch'] = pitch_coarse |
| 150 | + |
| 151 | + @classmethod |
| 152 | + def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): |
| 153 | + if hparams['vocoder'] in VOCODERS: |
| 154 | + wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) |
| 155 | + else: |
| 156 | + wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) |
| 157 | + res = { |
| 158 | + 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, |
| 159 | + 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id |
| 160 | + } |
| 161 | + try: |
| 162 | + if binarization_args['with_f0']: |
| 163 | + # cls.get_pitch(wav_fn, mel, res) |
| 164 | + cls.get_pitch(wav, mel, res) |
| 165 | + if binarization_args['with_txt']: |
| 166 | + try: |
| 167 | + # print(ph) |
| 168 | + phone_encoded = res['phone'] = encoder.encode(ph) |
| 169 | + except: |
| 170 | + traceback.print_exc() |
| 171 | + raise BinarizationError(f"Empty phoneme") |
| 172 | + if binarization_args['with_align']: |
| 173 | + cls.get_align(tg_fn, ph, mel, phone_encoded, res) |
| 174 | + except BinarizationError as e: |
| 175 | + print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") |
| 176 | + return None |
| 177 | + return res |
| 178 | + |
| 179 | + |
| 180 | +class MidiSingingBinarizer(SingingBinarizer): |
| 181 | + @staticmethod |
| 182 | + def get_pitch(wav_fn, spec, res): |
| 183 | + wav_suffix = '_wf0.wav' |
| 184 | + midi_suffix = '.mid' |
| 185 | + |
| 186 | + ## aux f0 |
| 187 | + # f0_suffix = '_f0.npy' |
| 188 | + # f0fn = wav_fn.replace(wav_suffix, f0_suffix) |
| 189 | + # pitch_info = np.load(f0fn) |
| 190 | + # f0 = [x[1] for x in pitch_info] |
| 191 | + # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)] |
| 192 | + # |
| 193 | + # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)] |
| 194 | + # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)] |
| 195 | + |
| 196 | + ## read midi |
| 197 | + midi_fn = wav_fn.replace(wav_suffix, midi_suffix) |
| 198 | + pm = pretty_midi.PrettyMIDI(midi_fn) |
| 199 | + notes = np.zeros([len(spec)]) |
| 200 | + for n in pm.instruments[0].notes: |
| 201 | + sps = hparams['audio_sample_rate'] / hparams['hop_size'] |
| 202 | + notes[int(n.start * sps):int(n.end * sps)] = librosa.midi_to_hz(n.pitch) |
| 203 | + |
| 204 | + # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)] |
| 205 | + # note_x_coor = np.arange(0, 1, 1 / len(notes))[:len(notes)] |
| 206 | + # notes = interp1d(note_x_coor, notes, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)] |
| 207 | + |
| 208 | + f0 = notes |
| 209 | + |
| 210 | + if sum(f0) == 0: |
| 211 | + raise BinarizationError("Empty f0") |
| 212 | + assert len(f0) == len(spec), (len(f0), len(spec)) |
| 213 | + pitch_coarse = f0_to_coarse(f0) |
| 214 | + |
| 215 | + # # vis f0 |
| 216 | + # import matplotlib.pyplot as plt |
| 217 | + # from textgrid import TextGrid |
| 218 | + # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid') |
| 219 | + # fig = plt.figure(figsize=(12, 6)) |
| 220 | + # plt.pcolor(spec.T, vmin=-5, vmax=0) |
| 221 | + # ax = plt.gca() |
| 222 | + # ax2 = ax.twinx() |
| 223 | + # ax2.plot(f0, color='red') |
| 224 | + # ax2.plot(notes, color='white') |
| 225 | + # ax2.set_ylim(0, 800) |
| 226 | + # itvs = TextGrid.fromFile(tg_fn)[0] |
| 227 | + # for itv in itvs: |
| 228 | + # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size'] |
| 229 | + # plt.vlines(x=x, ymin=0, ymax=80, color='black') |
| 230 | + # plt.text(x=x, y=20, s=itv.mark, color='black') |
| 231 | + # plt.savefig('tmp/1231_singing_plots_test.png') |
| 232 | + |
| 233 | + res['f0'] = f0 |
| 234 | + res['pitch'] = pitch_coarse |
| 235 | + |
| 236 | + @classmethod |
| 237 | + def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): |
| 238 | + if hparams['vocoder'] in VOCODERS: |
| 239 | + wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) |
| 240 | + else: |
| 241 | + wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) |
| 242 | + res = { |
| 243 | + 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, |
| 244 | + 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id |
| 245 | + } |
| 246 | + try: |
| 247 | + if binarization_args['with_f0']: |
| 248 | + cls.get_pitch(wav_fn, mel, res) |
| 249 | + if binarization_args['with_txt']: |
| 250 | + try: |
| 251 | + # print(ph) |
| 252 | + phone_encoded = res['phone'] = encoder.encode(ph) |
| 253 | + except: |
| 254 | + traceback.print_exc() |
| 255 | + raise BinarizationError(f"Empty phoneme") |
| 256 | + if binarization_args['with_align']: |
| 257 | + cls.get_align(tg_fn, ph, mel, phone_encoded, res) |
| 258 | + except BinarizationError as e: |
| 259 | + print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") |
| 260 | + return None |
| 261 | + return res |
| 262 | + |
| 263 | + |
| 264 | +class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer): |
| 265 | + pass |
| 266 | + |
| 267 | + |
| 268 | +if __name__ == "__main__": |
| 269 | + SingingBinarizer().process() |
0 commit comments