Skip to content

Commit f13ab8b

Browse files
ayutazclaude
andcommitted
refactor: Simple code cleanups (from upstream PR yl4579#219)
- Remove unused imports and variables - Replace wildcard imports with specific imports - Fix Flake8 errors - Remove IPython tracing in scripts Original PR: yl4579#219 Author: Arvind Suresh <arvind@free-speech.ai> 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent fc9eed4 commit f13ab8b

22 files changed

+156
-300
lines changed

Modules/diffusion/diffusion.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
1-
from math import pi
2-
from random import randint
3-
from typing import Any, Optional, Sequence, Tuple, Union
4-
5-
import torch
6-
from einops import rearrange
71
from torch import Tensor, nn
8-
from tqdm import tqdm
92

10-
from .utils import *
11-
from .sampler import *
3+
from .utils import groupby
4+
from .sampler import UniformDistribution, LinearSchedule, VSampler
125

136
"""
147
Diffusion Classes (generic for 1d data)

Modules/diffusion/modules.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
from math import floor, log, pi
2-
from typing import Any, List, Optional, Sequence, Tuple, Union
1+
from math import log, pi
2+
from typing import Optional
33

4-
from .utils import *
4+
from .utils import default, exists, rand_bool
55

66
import torch
77
import torch.nn as nn
8+
import torch.nn.functional as F
89
from einops import rearrange, reduce, repeat
910
from einops.layers.torch import Rearrange
1011
from einops_exts import rearrange_many
1112
from torch import Tensor, einsum
1213

13-
1414
"""
1515
Utils
1616
"""

Modules/diffusion/sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from einops import rearrange, reduce
88
from torch import Tensor
99

10-
from .utils import *
10+
from .utils import default, exists
1111

1212
"""
1313
Diffusion Training
@@ -213,7 +213,6 @@ def loss_weight(self, sigmas: Tensor) -> Tensor:
213213

214214
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
215215
batch_size, device = x.shape[0], x.device
216-
from einops import rearrange, reduce
217216

218217
# Sample amount of noise to add for each batch element
219218
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)

Modules/diffusion/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
from functools import reduce
22
from inspect import isfunction
3-
from math import ceil, floor, log2, pi
3+
from math import ceil, floor, log2
44
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
55

66
import torch
7-
import torch.nn.functional as F
8-
from einops import rearrange
9-
from torch import Generator, Tensor
107
from typing_extensions import TypeGuard
118

129
T = TypeVar("T")

Modules/discriminators.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33
import torch.nn as nn
4-
from torch.nn import Conv1d, AvgPool1d, Conv2d
4+
from torch.nn import Conv1d, Conv2d
55
from torch.nn.utils import weight_norm, spectral_norm
66

77
from .utils import get_padding
@@ -21,8 +21,6 @@ def stft(x, fft_size, hop_size, win_length, window):
2121
"""
2222
x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
2323
return_complex=True)
24-
real = x_stft[..., 0]
25-
imag = x_stft[..., 1]
2624

2725
return torch.abs(x_stft).transpose(2, 1)
2826

@@ -31,7 +29,7 @@ class SpecDiscriminator(nn.Module):
3129

3230
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
3331
super(SpecDiscriminator, self).__init__()
34-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
32+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
3533
self.fft_size = fft_size
3634
self.shift_size = shift_size
3735
self.win_length = win_length
@@ -97,7 +95,7 @@ class DiscriminatorP(torch.nn.Module):
9795
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
9896
super(DiscriminatorP, self).__init__()
9997
self.period = period
100-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
98+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
10199
self.convs = nn.ModuleList([
102100
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
103101
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
@@ -118,8 +116,8 @@ def forward(self, x):
118116
t = t + n_pad
119117
x = x.view(b, c, t // self.period, self.period)
120118

121-
for l in self.convs:
122-
x = l(x)
119+
for layer in self.convs:
120+
x = layer(x)
123121
x = F.leaky_relu(x, LRELU_SLOPE)
124122
fmap.append(x)
125123
x = self.conv_post(x)
@@ -163,7 +161,7 @@ def __init__(self, slm_hidden=768,
163161
initial_channel=64,
164162
use_spectral_norm=False):
165163
super(WavLMDiscriminator, self).__init__()
166-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
164+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
167165
self.pre = norm_f(Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
168166

169167
self.convs = nn.ModuleList([
@@ -178,11 +176,11 @@ def forward(self, x):
178176
x = self.pre(x)
179177

180178
fmap = []
181-
for l in self.convs:
182-
x = l(x)
179+
for layer in self.convs:
180+
x = layer(x)
183181
x = F.leaky_relu(x, LRELU_SLOPE)
184182
fmap.append(x)
185183
x = self.conv_post(x)
186184
x = torch.flatten(x, 1, -1)
187185

188-
return x
186+
return x

Modules/hifigan.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch.nn.functional as F
33
import torch.nn as nn
4-
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5-
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
4+
from torch.nn import Conv1d, ConvTranspose1d
5+
from torch.nn.utils import weight_norm, remove_weight_norm
66
from .utils import init_weights, get_padding
77

88
import math
@@ -74,10 +74,10 @@ def forward(self, x, s):
7474
return x
7575

7676
def remove_weight_norm(self):
77-
for l in self.convs1:
78-
remove_weight_norm(l)
79-
for l in self.convs2:
80-
remove_weight_norm(l)
77+
for layer in self.convs1:
78+
remove_weight_norm(layer)
79+
for layer in self.convs2:
80+
remove_weight_norm(layer)
8181

8282
class SineGen(torch.nn.Module):
8383
""" Definition of sine generator
@@ -193,8 +193,7 @@ def forward(self, f0):
193193
output sine_tensor: tensor(batchsize=1, length, dim)
194194
output uv: tensor(batchsize=1, length, 1)
195195
"""
196-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
197-
device=f0.device)
196+
torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
198197
# fundamental component
199198
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
200199

@@ -348,10 +347,10 @@ def forward(self, x, s, f0):
348347

349348
def remove_weight_norm(self):
350349
print('Removing weight norm...')
351-
for l in self.ups:
352-
remove_weight_norm(l)
353-
for l in self.resblocks:
354-
l.remove_weight_norm()
350+
for layer in self.ups:
351+
remove_weight_norm(layer)
352+
for layer in self.resblocks:
353+
layer.remove_weight_norm()
355354
remove_weight_norm(self.conv_pre)
356355
remove_weight_norm(self.conv_post)
357356

@@ -474,4 +473,4 @@ def forward(self, asr, F0_curve, N, s):
474473
x = self.generator(x, s, F0_curve)
475474
return x
476475

477-
476+

Modules/istftnet.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch.nn.functional as F
33
import torch.nn as nn
4-
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5-
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
4+
from torch.nn import Conv1d, ConvTranspose1d
5+
from torch.nn.utils import weight_norm, remove_weight_norm
66
from .utils import init_weights, get_padding
77

88
import math
@@ -75,10 +75,10 @@ def forward(self, x, s):
7575
return x
7676

7777
def remove_weight_norm(self):
78-
for l in self.convs1:
79-
remove_weight_norm(l)
80-
for l in self.convs2:
81-
remove_weight_norm(l)
78+
for layer in self.convs1:
79+
remove_weight_norm(layer)
80+
for layer in self.convs2:
81+
remove_weight_norm(layer)
8282

8383
class TorchSTFT(torch.nn.Module):
8484
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
@@ -222,8 +222,7 @@ def forward(self, f0):
222222
output sine_tensor: tensor(batchsize=1, length, dim)
223223
output uv: tensor(batchsize=1, length, 1)
224224
"""
225-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
226-
device=f0.device)
225+
torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
227226
# fundamental component
228227
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
229228

@@ -399,10 +398,10 @@ def fw_phase(self, x, s):
399398

400399
def remove_weight_norm(self):
401400
print('Removing weight norm...')
402-
for l in self.ups:
403-
remove_weight_norm(l)
404-
for l in self.resblocks:
405-
l.remove_weight_norm()
401+
for layer in self.ups:
402+
remove_weight_norm(layer)
403+
for layer in self.resblocks:
404+
layer.remove_weight_norm()
406405
remove_weight_norm(self.conv_pre)
407406
remove_weight_norm(self.conv_post)
408407

@@ -527,4 +526,4 @@ def forward(self, asr, F0_curve, N, s):
527526
x = self.generator(x, s, F0_curve)
528527
return x
529528

530-
529+

Modules/slmadv.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_t
4141
num_steps=num_steps).squeeze(1)
4242

4343
s_dur = s_preds[:, 128:]
44-
s = s_preds[:, :128]
44+
s_preds[:, :128]
4545

4646
d, _ = self.model.predictor(d_en, s_dur,
4747
ref_lengths,
@@ -61,20 +61,20 @@ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_t
6161
_s2s_pred = torch.sigmoid(_s2s_pred_org)
6262
_dur_pred = _s2s_pred.sum(axis=-1)
6363

64-
l = int(torch.round(_s2s_pred.sum()).item())
65-
t = torch.arange(0, l).expand(l)
64+
length = int(torch.round(_s2s_pred.sum()).item())
65+
t = torch.arange(0, length).expand(length)
6666

67-
t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
67+
t = torch.arange(0, length).unsqueeze(0).expand((len(_s2s_pred), length)).to(ref_text.device)
6868
loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
6969

70-
h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
70+
h = torch.exp(-0.5 * torch.square(t - (length - loc.unsqueeze(-1))) / (self.sig)**2)
7171

7272
out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
7373
h.unsqueeze(1),
74-
padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
74+
padding=h.shape[-1] - 1, groups=int(_text_length))[..., :length]
7575
attn_preds.append(F.softmax(out.squeeze(), dim=0))
7676

77-
output_lengths.append(l)
77+
output_lengths.append(length)
7878

7979
max_len = max(output_lengths)
8080

@@ -96,14 +96,9 @@ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_t
9696
mel_len = min(mel_len, self.max_len // 2)
9797

9898
# get clips
99-
10099
en = []
101100
p_en = []
102101
sp = []
103-
104-
F0_fakes = []
105-
N_fakes = []
106-
107102
wav = []
108103

109104
for bib in range(len(output_lengths)):

Modules/utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,5 @@ def init_weights(m, mean=0.0, std=0.01):
33
if classname.find("Conv") != -1:
44
m.weight.data.normal_(mean, std)
55

6-
7-
def apply_weight_norm(m):
8-
classname = m.__class__.__name__
9-
if classname.find("Conv") != -1:
10-
weight_norm(m)
11-
12-
136
def get_padding(kernel_size, dilation=1):
14-
return int((kernel_size*dilation - dilation)/2)
7+
return int((kernel_size*dilation - dilation)/2)

Utils/ASR/layers.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
import math
21
import torch
32
from torch import nn
4-
from typing import Optional, Any
5-
from torch import Tensor
63
import torch.nn.functional as F
7-
import torchaudio
84
import torchaudio.functional as audio_F
95

106
import random

0 commit comments

Comments
 (0)