Skip to content

Commit a0bacd1

Browse files
committed
optimize online data augmentation
1 parent 063a99e commit a0bacd1

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

training/nsf_HiFigan_task.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,23 @@ def __getitem__(self, index):
8383
if random.random() < self.key_aug_prob:
8484
audio = torch.from_numpy(data['audio'])
8585
speed = random.uniform(self.config['aug_min'], self.config['aug_max'])
86-
audiox = wav_aug(audio, self.config["hop_size"], speed=speed)
87-
mel = dynamic_range_compression_torch(self.mel_spec_transform(audiox[None,:]))
88-
f0, uv = get_pitch(audio.numpy(), hparams=self.config, speed=speed, interp_uv=True, length=len(mel[0].T))
86+
crop_mel_frames = int(np.ceil((self.config['crop_mel_frames'] + 4) * speed))
87+
samples_per_frame = self.config['hop_size']
88+
crop_wav_samples = crop_mel_frames * samples_per_frame
89+
if crop_wav_samples < audio.shape[0]:
90+
return {'f0': data['f0'], 'spectrogram': data['mel'], 'audio': data['audio']}
91+
start = random.randint(0, audio.shape[0] - 1 - crop_wav_samples)
92+
end = start + crop_wav_samples
93+
audio = audio[start:end]
94+
f0, uv = get_pitch(audio.numpy(), hparams=self.config, speed=speed, interp_uv=True, length=mel.shape[-1])
8995
if f0 is None:
9096
return {'f0': data['f0'], 'spectrogram': data['mel'], 'audio': data['audio']}
91-
f0 *= speed
92-
return {'f0': f0, 'spectrogram': mel[0].T.numpy(), 'audio': audiox.numpy()}
97+
audio_aug = wav_aug(audio, self.config["hop_size"], speed=speed)
98+
mel_aug = dynamic_range_compression_torch(self.mel_spec_transform(audio_aug[None,:]))
99+
audio_aug = audio_aug[2*samples_per_frame: -2*samples_per_frame].numpy()
100+
mel_aug = mel_aug[0, :, 2:-2].T.numpy()
101+
f0_aug = f0[2:-2] * speed
102+
return {'f0': f0_aug, 'spectrogram': mel_aug, 'audio': audio_aug}
93103

94104
else:
95105
return {'f0': data['f0'], 'spectrogram': data['mel'], 'audio': data['audio']}
@@ -107,13 +117,15 @@ def collater(self, minibatch):
107117
for record in minibatch:
108118

109119
# Filter out records that aren't long enough.
110-
if len(record['spectrogram']) <= crop_mel_frames:
120+
if record['spectrogram'].shape[0] < crop_mel_frames:
111121
del record['spectrogram']
112122
del record['audio']
113123
del record['f0']
114124
continue
115-
116-
start = random.randint(0, record['spectrogram'].shape[0] - 1 - crop_mel_frames)
125+
elif record['spectrogram'].shape[0] == crop_mel_frames:
126+
start = 0
127+
else:
128+
start = random.randint(0, record['spectrogram'].shape[0] - 1 - crop_mel_frames)
117129
end = start + crop_mel_frames
118130
if self.infer:
119131
record['spectrogram'] = record['spectrogram'].T

0 commit comments

Comments
 (0)