-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathutils.py
More file actions
259 lines (211 loc) · 8.66 KB
/
utils.py
File metadata and controls
259 lines (211 loc) · 8.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
from collections import namedtuple
import torch
import torch.nn as nn
from torchvision import models
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self, name="vgg_lpips"):
try:
data = torch.load("vgg.pth", map_location=torch.device("cpu"))
except:
print("Failed to load vgg.pth, downloading...")
os.system(
"wget https://heibox.uni-heidelberg.de/seafhttp/files/9535cbee-6558-4c0c-8743-78f5e56ea75e/vgg.pth"
)
data = torch.load("vgg.pth", map_location=torch.device("cpu"))
self.load_state_dict(
data,
strict=False,
)
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
outs1[kk]
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)
class PatchDiscriminator(nn.Module):
def __init__(self):
super(PatchDiscriminator, self).__init__()
self.scaling_layer = ScalingLayer()
_vgg = models.vgg16(pretrained=True)
self.slice1 = nn.Sequential(_vgg.features[:4])
self.slice2 = nn.Sequential(_vgg.features[4:9])
self.slice3 = nn.Sequential(_vgg.features[9:16])
self.slice4 = nn.Sequential(_vgg.features[16:23])
self.slice5 = nn.Sequential(_vgg.features[23:30])
self.binary_classifier1 = nn.Sequential(
nn.Conv2d(64, 32, kernel_size=4, stride=4, padding=0, bias=True),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=4, stride=4, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier1[-1].weight)
self.binary_classifier2 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=4, stride=4, padding=0, bias=True),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=2, stride=2, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier2[-1].weight)
self.binary_classifier3 = nn.Sequential(
nn.Conv2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True),
nn.ReLU(),
nn.Conv2d(128, 1, kernel_size=2, stride=2, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier3[-1].weight)
self.binary_classifier4 = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=2, stride=2, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier4[-1].weight)
self.binary_classifier5 = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier5[-1].weight)
def forward(self, x):
x = self.scaling_layer(x)
features1 = self.slice1(x)
features2 = self.slice2(features1)
features3 = self.slice3(features2)
features4 = self.slice4(features3)
features5 = self.slice5(features4)
# torch.Size([1, 64, 256, 256]) torch.Size([1, 128, 128, 128]) torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 512, 16, 16])
bc1 = self.binary_classifier1(features1).flatten(1)
bc2 = self.binary_classifier2(features2).flatten(1)
bc3 = self.binary_classifier3(features3).flatten(1)
bc4 = self.binary_classifier4(features4).flatten(1)
bc5 = self.binary_classifier5(features5).flatten(1)
return bc1 + bc2 + bc3 + bc4 + bc5
dec_lo, dec_hi = (
torch.Tensor([-0.1768, 0.3536, 1.0607, 0.3536, -0.1768, 0.0000]),
torch.Tensor([0.0000, -0.0000, 0.3536, -0.7071, 0.3536, -0.0000]),
)
filters = torch.stack(
[
dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1),
],
dim=0,
)
filters_expanded = filters.unsqueeze(1)
def prepare_filter(device):
global filters_expanded
filters_expanded = filters_expanded.to(device)
def wavelet_transform_multi_channel(x, levels=4):
B, C, H, W = x.shape
padded = torch.nn.functional.pad(x, (2, 2, 2, 2))
# use predefined filters
global filters_expanded
ress = []
for ch in range(C):
res = torch.nn.functional.conv2d(
padded[:, ch : ch + 1], filters_expanded, stride=2
)
ress.append(res)
res = torch.cat(ress, dim=1)
H_out, W_out = res.shape[2], res.shape[3]
res = res.view(B, C, 4, H_out, W_out)
res = res.view(B, 4 * C, H_out, W_out)
return res
def test_patch_discriminator():
vggDiscriminator = PatchDiscriminator().cuda()
x = vggDiscriminator(torch.randn(1, 3, 256, 256).cuda())
print(x.shape)
if __name__ == "__main__":
vggDiscriminator = PatchDiscriminator().cuda()
x = vggDiscriminator(torch.randn(1, 3, 256, 256).cuda())
print(x.shape)