Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils import init_weights, get_padding
from utils import init_weights, get_padding, apply_mask

LRELU_SLOPE = 0.1

Expand Down Expand Up @@ -32,14 +32,18 @@ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
])
self.convs2.apply(init_weights)

def forward(self, x):
def forward(self, x, valid_lengths=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
if valid_lengths is not None:
xt, valid_lengths = apply_mask(c1, xt, valid_lengths)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
if valid_lengths is not None:
xt, valid_lengths = apply_mask(c2, xt, valid_lengths)
x = xt + x
return x
return x, valid_lengths

def remove_weight_norm(self):
for l in self.convs1:
Expand All @@ -60,12 +64,14 @@ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
])
self.convs.apply(init_weights)

def forward(self, x):
def forward(self, x, valid_lengths=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
if valid_lengths is not None:
xt, valid_lengths = apply_mask(c, xt, valid_lengths)
x = xt + x
return x
return x, valid_lengths

def remove_weight_norm(self):
for l in self.convs:
Expand Down Expand Up @@ -97,18 +103,24 @@ def __init__(self, h):
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)

def forward(self, x):
def forward(self, x, valid_lengths=None):
x = self.conv_pre(x)
if valid_lengths is not None:
x, valid_lengths = apply_mask(self.conv_pre, x, valid_lengths)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
if valid_lengths is not None:
x, valid_lengths = apply_mask(self.ups[i], x, valid_lengths)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
xs, valid_lengths = self.resblocks[i*self.num_kernels+j](x, valid_lengths)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
xs_cal , valid_lengths = self.resblocks[i*self.num_kernels+j](x, valid_lengths)
xs = xs + xs_cal
x = xs / self.num_kernels

x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
Expand Down
23 changes: 23 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,26 @@ def scan_checkpoint(cp_dir, prefix):
return None
return sorted(cp_list)[-1]

def apply_mask(op, x, valid_lengths):
if op._get_name() == 'Conv1d':
kernel_size = op.kernel_size[0]
stride = op.stride[0]
padding = op.padding[0]
dilation = op.dilation[0]
output_length = []
for i, length in enumerate(valid_lengths):
output_length.append((length + 2 * padding - (kernel_size - 1) * dilation - 1) // stride + 1)
x[i, :, output_length[i]:] = 0
elif op._get_name() == 'ConvTranspose1d':
kernel_size = op.kernel_size[0]
stride = op.stride[0]
padding = op.padding[0]
dilation = op.dilation[0]
output_length = []
for i, length in enumerate(valid_lengths):
output_length.append((length - 1) * stride - 2 * padding + (kernel_size - 1) * dilation + op.output_padding[0] + 1)
x[i, :, output_length[i]:] = 0
else:
return x, valid_lengths
return x, output_length