Skip to content

Commit 1de8315

Browse files
committed
faster nsf
1 parent 009dfd4 commit 1de8315

File tree

1 file changed

+13
-35
lines changed

1 file changed

+13
-35
lines changed

models/nsf_HiFigan/models.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -120,41 +120,20 @@ def _f02uv(self, f0):
120120
uv = uv * (f0 > self.voiced_threshold)
121121
return uv
122122

123-
def _f02sine(self, f0_values, upp):
124-
""" f0_values: (batchsize, length, dim)
123+
def _f02sine(self, f0, upp):
124+
""" f0: (batchsize, length, dim)
125125
where dim indicates fundamental tone and overtones
126126
"""
127-
rad_values = (f0_values / self.sampling_rate).fmod(1.) # %1意味着n_har的乘积无法后处理优化
128-
rand_ini = torch.rand(1, self.dim, device=f0_values.device)
129-
rand_ini[:, 0] = 0
130-
rad_values[:, 0, :] += rand_ini
131-
is_half = rad_values.dtype is not torch.float32
132-
tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化
133-
if is_half:
134-
tmp_over_one = tmp_over_one.half()
135-
else:
136-
tmp_over_one = tmp_over_one.float()
137-
tmp_over_one *= upp
138-
tmp_over_one = F.interpolate(
139-
tmp_over_one.transpose(2, 1), scale_factor=upp,
140-
mode='linear', align_corners=True
141-
).transpose(2, 1)
142-
rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
143-
tmp_over_one = tmp_over_one.fmod(1.)
144-
diff = F.conv2d(
145-
tmp_over_one.unsqueeze(1), torch.FloatTensor([[[[-1.], [1.]]]]).to(tmp_over_one.device),
146-
stride=(1, 1), padding=0, dilation=(1, 1)
147-
).squeeze(1) # Equivalent to torch.diff, but able to export ONNX
148-
cumsum_shift = (diff < 0).double()
149-
cumsum_shift = torch.cat((
150-
torch.zeros((f0_values.size()[0], 1, self.dim), dtype=torch.double).to(f0_values.device),
151-
cumsum_shift
152-
), dim=1)
153-
sines = torch.sin(torch.cumsum(rad_values.double() + cumsum_shift, dim=1) * 2 * np.pi)
154-
if is_half:
155-
sines = sines.half()
156-
else:
157-
sines = sines.float()
127+
rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, device=f0.device)
128+
rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5
129+
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
130+
rad += F.pad(rad_acc, (0, 0, 1, -1))
131+
rad = rad.reshape(f0.shape[0], -1, 1)
132+
rad = torch.multiply(rad, torch.arange(1, self.dim + 1, device=f0.device).reshape(1, 1, -1))
133+
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
134+
rand_ini[..., 0] = 0
135+
rad += rand_ini
136+
sines = torch.sin(2 * np.pi * rad)
158137
return sines
159138

160139
@torch.no_grad()
@@ -166,8 +145,7 @@ def forward(self, f0, upp):
166145
output uv: tensor(batchsize=1, length, 1)
167146
"""
168147
f0 = f0.unsqueeze(-1)
169-
fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1)))
170-
sine_waves = self._f02sine(fn, upp) * self.sine_amp
148+
sine_waves = self._f02sine(f0, upp) * self.sine_amp
171149
uv = (f0 > self.voiced_threshold).float()
172150
uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
173151
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3

0 commit comments

Comments
 (0)