Skip to content

Commit 99cf692

Browse files
committed
improve testing
1 parent 846820a commit 99cf692

File tree

2 files changed

+31
-86
lines changed

2 files changed

+31
-86
lines changed

stochman/nnj.py

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,8 @@ def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
7373

7474
class PosLinear(AbstractJacobian, nn.Linear):
7575
def forward(self, x: Tensor):
76-
if self.bias is None:
77-
val = F.linear(x, F.softplus(self.weight))
78-
else:
79-
val = F.linear(x, F.softplus(self.weight), F.softplus(self.bias))
76+
bias = F.softplus(self.bias) if self.bias is not None else self.bias
77+
val = F.linear(x, F.softplus(self.weight), bias)
8078
return val
8179

8280
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
@@ -87,42 +85,21 @@ class Upsample(AbstractJacobian, nn.Upsample):
8785
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
8886
xs = x.shape
8987
vs = val.shape
90-
if x.ndim == 3:
91-
return (
92-
F.interpolate(
93-
jac_in.movedim((1, 2), (-2, -1)).reshape(-1, *xs[1:]),
94-
self.size,
95-
self.scale_factor,
96-
self.mode,
97-
self.align_corners,
98-
)
99-
.reshape(xs[0], *jac_in.shape[3:], *vs[1:])
100-
.movedim((-2, -1), (1, 2))
101-
)
102-
if x.ndim == 4:
103-
return (
104-
F.interpolate(
105-
jac_in.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, *xs[1:]),
106-
self.size,
107-
self.scale_factor,
108-
self.mode,
109-
self.align_corners,
110-
)
111-
.reshape(xs[0], *jac_in.shape[4:], *vs[1:])
112-
.movedim((-3, -2, -1), (1, 2, 3))
113-
)
114-
if x.ndim == 5:
115-
return (
116-
F.interpolate(
117-
jac_in.movedim((1, 2, 3, 4), (-4, -3, -2, -1)).reshape(-1, *xs[1:]),
118-
self.size,
119-
self.scale_factor,
120-
self.mode,
121-
self.align_corners,
122-
)
123-
.reshape(xs[0], *jac_in.shape[5:], *vs[1:])
124-
.movedim((-4, -3, -2, -1), (1, 2, 3, 4))
88+
89+
dims1 = tuple(range(1, x.ndim))
90+
dims2 = tuple(range(-x.ndim + 1, 0))
91+
92+
return (
93+
F.interpolate(
94+
jac_in.movedim(dims1, dims2).reshape(-1, *xs[1:]),
95+
self.size,
96+
self.scale_factor,
97+
self.mode,
98+
self.align_corners,
12599
)
100+
.reshape(xs[0], *jac_in.shape[x.ndim :], *vs[1:])
101+
.movedim(dims2, dims1)
102+
)
126103

127104

128105
class Conv1d(AbstractJacobian, nn.Conv1d):
@@ -308,7 +285,7 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
308285
class Hardshrink(AbstractActivationJacobian, nn.Hardshrink):
309286
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
310287
jac = torch.ones_like(val)
311-
jac[-self.lambd < x < self.lambd] = 0.0
288+
jac[torch.logical_and(-self.lambd < x, x < self.lambd)] = 0.0
312289
return jac
313290

314291

tests/test_nnj.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,57 +23,22 @@ def _compare_jacobian(f, x):
2323
return res
2424

2525

26-
"""
27-
_models = [
28-
nnj.Sequential(
29-
nnj.Linear(_in_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh()
30-
),
31-
nnj.Sequential(nnj.RBF(_in_features, 30), nnj.Linear(30, 2)),
32-
nnj.Sequential(nnj.Linear(_in_features, 4), nnj.Norm2()),
33-
nnj.Sequential(nnj.Linear(_in_features, 50), nnj.ReLU(), nnj.Linear(50, 100), nnj.Softplus()),
34-
nnj.Sequential(nnj.Linear(_in_features, 256)),
35-
nnj.Sequential(nnj.Softplus(), nnj.Linear(_in_features, 3), nnj.Softplus()),
36-
nnj.Sequential(nnj.Softplus(), nnj.Sigmoid(), nnj.Linear(_in_features, 3)),
37-
nnj.Sequential(nnj.Softplus(), nnj.Sigmoid()),
38-
nnj.Sequential(nnj.Linear(_in_features, 3), nnj.OneMinusX()),
39-
nnj.Sequential(
40-
nnj.PosLinear(_in_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.PosLinear(2, 4), nnj.Tanh()
41-
),
42-
nnj.Sequential(nnj.PosLinear(_in_features, 5), nnj.Reciprocal(b=1.0)),
43-
nnj.Sequential(nnj.ReLU(), nnj.ELU(), nnj.LeakyReLU(), nnj.Sigmoid(), nnj.Softplus(), nnj.Tanh()),
44-
nnj.Sequential(nnj.ReLU()),
45-
nnj.Sequential(nnj.ELU()),
46-
nnj.Sequential(nnj.LeakyReLU()),
47-
nnj.Sequential(nnj.Sigmoid()),
48-
nnj.Sequential(nnj.Softplus()),
49-
nnj.Sequential(nnj.Tanh()),
50-
nnj.Sequential(nnj.Hardshrink()),
51-
nnj.Sequential(nnj.Hardtanh()),
52-
nnj.Sequential(nnj.ResidualBlock(nnj.Linear(_in_features, 50), nnj.ReLU())),
53-
nnj.Sequential(nnj.BatchNorm1d(_in_features)),
54-
nnj.Sequential(
55-
nnj.BatchNorm1d(_in_features),
56-
nnj.ResidualBlock(nnj.Linear(_in_features, 25), nnj.Softplus()),
57-
nnj.BatchNorm1d(25),
58-
nnj.ResidualBlock(nnj.Linear(25, 25), nnj.Softplus()),
59-
),
60-
]
61-
"""
62-
63-
64-
6526
@pytest.mark.parametrize("model, input",
6627
[
67-
(nnj.Sequential(nnj.Linear(_features, 2)), _linear_input),
68-
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid()), _linear_input),
28+
(nnj.Sequential(nnj.Identity(), nnj.Identity()), _linear_input),
29+
(nnj.Linear(_features, 2), _linear_input),
30+
(nnj.PosLinear(_features, 2), _linear_input),
31+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ArcTanh()), _linear_input),
6932
(nnj.Sequential(nnj.Linear(_features, 5), nnj.Sigmoid(), nnj.Linear(5, 2)), _linear_input),
7033
(nnj.Sequential(
7134
nnj.Linear(_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh()
7235
), _linear_input),
73-
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ReLU()), _linear_input),
74-
(nnj.Sequential(nnj.Conv1d(_features, 2, 5)), _1d_conv_input),
75-
(nnj.Sequential(nnj.Conv2d(_features, 2, 5)), _2d_conv_input),
76-
(nnj.Sequential(nnj.Conv3d(_features, 2, 5)), _3d_conv_input),
36+
(nnj.Sequential(
37+
nnj.ELU(), nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ReLU(), nnj.Hardshrink(), nnj.LeakyReLU()
38+
), _linear_input),
39+
(nnj.Sequential(nnj.Conv1d(_features, 2, 5), nnj.ConvTranspose1d(2, _features, 5)), _1d_conv_input),
40+
(nnj.Sequential(nnj.Conv2d(_features, 2, 5), nnj.ConvTranspose2d(2, _features, 5)), _2d_conv_input),
41+
(nnj.Sequential(nnj.Conv3d(_features, 2, 5), nnj.ConvTranspose3d(2, _features, 5)), _3d_conv_input),
7742
(nnj.Sequential(
7843
nnj.Linear(_features, 8), nnj.Sigmoid(), nnj.Reshape(2, 4), nnj.Conv1d(2, 1, 2),
7944
),_linear_input),
@@ -91,7 +56,10 @@ def _compare_jacobian(f, x):
9156
),_2d_conv_input),
9257
(nnj.Sequential(
9358
nnj.Conv3d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*8*8*2, 5), nnj.ReLU(),
94-
),_3d_conv_input)
59+
),_3d_conv_input),
60+
(nnj.Sequential(
61+
nnj.Conv2d(_features, 2, 3), nnj.Hardtanh(), nnj.Upsample(scale_factor=2)
62+
), _2d_conv_input)
9563
]
9664
)
9765
class TestJacobian:

0 commit comments

Comments
 (0)