Skip to content

Commit 051772c

Browse files
committed
adapte BatchNorm1d to 1d-CNN
1 parent eec8c14 commit 051772c

File tree

4 files changed

+205
-11
lines changed

4 files changed

+205
-11
lines changed

src/INN/INN.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,22 @@ def __init__(self, dim, requires_grad=True):
139139
INNAbstract.INNModule.__init__(self)
140140
nn.BatchNorm1d.__init__(self, num_features=dim, affine=False)
141141
self.requires_grad = requires_grad
142+
143+
def _scale(self, x):
144+
'''The scale factor of x to compute Jacobian'''
145+
if len(x.shape) == 2:
146+
return 1
147+
148+
s = 1
149+
for dim in x.shape[2:]:
150+
s *= dim
151+
return s
142152

143153
def forward(self, x, log_p=0, log_det_J=0):
144-
154+
'''
155+
Apply batch normalization to x
156+
x.shape = [batch_size, dim, *]
157+
'''
145158
if self.compute_p:
146159
if not self.training:
147160
# if in self.eval()
@@ -155,15 +168,24 @@ def forward(self, x, log_p=0, log_det_J=0):
155168
x = super(BatchNorm1d, self).forward(x)
156169

157170
log_det = -0.5 * torch.log(var + self.eps)
158-
log_det = torch.sum(log_det, dim=-1)
171+
log_det = torch.sum(log_det, dim=-1) * self._scale(x)
159172

160173
return x, log_p, log_det_J + log_det
161174
else:
162175
return super(BatchNorm1d, self).forward(x)
163176

164177
def inverse(self, y, **args):
178+
'''
179+
inverse y to un-batch-normed numbers
180+
The shape of y can be:
181+
a. Linear: [batch_size, dim]
182+
b. n-d: [batch_size, dim, *]
183+
'''
184+
batch_size, dim = y.shape[0], y.shape[1]
165185
var = self.running_var + self.eps
166186
mean = self.running_mean
187+
var = var.reshape(1, dim, *([1]*(len(y.shape) - 2)))
188+
mean = mean.reshape(1, dim, *([1]*(len(y.shape) - 2)))
167189
x = y * torch.sqrt(var) + mean
168190
return x
169191

tests/quick_tests.ipynb

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,27 +1256,126 @@
12561256
},
12571257
{
12581258
"cell_type": "code",
1259-
"execution_count": 25,
1259+
"execution_count": 28,
12601260
"metadata": {
12611261
"ExecuteTime": {
1262-
"end_time": "2021-04-28T00:06:52.232788Z",
1263-
"start_time": "2021-04-28T00:06:52.224166Z"
1262+
"end_time": "2021-04-28T00:57:22.891307Z",
1263+
"start_time": "2021-04-28T00:57:22.885167Z"
1264+
}
1265+
},
1266+
"outputs": [],
1267+
"source": [
1268+
"x = torch.randn((3, 5, 9))"
1269+
]
1270+
},
1271+
{
1272+
"cell_type": "code",
1273+
"execution_count": 45,
1274+
"metadata": {
1275+
"ExecuteTime": {
1276+
"end_time": "2021-04-28T01:18:20.763437Z",
1277+
"start_time": "2021-04-28T01:18:20.758441Z"
1278+
}
1279+
},
1280+
"outputs": [],
1281+
"source": [
1282+
"bnINN = INN.BatchNorm1d(5)\n",
1283+
"bnINN.computing_p(False)\n",
1284+
"bn = nn.BatchNorm1d(5, affine=False)\n",
1285+
"\n",
1286+
"#bnINN.eval()\n",
1287+
"#bn.eval()"
1288+
]
1289+
},
1290+
{
1291+
"cell_type": "code",
1292+
"execution_count": 50,
1293+
"metadata": {
1294+
"ExecuteTime": {
1295+
"end_time": "2021-04-28T01:18:37.776361Z",
1296+
"start_time": "2021-04-28T01:18:37.765331Z"
12641297
}
12651298
},
12661299
"outputs": [
12671300
{
12681301
"data": {
12691302
"text/plain": [
1270-
"torch.Size([3, 45])"
1303+
"tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1304+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1305+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1306+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1307+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
1308+
"\n",
1309+
" [[0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1310+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1311+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1312+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1313+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
1314+
"\n",
1315+
" [[0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1316+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1317+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1318+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
1319+
" [0., 0., 0., 0., 0., 0., 0., 0., 0.]]])"
12711320
]
12721321
},
1273-
"execution_count": 25,
1322+
"execution_count": 50,
1323+
"metadata": {},
1324+
"output_type": "execute_result"
1325+
}
1326+
],
1327+
"source": [
1328+
"bnINN(x) - bn(x)"
1329+
]
1330+
},
1331+
{
1332+
"cell_type": "code",
1333+
"execution_count": 40,
1334+
"metadata": {
1335+
"ExecuteTime": {
1336+
"end_time": "2021-04-28T01:03:10.360320Z",
1337+
"start_time": "2021-04-28T01:03:10.353978Z"
1338+
}
1339+
},
1340+
"outputs": [
1341+
{
1342+
"data": {
1343+
"text/plain": [
1344+
"tensor([1., 1., 1., 1., 1.])"
1345+
]
1346+
},
1347+
"execution_count": 40,
1348+
"metadata": {},
1349+
"output_type": "execute_result"
1350+
}
1351+
],
1352+
"source": [
1353+
"bnINN.running_var"
1354+
]
1355+
},
1356+
{
1357+
"cell_type": "code",
1358+
"execution_count": 42,
1359+
"metadata": {
1360+
"ExecuteTime": {
1361+
"end_time": "2021-04-28T01:11:49.029348Z",
1362+
"start_time": "2021-04-28T01:11:49.025402Z"
1363+
}
1364+
},
1365+
"outputs": [
1366+
{
1367+
"data": {
1368+
"text/plain": [
1369+
"[3]"
1370+
]
1371+
},
1372+
"execution_count": 42,
12741373
"metadata": {},
12751374
"output_type": "execute_result"
12761375
}
12771376
],
12781377
"source": [
1279-
"model(x).shape"
1378+
"[1,2,3][2:]"
12801379
]
12811380
},
12821381
{

tests/test_advanced.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import INN
2+
import torch
3+
import torch.nn as nn
24

35
'''
46
TODO: directly compute log|det(J)| by a strict way, compare to the output
@@ -33,11 +35,66 @@ def Jacobian_matrix(model, x):
3335
repeats.append(1)
3436

3537
x_hat = x.unsqueeze(0).repeat(tuple(repeats))
38+
#print(f'x.shape={x.shape}, x_hat.shape={x_hat.shape}')
3639
x_hat.requires_grad = True
3740
model.computing_p(True)
3841
y, log_p, log_det = model(x_hat)
3942

40-
v = torch.diag(torch.ones(dim)).reshape((dim, *x.shape))
43+
v = torch.diag(torch.ones(dim)).reshape((y.shape))
4144
grad = INN.utilities.vjp(y, x_hat, v)[0]
4245

43-
return grad.detach(), log_det.detach()
46+
return grad.detach().reshape(y.shape), log_det.detach()
47+
48+
def TestJacobian(model, shape, th=1e-6):
49+
model.eval()
50+
J, logdet = Jacobian_matrix(model, x=torch.randn(shape))
51+
#print(J.shape)
52+
log_det_J = torch.log(torch.abs(torch.det(J)))
53+
print(f'J={log_det_J:.10f}, estimated={torch.mean(logdet):.10f}', end=' , ')
54+
diff = nn.L1Loss()(log_det_J, torch.mean(logdet))
55+
if abs(diff / log_det_J) <= th:
56+
print('pass')
57+
else:
58+
print(f'estimation error is too big (relative loss={abs(diff / log_det_J):.8f})')
59+
60+
'''
61+
########################################################################
62+
Start Tests
63+
########################################################################
64+
'''
65+
66+
print('#'*8 + ' Nonlinear (RealNVP) ' + '#'*8)
67+
model = INN.Nonlinear(5)
68+
TestJacobian(model, shape=5)
69+
70+
print('#'*8 + ' Nonlinear (NICE) ' + '#'*8)
71+
model = INN.Sequential(INN.Nonlinear(5), INN.Nonlinear(5, method='NICE'))
72+
TestJacobian(model, shape=5)
73+
74+
print('#'*8 + ' Nonlinear (iResNet) ' + '#'*8)
75+
model = INN.Sequential(INN.Nonlinear(5), INN.Nonlinear(5, method='iResNet'))
76+
TestJacobian(model, shape=5)
77+
78+
print('#'*8 + ' Conv1d (RealNVP, NICE) ' + '#'*8)
79+
model = INN.Sequential(INN.Conv1d(5, kernel_size=1, method='RealNVP'),
80+
INN.Conv1d(5, kernel_size=1, method='NICE'),
81+
INN.Reshape(shape_in=(5,8), shape_out=(40,)))
82+
TestJacobian(model, shape=(5, 8))
83+
84+
print('#'*8 + ' 1x1 Conv1d ' + '#'*8)
85+
model = INN.Sequential(INN.Conv1d(5, kernel_size=1, method='RealNVP'),
86+
INN.Linear1d(5),
87+
INN.Reshape(shape_in=(5,8), shape_out=(40,)))
88+
TestJacobian(model, shape=(5, 8))
89+
90+
print('#'*8 + ' BatchNorm1d (Linear) ' + '#'*8)
91+
model = INN.BatchNorm1d(5)
92+
model.running_var *= torch.exp(torch.randn(1))
93+
TestJacobian(model, shape=(5,))
94+
95+
print('#'*8 + ' BatchNorm1d (1d) ' + '#'*8)
96+
model = INN.Sequential(#INN.Conv1d(5, kernel_size=1, method='RealNVP'),
97+
INN.BatchNorm1d(5),
98+
INN.Reshape(shape_in=(5,8), shape_out=(40,)))
99+
model[0].running_var *= torch.exp(torch.randn(1))
100+
TestJacobian(model, shape=(5, 8))

tests/test_basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,20 @@ def BasicTest(model, dim, requires_grad=False, batch_size=8):
129129
print('Sequential:')
130130
model = INN.Sequential(INN.Reshape(shape_in=(8,8), shape_out=(64,)),
131131
INN.Reshape(shape_in=(64,), shape_out=(32, 2)))
132-
BasicTest(model, [8, 8], requires_grad=False)
132+
BasicTest(model, [8, 8], requires_grad=False)
133+
134+
print('#'*32 + ' Testing BatchNorm1d (Linear inputs)' + '#'*32)
135+
model = INN.BatchNorm1d(5).eval()
136+
BasicTest(model, [5], requires_grad=False)
137+
print('Sequential:')
138+
model = INN.Sequential(model,
139+
model)
140+
BasicTest(model, [5], requires_grad=False)
141+
142+
print('#'*32 + ' Testing BatchNorm1d (1d CNN)' + '#'*32)
143+
model = INN.BatchNorm1d(5).eval()
144+
BasicTest(model, [5, 8], requires_grad=False)
145+
print('Sequential:')
146+
model = INN.Sequential(model,
147+
model)
148+
BasicTest(model, [5, 8], requires_grad=False)

0 commit comments

Comments
 (0)