Skip to content

Commit 885082f

Browse files
committed
add pc-nsf training method
1 parent 6aab41f commit 885082f

File tree

6 files changed

+101
-116
lines changed

6 files changed

+101
-116
lines changed

configs/base_hifi.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ pe: 'parselmouth' # 'parselmouth' or 'harvest'
1010
f0_min: 65
1111
f0_max: 1100
1212

13+
pc_aug: false # pc-nsf training method
14+
pc_aug_prob: 0.5
15+
pc_aug_key: 5
16+
1317
aug_min: 0.9
1418
aug_max: 1.4
1519
aug_num: 1
@@ -34,15 +38,14 @@ valid_set_name: valid
3438
train_set_name: train
3539

3640

37-
volume_aug: True
41+
volume_aug: true
3842
volume_aug_prob: 0.5
3943

4044

4145
mel_vmin: -6. #-6.
4246
mel_vmax: 1.5
4347

4448

45-
mini_nsf: false
4649
audio_sample_rate: 44100
4750
audio_num_mel_bins: 128
4851
hop_size: 512 # Hop size.
@@ -62,6 +65,7 @@ crop_mel_frames: 20
6265

6366
#model_cls: training.nsf_HiFigan_task.nsf_HiFigan
6467
model_args:
68+
mini_nsf: false
6569
upsample_rates: [ 8, 8, 2, 2, 2 ]
6670
upsample_kernel_sizes: [ 16,16, 4, 4, 4 ]
6771
upsample_initial_channel: 512

configs/ft_hifigan.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ pe: 'parselmouth' # 'parselmouth' or 'harvest'
1010
f0_min: 65
1111
f0_max: 1100
1212

13+
pc_aug: false # pc-nsf training method
14+
pc_aug_prob: 0.5
15+
pc_aug_key: 5
16+
1317
aug_min: 0.9
1418
aug_max: 1.4
1519
aug_num: 1
@@ -42,7 +46,6 @@ mel_vmin: -6. #-6.
4246
mel_vmax: 1.5
4347

4448

45-
mini_nsf: false
4649
audio_sample_rate: 44100
4750
audio_num_mel_bins: 128
4851
hop_size: 512 # Hop size.
@@ -63,6 +66,7 @@ crop_mel_frames: 32
6366

6467
#model_cls: training.nsf_HiFigan_task.nsf_HiFigan
6568
model_args:
69+
mini_nsf: false
6670
upsample_rates: [ 8, 8, 2, 2, 2 ]
6771
upsample_kernel_sizes: [ 16,16, 4, 4, 4 ]
6872
upsample_initial_channel: 512

export_ckpt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ def export(exp_name, ckpt_path, save_path, work_dir):
4747
new_config['win_size'] = config['win_size']
4848
new_config['fmin'] = config['fmin']
4949
new_config['fmax'] = config['fmax']
50-
if 'mini_nsf' in config.keys():
51-
new_config['mini_nsf'] = config['mini_nsf']
52-
else:
50+
if 'mini_nsf' not in new_config.keys():
5351
new_config['mini_nsf'] = False
5452
json_file.write(json.dumps(new_config, indent=1))
5553
print("Export configuration file successfully: ", new_config_file)

modules/loss/HiFiloss.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,21 @@ def discriminator_loss(self, disc_real_outputs, disc_generated_outputs):
2929
loss = 0
3030
rlosses = 0
3131
glosses = 0
32-
r_losses = []
33-
g_losses = []
3432

3533
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
3634
r_loss = torch.mean((1 - dr) ** 2)
3735
g_loss = torch.mean(dg ** 2)
3836
loss += r_loss + g_loss
3937
rlosses += r_loss.item()
4038
glosses += g_loss.item()
41-
r_losses.append(r_loss.item())
42-
g_losses.append(g_loss.item())
4339

44-
return loss, rlosses, glosses, r_losses, g_losses
40+
return loss, rlosses, glosses
4541

4642
def Dloss(self, Dfake, Dtrue):
47-
4843
(Fmsd_out, _), (Fmpd_out, _) = Dfake
4944
(Tmsd_out, _), (Tmpd_out, _) = Dtrue
50-
msdloss, msdrlosses, msdglosses, _, _ = self.discriminator_loss(Tmsd_out, Fmsd_out)
51-
mpdloss, mpdrlosses, mpdglosses, _, _ = self.discriminator_loss(Tmpd_out, Fmpd_out)
45+
msdloss, msdrlosses, msdglosses = self.discriminator_loss(Tmsd_out, Fmsd_out)
46+
mpdloss, mpdrlosses, mpdglosses = self.discriminator_loss(Tmpd_out, Fmpd_out)
5247
loss = msdloss + mpdloss
5348
return loss, {'DmsdlossF': msdglosses, 'DmsdlossT': msdrlosses, 'DmpdlossT': mpdrlosses,
5449
'DmpdlossF': mpdglosses}
@@ -57,55 +52,42 @@ def feature_loss(self, fmap_r, fmap_g):
5752
loss = 0
5853
for dr, dg in zip(fmap_r, fmap_g):
5954
for rl, gl in zip(dr, dg):
60-
loss += torch.mean(torch.abs(rl - gl))
61-
55+
b = min(rl.shape[0], gl.shape[0])
56+
loss += torch.mean(torch.abs(rl[: b] - gl[: b]))
6257
return loss * 2
6358

6459
def GDloss(self, GDfake, GDtrue):
6560
loss = 0
66-
gen_losses = []
6761
msd_losses = 0
6862
mpd_losses = 0
63+
6964
(msd_out, Fmsd_feature), (mpd_out, Fmpd_feature) = GDfake
7065
(_, Tmsd_feature), (_, Tmpd_feature) = GDtrue
66+
7167
for dg in msd_out:
72-
l = torch.mean((1 - dg) ** 2)
73-
gen_losses.append(l.item())
74-
# loss += l
75-
msd_losses = l + msd_losses
76-
68+
msd_losses += torch.mean((1 - dg) ** 2)
7769
for dg in mpd_out:
78-
l = torch.mean((1 - dg) ** 2)
79-
gen_losses.append(l.item())
80-
# loss += l
81-
mpd_losses = l + mpd_losses
82-
70+
mpd_losses += torch.mean((1 - dg) ** 2)
71+
8372
msd_feature_loss = self.feature_loss(Tmsd_feature, Fmsd_feature)
8473
mpd_feature_loss = self.feature_loss(Tmpd_feature, Fmpd_feature)
85-
# loss +=msd_feature_loss
86-
# loss +=mpd_feature_loss
74+
8775
loss = msd_feature_loss + mpd_feature_loss + mpd_losses + msd_losses
88-
# (msd_losses, mpd_losses), (msd_feature_loss, mpd_feature_loss), gen_losses
76+
8977
return loss, {'Gmsdloss': msd_losses, 'Gmpdloss': mpd_losses, 'Gmsd_feature_loss': msd_feature_loss,
9078
'Gmpd_feature_loss': mpd_feature_loss}
9179

92-
def Auxloss(self, Goutput, sample):
93-
Gmel = self.mel.dynamic_range_compression_torch(self.mel(Goutput['audio'].squeeze(1)))
94-
Rmel = self.mel.dynamic_range_compression_torch(self.mel(sample['audio'].squeeze(1)))
80+
def Auxloss(self, Goutput, sample):
81+
Gwav = Goutput['audio'].squeeze(1)
82+
Rwav = sample['audio'].squeeze(1)
83+
b = min(Gwav.shape[0], Rwav.shape[0])
84+
Gmel = self.mel.dynamic_range_compression_torch(self.mel(Gwav[: b]))
85+
Rmel = self.mel.dynamic_range_compression_torch(self.mel(Rwav[: b]))
9586
mel_loss = self.L1loss(Gmel, Rmel) * self.lab_aux_mel_loss
9687
if self.use_stftloss:
97-
sc_loss, mag_loss = self.stft.stft(Goutput['audio'].squeeze(1), sample['audio'].squeeze(1))
88+
sc_loss, mag_loss = self.stft.stft(Gwav[: b], Rwav[: b])
9889
stft_loss = (sc_loss + mag_loss) * self.lab_aux_stft_loss
9990
loss = mel_loss + stft_loss
100-
return loss, {'auxloss': loss, 'auxloss_mel': mel_loss, 'auxloss_stft': stft_loss}
101-
return mel_loss, {'auxloss': mel_loss}
102-
103-
# def Auxloss(self,Goutput, sample):
104-
#
105-
# Gmel=self.mel.dynamic_range_compression_torch(self.mel(Goutput['audio'].squeeze(1)))
106-
# # Rmel=sample['mel']
107-
# Rmel = self.mel.dynamic_range_compression_torch(self.mel(sample['audio'].squeeze(1)))
108-
# sc_loss, mag_loss=self.stft.stft(Goutput['audio'].squeeze(1), sample['audio'].squeeze(1))
109-
# loss=(sc_loss+ mag_loss)*self.labauxloss
110-
# return loss,{'auxloss':loss,'auxloss_sc_loss':sc_loss,'auxloss_mag_loss':mag_loss}
111-
#
91+
return loss, {'aux_mel_loss': mel_loss, 'aux_stft_loss': stft_loss}
92+
return mel_loss, {'aux_mel_loss': mel_loss}
93+

modules/loss/stft_loss.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ def stft(x, fft_size, hop_size, win_length, window):
2424
2525
"""
2626
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
27-
real = x_stft.real
28-
imag = x_stft.imag
2927

3028
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
31-
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
29+
return torch.clamp(x_stft.abs(), min=10**(-3.5)).transpose(2, 1)
3230

3331

3432
class SpectralConvergenceLoss(torch.nn.Module):
@@ -108,12 +106,10 @@ def forward(self, x, y):
108106

109107

110108
class warp_stft:
111-
def __init__(self,cfg={},divce='cuda'):
112-
self.stft=MultiResolutionSTFTLoss(**cfg).to(divce)
109+
def __init__(self, cfg={}, device='cuda'):
110+
self.stft = MultiResolutionSTFTLoss(**cfg).to(device)
113111

114-
115-
116-
def loss(self,x, y):
112+
def loss(self, x, y):
117113
return self.stft(x, y)
118114

119115

0 commit comments

Comments
 (0)