Skip to content

Commit bfa99f7

Browse files
committed
add batchnormalization
1 parent babc1f4 commit bfa99f7

File tree

2 files changed

+66
-33
lines changed

2 files changed

+66
-33
lines changed

stochman/nnj.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,27 @@ def __call__(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Ten
266266
return val
267267

268268

269+
class BatchNorm1d(AbstractActivationJacobian, nn.BatchNorm1d):
270+
# only implements jacobian during testing
271+
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
272+
jac = (self.weight / (self.running_var + self.eps).sqrt()).unsqueeze(0)
273+
return jac
274+
275+
276+
class BatchNorm2d(AbstractActivationJacobian, nn.BatchNorm2d):
277+
# only implements jacobian during testing
278+
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
279+
jac = (self.weight / (self.running_var + self.eps).sqrt()).unsqueeze(0)
280+
return jac
281+
282+
283+
class BatchNorm3d(AbstractActivationJacobian, nn.BatchNorm3d):
284+
# only implements jacobian during testing
285+
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
286+
jac = (self.weight / (self.running_var + self.eps).sqrt()).unsqueeze(0)
287+
return jac
288+
289+
269290
class Sigmoid(AbstractActivationJacobian, nn.Sigmoid):
270291
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
271292
jac = val * (1.0 - val)
@@ -302,7 +323,7 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
302323
class LeakyReLU(AbstractActivationJacobian, nn.LeakyReLU):
303324
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
304325
jac = torch.ones_like(val)
305-
jac[val < 0.0] = self.negative_slope
326+
jac[x < 0.0] = self.negative_slope
306327
return jac
307328

308329

tests/test_nnj.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
_features = 5
1212
_dims = 6
1313

14-
_linear_input = torch.randn(_batch_size, _features)
15-
_1d_conv_input = torch.randn(_batch_size, _features, _dims)
16-
_2d_conv_input = torch.randn(_batch_size, _features, _dims, _dims)
17-
_3d_conv_input = torch.randn(_batch_size, _features, _dims, _dims, _dims)
14+
_linear_input_shape = (_batch_size, _features)
15+
_1d_conv_input_shape = (_batch_size, _features, _dims)
16+
_2d_conv_input_shape = (_batch_size, _features, _dims, _dims)
17+
_3d_conv_input_shape = (_batch_size, _features, _dims, _dims, _dims)
1818

1919

2020
def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
@@ -28,18 +28,16 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
2828

2929

3030
@pytest.mark.parametrize(
31-
"model, input",
31+
"model, input_shape",
3232
[
33-
(nnj.Sequential(nnj.Identity(), nnj.Identity()), _linear_input),
34-
(nnj.Linear(_features, 2), _linear_input),
35-
(nnj.Sequential(nnj.PosLinear(_features, 2), nnj.Reciprocal()), _linear_input),
36-
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ArcTanh()), _linear_input),
37-
(nnj.Sequential(nnj.Linear(_features, 5), nnj.Sigmoid(), nnj.Linear(5, 2)), _linear_input),
33+
(nnj.Sequential(nnj.Identity(), nnj.Identity()), _linear_input_shape),
34+
(nnj.Linear(_features, 2), _linear_input_shape),
35+
(nnj.Sequential(nnj.PosLinear(_features, 2), nnj.Reciprocal()), _linear_input_shape),
36+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ArcTanh()), _linear_input_shape),
37+
(nnj.Sequential(nnj.Linear(_features, 5), nnj.Sigmoid(), nnj.Linear(5, 2)), _linear_input_shape),
3838
(
39-
nnj.Sequential(
40-
nnj.Linear(_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh()
41-
),
42-
_linear_input,
39+
nnj.Sequential(nnj.Linear(_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4)),
40+
_linear_input_shape,
4341
),
4442
(
4543
nnj.Sequential(
@@ -50,21 +48,31 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
5048
nnj.Sqrt(),
5149
nnj.Hardshrink(),
5250
),
53-
_linear_input,
51+
_linear_input_shape,
52+
),
53+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.LeakyReLU()), _linear_input_shape),
54+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Tanh()), _linear_input_shape),
55+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.OneMinusX()), _linear_input_shape),
56+
(
57+
nnj.Sequential(nnj.Conv1d(_features, 2, 5), nnj.ConvTranspose1d(2, _features, 5)),
58+
_1d_conv_input_shape,
59+
),
60+
(
61+
nnj.Sequential(nnj.Conv2d(_features, 2, 5), nnj.ConvTranspose2d(2, _features, 5)),
62+
_2d_conv_input_shape,
63+
),
64+
(
65+
nnj.Sequential(nnj.Conv3d(_features, 2, 5), nnj.ConvTranspose3d(2, _features, 5)),
66+
_3d_conv_input_shape,
5467
),
55-
(nnj.Sequential(nnj.Linear(_features, 2), nnj.LeakyReLU()), _linear_input),
56-
(nnj.Sequential(nnj.Linear(_features, 2), nnj.OneMinusX()), _linear_input),
57-
(nnj.Sequential(nnj.Conv1d(_features, 2, 5), nnj.ConvTranspose1d(2, _features, 5)), _1d_conv_input),
58-
(nnj.Sequential(nnj.Conv2d(_features, 2, 5), nnj.ConvTranspose2d(2, _features, 5)), _2d_conv_input),
59-
(nnj.Sequential(nnj.Conv3d(_features, 2, 5), nnj.ConvTranspose3d(2, _features, 5)), _3d_conv_input),
6068
(
6169
nnj.Sequential(
6270
nnj.Linear(_features, 8),
6371
nnj.Sigmoid(),
6472
nnj.Reshape(2, 4),
6573
nnj.Conv1d(2, 1, 2),
6674
),
67-
_linear_input,
75+
_linear_input_shape,
6876
),
6977
(
7078
nnj.Sequential(
@@ -73,7 +81,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
7381
nnj.Reshape(2, 4, 4),
7482
nnj.Conv2d(2, 1, 2),
7583
),
76-
_linear_input,
84+
_linear_input_shape,
7785
),
7886
(
7987
nnj.Sequential(
@@ -82,7 +90,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
8290
nnj.Reshape(2, 4, 4, 4),
8391
nnj.Conv3d(2, 1, 2),
8492
),
85-
_linear_input,
93+
_linear_input_shape,
8694
),
8795
(
8896
nnj.Sequential(
@@ -91,7 +99,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
9199
nnj.Linear(4 * 2, 5),
92100
nnj.ReLU(),
93101
),
94-
_1d_conv_input,
102+
_1d_conv_input_shape,
95103
),
96104
(
97105
nnj.Sequential(
@@ -100,7 +108,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
100108
nnj.Linear(4 * 4 * 2, 5),
101109
nnj.ReLU(),
102110
),
103-
_2d_conv_input,
111+
_2d_conv_input_shape,
104112
),
105113
(
106114
nnj.Sequential(
@@ -109,30 +117,34 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
109117
nnj.Linear(4 * 4 * 4 * 2, 5),
110118
nnj.ReLU(),
111119
),
112-
_3d_conv_input,
120+
_3d_conv_input_shape,
113121
),
114122
(
115123
nnj.Sequential(nnj.Conv2d(_features, 2, 3), nnj.Hardtanh(), nnj.Upsample(scale_factor=2)),
116-
_2d_conv_input,
124+
_2d_conv_input_shape,
117125
),
126+
(nnj.Sequential(nnj.Conv1d(_features, 3, 3), nnj.BatchNorm1d(3)), _1d_conv_input_shape),
127+
(nnj.Sequential(nnj.Conv2d(_features, 3, 3), nnj.BatchNorm2d(3)), _2d_conv_input_shape),
128+
(nnj.Sequential(nnj.Conv3d(_features, 3, 3), nnj.BatchNorm3d(3)), _3d_conv_input_shape),
118129
],
119130
)
120131
class TestJacobian:
121132
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
122-
def test_jacobians(self, model, input, dtype):
133+
def test_jacobians(self, model, input_shape, dtype):
123134
"""Test that the analytical jacobian of the model is consistent with finite
124135
order approximation
125136
"""
126-
model = deepcopy(model).to(dtype)
127-
input = deepcopy(input).to(dtype)
137+
model = deepcopy(model).to(dtype).eval()
138+
input = torch.randn(*input_shape, dtype=dtype)
128139
_, jac = model(input, jacobian=True)
129140
jacnum = _compare_jacobian(model, input)
130141
assert torch.isclose(jac, jacnum, atol=1e-7).all(), "jacobians did not match"
131142

132143
@pytest.mark.parametrize("return_jac", [True, False])
133-
def test_jac_return(self, model, input, return_jac):
144+
def test_jac_return(self, model, input_shape, return_jac):
134145
""" Test that all models returns the jacobian output if asked for it """
135-
output = model(input, jacobian=return_jac)
146+
147+
output = model(torch.randn(*input_shape), jacobian=return_jac)
136148
if return_jac:
137149
assert len(output) == 2, "expected two outputs when jacobian=True"
138150
assert all(

0 commit comments

Comments
 (0)