Skip to content

Commit c93940a

Browse files
authored
Add files via upload
1 parent f1730d4 commit c93940a

File tree

2 files changed

+250
-0
lines changed

2 files changed

+250
-0
lines changed

uvr5_pack/lib_v5/layers_new.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
5+
from uvr5_pack.lib_v5 import spec_utils
6+
7+
class Conv2DBNActiv(nn.Module):
8+
9+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10+
super(Conv2DBNActiv, self).__init__()
11+
self.conv = nn.Sequential(
12+
nn.Conv2d(
13+
nin, nout,
14+
kernel_size=ksize,
15+
stride=stride,
16+
padding=pad,
17+
dilation=dilation,
18+
bias=False),
19+
nn.BatchNorm2d(nout),
20+
activ()
21+
)
22+
23+
def __call__(self, x):
24+
return self.conv(x)
25+
26+
class Encoder(nn.Module):
27+
28+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
29+
super(Encoder, self).__init__()
30+
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
31+
self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
32+
33+
def __call__(self, x):
34+
h = self.conv1(x)
35+
h = self.conv2(h)
36+
37+
return h
38+
39+
40+
class Decoder(nn.Module):
41+
42+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
43+
super(Decoder, self).__init__()
44+
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
45+
# self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
46+
self.dropout = nn.Dropout2d(0.1) if dropout else None
47+
48+
def __call__(self, x, skip=None):
49+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
50+
51+
if skip is not None:
52+
skip = spec_utils.crop_center(skip, x)
53+
x = torch.cat([x, skip], dim=1)
54+
55+
h = self.conv1(x)
56+
# h = self.conv2(h)
57+
58+
if self.dropout is not None:
59+
h = self.dropout(h)
60+
61+
return h
62+
63+
64+
class ASPPModule(nn.Module):
65+
66+
def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
67+
super(ASPPModule, self).__init__()
68+
self.conv1 = nn.Sequential(
69+
nn.AdaptiveAvgPool2d((1, None)),
70+
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
71+
)
72+
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
73+
self.conv3 = Conv2DBNActiv(
74+
nin, nout, 3, 1, dilations[0], dilations[0], activ=activ
75+
)
76+
self.conv4 = Conv2DBNActiv(
77+
nin, nout, 3, 1, dilations[1], dilations[1], activ=activ
78+
)
79+
self.conv5 = Conv2DBNActiv(
80+
nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
81+
)
82+
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
83+
self.dropout = nn.Dropout2d(0.1) if dropout else None
84+
85+
def forward(self, x):
86+
_, _, h, w = x.size()
87+
feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
88+
feat2 = self.conv2(x)
89+
feat3 = self.conv3(x)
90+
feat4 = self.conv4(x)
91+
feat5 = self.conv5(x)
92+
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
93+
out = self.bottleneck(out)
94+
95+
if self.dropout is not None:
96+
out = self.dropout(out)
97+
98+
return out
99+
100+
101+
class LSTMModule(nn.Module):
102+
103+
def __init__(self, nin_conv, nin_lstm, nout_lstm):
104+
super(LSTMModule, self).__init__()
105+
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
106+
self.lstm = nn.LSTM(
107+
input_size=nin_lstm,
108+
hidden_size=nout_lstm // 2,
109+
bidirectional=True
110+
)
111+
self.dense = nn.Sequential(
112+
nn.Linear(nout_lstm, nin_lstm),
113+
nn.BatchNorm1d(nin_lstm),
114+
nn.ReLU()
115+
)
116+
117+
def forward(self, x):
118+
N, _, nbins, nframes = x.size()
119+
h = self.conv(x)[:, 0] # N, nbins, nframes
120+
h = h.permute(2, 0, 1) # nframes, N, nbins
121+
h, _ = self.lstm(h)
122+
h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins
123+
h = h.reshape(nframes, N, 1, nbins)
124+
h = h.permute(1, 2, 3, 0)
125+
126+
return h

uvr5_pack/lib_v5/nets_new.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
from uvr5_pack.lib_v5 import layers_new as layers
5+
6+
class BaseNet(nn.Module):
7+
8+
def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))):
9+
super(BaseNet, self).__init__()
10+
self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1)
11+
self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1)
12+
self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1)
13+
self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1)
14+
self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1)
15+
16+
self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True)
17+
18+
self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1)
19+
self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1)
20+
self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1)
21+
self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm)
22+
self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1)
23+
24+
def __call__(self, x):
25+
e1 = self.enc1(x)
26+
e2 = self.enc2(e1)
27+
e3 = self.enc3(e2)
28+
e4 = self.enc4(e3)
29+
e5 = self.enc5(e4)
30+
31+
h = self.aspp(e5)
32+
33+
h = self.dec4(h, e4)
34+
h = self.dec3(h, e3)
35+
h = self.dec2(h, e2)
36+
h = torch.cat([h, self.lstm_dec2(h)], dim=1)
37+
h = self.dec1(h, e1)
38+
39+
return h
40+
41+
class CascadedNet(nn.Module):
42+
43+
def __init__(self, n_fft, nout=32, nout_lstm=128):
44+
super(CascadedNet, self).__init__()
45+
46+
self.max_bin = n_fft // 2
47+
self.output_bin = n_fft // 2 + 1
48+
self.nin_lstm = self.max_bin // 2
49+
self.offset = 64
50+
51+
self.stg1_low_band_net = nn.Sequential(
52+
BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
53+
layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0)
54+
)
55+
56+
self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2)
57+
58+
self.stg2_low_band_net = nn.Sequential(
59+
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
60+
layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0)
61+
)
62+
self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2)
63+
64+
self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm)
65+
66+
self.out = nn.Conv2d(nout, 2, 1, bias=False)
67+
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
68+
69+
def forward(self, x):
70+
x = x[:, :, :self.max_bin]
71+
72+
bandw = x.size()[2] // 2
73+
l1_in = x[:, :, :bandw]
74+
h1_in = x[:, :, bandw:]
75+
l1 = self.stg1_low_band_net(l1_in)
76+
h1 = self.stg1_high_band_net(h1_in)
77+
aux1 = torch.cat([l1, h1], dim=2)
78+
79+
l2_in = torch.cat([l1_in, l1], dim=1)
80+
h2_in = torch.cat([h1_in, h1], dim=1)
81+
l2 = self.stg2_low_band_net(l2_in)
82+
h2 = self.stg2_high_band_net(h2_in)
83+
aux2 = torch.cat([l2, h2], dim=2)
84+
85+
f3_in = torch.cat([x, aux1, aux2], dim=1)
86+
f3 = self.stg3_full_band_net(f3_in)
87+
88+
mask = torch.sigmoid(self.out(f3))
89+
mask = F.pad(
90+
input=mask,
91+
pad=(0, 0, 0, self.output_bin - mask.size()[2]),
92+
mode='replicate'
93+
)
94+
95+
if self.training:
96+
aux = torch.cat([aux1, aux2], dim=1)
97+
aux = torch.sigmoid(self.aux_out(aux))
98+
aux = F.pad(
99+
input=aux,
100+
pad=(0, 0, 0, self.output_bin - aux.size()[2]),
101+
mode='replicate'
102+
)
103+
return mask, aux
104+
else:
105+
return mask
106+
107+
def predict_mask(self, x):
108+
mask = self.forward(x)
109+
110+
if self.offset > 0:
111+
mask = mask[:, :, :, self.offset:-self.offset]
112+
assert mask.size()[3] > 0
113+
114+
return mask
115+
116+
def predict(self, x,aggressiveness=None):
117+
mask = self.forward(x)
118+
pred_mag = x * mask
119+
120+
if self.offset > 0:
121+
pred_mag = pred_mag[:, :, :, self.offset:-self.offset]
122+
assert pred_mag.size()[3] > 0
123+
124+
return pred_mag

0 commit comments

Comments
 (0)