Skip to content

Commit a2b1a55

Browse files
committed
add a mini-nsf-hifigan model
1 parent 1de8315 commit a2b1a55

File tree

9 files changed

+73
-928
lines changed

9 files changed

+73
-928
lines changed

configs/base_hifi.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ mel_vmin: -6. #-6.
4242
mel_vmax: 1.5
4343

4444

45+
mini_nsf: false
4546
audio_sample_rate: 44100
4647
audio_num_mel_bins: 128
4748
hop_size: 512 # Hop size.

configs/base_hifi_chroma.yaml

Lines changed: 0 additions & 123 deletions
This file was deleted.

configs/ft_hifigan.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ mel_vmin: -6. #-6.
4242
mel_vmax: 1.5
4343

4444

45+
mini_nsf: false
4546
audio_sample_rate: 44100
4647
audio_num_mel_bins: 128
4748
hop_size: 512 # Hop size.

export_ckpt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def export(exp_name, ckpt_path, save_path, work_dir):
4747
new_config['win_size'] = config['win_size']
4848
new_config['fmin'] = config['fmin']
4949
new_config['fmax'] = config['fmax']
50+
if 'mini_nsf' in config.keys():
51+
new_config['mini_nsf'] = config['mini_nsf']
52+
else:
53+
new_config['mini_nsf'] = False
5054
json_file.write(json.dumps(new_config, indent=1))
5155
print("Export configuration file successfully: ", new_config_file)
5256

models/nsf_HiFigan/models.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ def init_weights(m, mean=0.0, std=0.01):
1818
classname = m.__class__.__name__
1919
if classname.find("Conv") != -1:
2020
m.weight.data.normal_(mean, std)
21+
m.bias.data.normal_(mean, std)
2122

2223

2324
def get_padding(kernel_size, dilation=1):
2425
return int((kernel_size * dilation - dilation) / 2)
2526

27+
2628
class ResBlock1(torch.nn.Module):
2729
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
2830
super(ResBlock1, self).__init__()
@@ -199,46 +201,74 @@ def __init__(self, h):
199201
self.h = h
200202
self.num_kernels = len(h.resblock_kernel_sizes)
201203
self.num_upsamples = len(h.upsample_rates)
202-
self.m_source = SourceModuleHnNSF(
203-
sampling_rate=h.sampling_rate,
204-
harmonic_num=8
205-
)
206-
self.noise_convs = nn.ModuleList()
207-
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
208-
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
209-
204+
self.mini_nsf = h.mini_nsf
205+
206+
if h.mini_nsf:
207+
self.source_sr = h.sampling_rate / int(np.prod(h.upsample_rates[2: ]))
208+
self.upp = int(np.prod(h.upsample_rates[: 2]))
209+
else:
210+
self.source_sr = h.sampling_rate
211+
self.upp = int(np.prod(h.upsample_rates))
212+
self.m_source = SourceModuleHnNSF(
213+
sampling_rate=h.sampling_rate,
214+
harmonic_num=8
215+
)
216+
self.noise_convs = nn.ModuleList()
217+
218+
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
219+
210220
self.ups = nn.ModuleList()
211-
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
212-
c_cur = h.upsample_initial_channel // (2 ** (i + 1))
213-
self.ups.append(weight_norm(
214-
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
215-
k, u, padding=(k - u) // 2)))
216-
if i + 1 < len(h.upsample_rates): #
217-
stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
218-
self.noise_convs.append(Conv1d(
219-
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
220-
else:
221-
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
222221
self.resblocks = nn.ModuleList()
222+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
223223
ch = h.upsample_initial_channel
224-
for i in range(len(self.ups)):
224+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
225225
ch //= 2
226+
self.ups.append(weight_norm(ConvTranspose1d(2 * ch, ch, k, u, padding=(k - u) // 2)))
226227
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
227228
self.resblocks.append(resblock(h, ch, k, d))
229+
if not h.mini_nsf:
230+
if i + 1 < len(h.upsample_rates): #
231+
stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
232+
self.noise_convs.append(Conv1d(
233+
1, ch, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
234+
else:
235+
self.noise_convs.append(Conv1d(1, ch, kernel_size=1))
236+
elif i == 1:
237+
self.source_conv = Conv1d(1, ch, 1)
238+
self.source_conv.apply(init_weights)
228239

229240
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
241+
230242
self.ups.apply(init_weights)
231243
self.conv_post.apply(init_weights)
232-
self.upp = int(np.prod(h.upsample_rates))
233-
244+
245+
def fastsinegen(self, f0):
246+
n = torch.arange(1, self.upp + 1, device=f0.device)
247+
s0 = f0.unsqueeze(-1) / self.source_sr
248+
ds0 = F.pad(s0[:, 1:, :] - s0[:, :-1, :], (0, 0, 0, 1))
249+
rad = s0 * n + 0.5 * ds0 * n * (n - 1) / self.upp
250+
rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5
251+
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
252+
rad += F.pad(rad_acc, (0, 0, 1, -1))
253+
rad = rad.reshape(f0.shape[0], 1, -1)
254+
sines = torch.sin(2 * np.pi * rad)
255+
return sines
256+
234257
def forward(self, x, f0):
235-
har_source = self.m_source(f0, self.upp).transpose(1, 2)
258+
if self.mini_nsf:
259+
har_source = self.fastsinegen(f0)
260+
else:
261+
har_source = self.m_source(f0, self.upp).transpose(1, 2)
236262
x = self.conv_pre(x)
237263
for i in range(self.num_upsamples):
238264
x = F.leaky_relu(x, LRELU_SLOPE)
239265
x = self.ups[i](x)
240-
x_source = self.noise_convs[i](har_source)
241-
x = x + x_source
266+
if not self.mini_nsf:
267+
x_source = self.noise_convs[i](har_source)
268+
x = x + x_source
269+
elif i == 1:
270+
x_source = self.source_conv(har_source)
271+
x = x + x_source
242272
xs = None
243273
for j in range(self.num_kernels):
244274
if xs is None:
@@ -249,7 +279,6 @@ def forward(self, x, f0):
249279
x = F.leaky_relu(x)
250280
x = self.conv_post(x)
251281
x = torch.tanh(x)
252-
253282
return x
254283

255284
def remove_weight_norm(self):

models/nsf_HiFigan_chroma/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)