diff --git a/models.py b/models.py index da233d02d..459ff619e 100644 --- a/models.py +++ b/models.py @@ -3,7 +3,7 @@ import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from utils import init_weights, get_padding +from utils import init_weights, get_padding, apply_mask LRELU_SLOPE = 0.1 @@ -32,14 +32,18 @@ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): ]) self.convs2.apply(init_weights) - def forward(self, x): + def forward(self, x, valid_lengths=None): for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, LRELU_SLOPE) xt = c1(xt) + if valid_lengths is not None: + xt, valid_lengths = apply_mask(c1, xt, valid_lengths) xt = F.leaky_relu(xt, LRELU_SLOPE) xt = c2(xt) + if valid_lengths is not None: + xt, valid_lengths = apply_mask(c2, xt, valid_lengths) x = xt + x - return x + return x, valid_lengths def remove_weight_norm(self): for l in self.convs1: @@ -60,12 +64,14 @@ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): ]) self.convs.apply(init_weights) - def forward(self, x): + def forward(self, x, valid_lengths=None): for c in self.convs: xt = F.leaky_relu(x, LRELU_SLOPE) xt = c(xt) + if valid_lengths is not None: + xt, valid_lengths = apply_mask(c, xt, valid_lengths) x = xt + x - return x + return x, valid_lengths def remove_weight_norm(self): for l in self.convs: @@ -97,18 +103,24 @@ def __init__(self, h): self.ups.apply(init_weights) self.conv_post.apply(init_weights) - def forward(self, x): + def forward(self, x, valid_lengths=None): x = self.conv_pre(x) + if valid_lengths is not None: + x, valid_lengths = apply_mask(self.conv_pre, x, valid_lengths) for i in range(self.num_upsamples): x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) + if valid_lengths is not None: + x, valid_lengths = apply_mask(self.ups[i], x, valid_lengths) xs = None for j in range(self.num_kernels): if xs is None: - xs = self.resblocks[i*self.num_kernels+j](x) + xs, valid_lengths = self.resblocks[i*self.num_kernels+j](x, valid_lengths) else: - xs += self.resblocks[i*self.num_kernels+j](x) + xs_cal , valid_lengths = self.resblocks[i*self.num_kernels+j](x, valid_lengths) + xs = xs + xs_cal x = xs / self.num_kernels + x = F.leaky_relu(x) x = self.conv_post(x) x = torch.tanh(x) diff --git a/utils.py b/utils.py index aa2a536e6..7084b0f00 100644 --- a/utils.py +++ b/utils.py @@ -56,3 +56,26 @@ def scan_checkpoint(cp_dir, prefix): return None return sorted(cp_list)[-1] +def apply_mask(op, x, valid_lengths): + if op._get_name() == 'Conv1d': + kernel_size = op.kernel_size[0] + stride = op.stride[0] + padding = op.padding[0] + dilation = op.dilation[0] + output_length = [] + for i, length in enumerate(valid_lengths): + output_length.append((length + 2 * padding - (kernel_size - 1) * dilation - 1) // stride + 1) + x[i, :, output_length[i]:] = 0 + elif op._get_name() == 'ConvTranspose1d': + kernel_size = op.kernel_size[0] + stride = op.stride[0] + padding = op.padding[0] + dilation = op.dilation[0] + output_length = [] + for i, length in enumerate(valid_lengths): + output_length.append((length - 1) * stride - 2 * padding + (kernel_size - 1) * dilation + op.output_padding[0] + 1) + x[i, :, output_length[i]:] = 0 + else: + return x, valid_lengths + return x, output_length +