11import torch
22import torch .nn .functional as F
33import torch .nn as nn
4- from torch .nn import Conv1d , AvgPool1d , Conv2d
4+ from torch .nn import Conv1d , Conv2d
55from torch .nn .utils import weight_norm , spectral_norm
66
77from .utils import get_padding
@@ -21,8 +21,6 @@ def stft(x, fft_size, hop_size, win_length, window):
2121 """
2222 x_stft = torch .stft (x , fft_size , hop_size , win_length , window ,
2323 return_complex = True )
24- real = x_stft [..., 0 ]
25- imag = x_stft [..., 1 ]
2624
2725 return torch .abs (x_stft ).transpose (2 , 1 )
2826
@@ -31,7 +29,7 @@ class SpecDiscriminator(nn.Module):
3129
3230 def __init__ (self , fft_size = 1024 , shift_size = 120 , win_length = 600 , window = "hann_window" , use_spectral_norm = False ):
3331 super (SpecDiscriminator , self ).__init__ ()
34- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
32+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
3533 self .fft_size = fft_size
3634 self .shift_size = shift_size
3735 self .win_length = win_length
@@ -97,7 +95,7 @@ class DiscriminatorP(torch.nn.Module):
9795 def __init__ (self , period , kernel_size = 5 , stride = 3 , use_spectral_norm = False ):
9896 super (DiscriminatorP , self ).__init__ ()
9997 self .period = period
100- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
98+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
10199 self .convs = nn .ModuleList ([
102100 norm_f (Conv2d (1 , 32 , (kernel_size , 1 ), (stride , 1 ), padding = (get_padding (5 , 1 ), 0 ))),
103101 norm_f (Conv2d (32 , 128 , (kernel_size , 1 ), (stride , 1 ), padding = (get_padding (5 , 1 ), 0 ))),
@@ -118,8 +116,8 @@ def forward(self, x):
118116 t = t + n_pad
119117 x = x .view (b , c , t // self .period , self .period )
120118
121- for l in self .convs :
122- x = l (x )
119+ for layer in self .convs :
120+ x = layer (x )
123121 x = F .leaky_relu (x , LRELU_SLOPE )
124122 fmap .append (x )
125123 x = self .conv_post (x )
@@ -163,7 +161,7 @@ def __init__(self, slm_hidden=768,
163161 initial_channel = 64 ,
164162 use_spectral_norm = False ):
165163 super (WavLMDiscriminator , self ).__init__ ()
166- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
164+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
167165 self .pre = norm_f (Conv1d (slm_hidden * slm_layers , initial_channel , 1 , 1 , padding = 0 ))
168166
169167 self .convs = nn .ModuleList ([
@@ -178,11 +176,11 @@ def forward(self, x):
178176 x = self .pre (x )
179177
180178 fmap = []
181- for l in self .convs :
182- x = l (x )
179+ for layer in self .convs :
180+ x = layer (x )
183181 x = F .leaky_relu (x , LRELU_SLOPE )
184182 fmap .append (x )
185183 x = self .conv_post (x )
186184 x = torch .flatten (x , 1 , - 1 )
187185
188- return x
186+ return x
0 commit comments