|
| 1 | +from copy import deepcopy |
| 2 | + |
| 3 | +import librosa |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | + |
| 7 | +from basics.base_augmentation import BaseAugmentation, require_same_keys |
| 8 | +from basics.base_pe import BasePE |
| 9 | +from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST |
| 10 | +from modules.fastspeech.tts_modules import LengthRegulator |
| 11 | +from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch |
| 12 | +from utils.hparams import hparams |
| 13 | +from utils.infer_utils import resample_align_curve |
| 14 | + |
| 15 | + |
| 16 | +class SpectrogramStretchAugmentation(BaseAugmentation): |
| 17 | + """ |
| 18 | + This class contains methods for frequency-domain and time-domain stretching augmentation. |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None): |
| 22 | + super().__init__(data_dirs, augmentation_args) |
| 23 | + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 24 | + self.lr = LengthRegulator().to(self.device) |
| 25 | + self.pe = pe |
| 26 | + |
| 27 | + @require_same_keys |
| 28 | + def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict: |
| 29 | + aug_item = deepcopy(item) |
| 30 | + waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True) |
| 31 | + mel = get_mel_torch( |
| 32 | + waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'], |
| 33 | + hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'], |
| 34 | + fmin=hparams['fmin'], fmax=hparams['fmax'], |
| 35 | + keyshift=key_shift, speed=speed, device=self.device |
| 36 | + ) |
| 37 | + |
| 38 | + aug_item['mel'] = mel |
| 39 | + |
| 40 | + if speed != 1. or hparams['use_speed_embed']: |
| 41 | + aug_item['length'] = mel.shape[0] |
| 42 | + aug_item['speed'] = int(np.round(hparams['hop_size'] * speed)) / hparams['hop_size'] # real speed |
| 43 | + aug_item['seconds'] /= aug_item['speed'] |
| 44 | + aug_item['ph_dur'] /= aug_item['speed'] |
| 45 | + aug_item['mel2ph'] = get_mel2ph_torch( |
| 46 | + self.lr, torch.from_numpy(aug_item['ph_dur']), aug_item['length'], self.timestep, device=self.device |
| 47 | + ).cpu().numpy() |
| 48 | + |
| 49 | + f0, _ = self.pe.get_pitch( |
| 50 | + waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'], |
| 51 | + hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'], |
| 52 | + speed=speed, interp_uv=True |
| 53 | + ) |
| 54 | + aug_item['f0'] = f0.astype(np.float32) |
| 55 | + |
| 56 | + # NOTE: variance curves are directly resampled according to speed, |
| 57 | + # despite how frequency-domain features change after the augmentation. |
| 58 | + # For acoustic models, this can bring more (but not much) difficulty |
| 59 | + # to learn how variance curves affect the mel spectrograms, since |
| 60 | + # they must realize how the augmentation causes the mismatch. |
| 61 | + # |
| 62 | + # This is a simple way to combine augmentation and variances. However, |
| 63 | + # dealing variance curves like this will decrease the accuracy of |
| 64 | + # variance controls. In most situations, not being ~100% accurate |
| 65 | + # will not ruin the user experience. For example, it does not matter |
| 66 | + # if the energy does not exactly equal the RMS; it is just fine |
| 67 | + # as long as higher energy can bring higher loudness and strength. |
| 68 | + # The neural networks itself cannot be 100% accurate, though. |
| 69 | + # |
| 70 | + # There are yet other choices to simulate variance curves: |
| 71 | + # 1. Re-extract the features from resampled waveforms; |
| 72 | + # 2. Re-extract the features from re-constructed waveforms using |
| 73 | + # the transformed mel spectrograms through the vocoder. |
| 74 | + # But there are actually no perfect ways to make them all accurate |
| 75 | + # and stable. |
| 76 | + for v_name in VARIANCE_CHECKLIST: |
| 77 | + if v_name in item: |
| 78 | + aug_item[v_name] = resample_align_curve( |
| 79 | + aug_item[v_name], |
| 80 | + original_timestep=self.timestep, |
| 81 | + target_timestep=self.timestep * aug_item['speed'], |
| 82 | + align_length=aug_item['length'] |
| 83 | + ) |
| 84 | + |
| 85 | + if key_shift != 0. or hparams['use_key_shift_embed']: |
| 86 | + if replace_spk_id is None: |
| 87 | + aug_item['key_shift'] = key_shift |
| 88 | + else: |
| 89 | + aug_item['spk_id'] = replace_spk_id |
| 90 | + aug_item['f0'] *= 2 ** (key_shift / 12) |
| 91 | + |
| 92 | + return aug_item |
0 commit comments