Skip to content

Commit a4f4069

Browse files
committed
improve testing
1 parent 1371f58 commit a4f4069

File tree

1 file changed

+69
-109
lines changed

1 file changed

+69
-109
lines changed

tests/test_nnj.py

Lines changed: 69 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -4,70 +4,26 @@
44

55
from stochman import nnj
66

7+
_batch_size = 2
8+
_features = 5
9+
_dims = 10
710

8-
def _fd_jacobian(function, x, h=1e-4):
9-
"""Compute finite difference Jacobian of given function
10-
at a single location x. This function is mainly considered
11-
useful for debugging."""
11+
_linear_input = torch.randn(_batch_size, _features)
12+
_1d_conv_input = torch.randn(_batch_size, _features, _dims)
13+
_2d_conv_input = torch.randn(_batch_size, _features, _dims, _dims)
14+
_3d_conv_input = torch.randn(_batch_size, _features, _dims, _dims, _dims)
1215

13-
no_batch = x.dim() == 1
14-
if no_batch:
15-
x = x.unsqueeze(0)
16-
elif x.dim() > 2:
17-
raise Exception("The input should be a D-vector or a BxD matrix")
18-
B, D = x.shape
1916

20-
# Compute finite differences
21-
E = h * torch.eye(D)
22-
try:
23-
function.eval()
24-
# Disable "training" in the function (relevant eg. for batch normalization)
25-
Jnum = torch.cat(
26-
[((function(x[b] + E) - function(x[b].unsqueeze(0))).t() / h).unsqueeze(0) for b in range(B)]
27-
)
28-
finally:
29-
function.train()
17+
def _compare_jacobian(f, x):
18+
out = f(x)
19+
output = torch.autograd.functional.jacobian(f, x)
20+
m = out.ndim
21+
output = output.movedim(m,1)
22+
res = torch.stack([output[i,i] for i in range(_batch_size)], dim=0)
23+
return res
3024

31-
if no_batch:
32-
Jnum = Jnum.squeeze(0)
3325

34-
return Jnum
35-
36-
37-
def _jacobian_check(function, in_dim=None):
38-
"""Accepts an nnj module and checks the
39-
Jacobian via the finite differences method.
40-
41-
Args:
42-
function: An nnj module object. The
43-
function to be tested.
44-
45-
Returns a tuple of the following form:
46-
(Jacobian_analytical, Jacobian_finite_differences)
47-
"""
48-
49-
with torch.no_grad():
50-
batch_size = 5
51-
if in_dim is None:
52-
in_dim, _ = function.dimensions()
53-
if in_dim is None:
54-
in_dim = 10
55-
x = torch.randn(batch_size, in_dim)
56-
try:
57-
function.eval()
58-
y, J = function(x, jacobian=True)
59-
finally:
60-
function.train()
61-
62-
if J.jactype == nnj.JacType.DIAG:
63-
J = J.diag_embed()
64-
65-
Jnum = _fd_jacobian(function, x)
66-
67-
return J, Jnum
68-
69-
70-
_in_features = 10
26+
"""
7127
_models = [
7228
nnj.Sequential(
7329
nnj.Linear(_in_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh()
@@ -102,54 +58,58 @@ def _jacobian_check(function, in_dim=None):
10258
nnj.ResidualBlock(nnj.Linear(25, 25), nnj.Softplus()),
10359
),
10460
]
61+
"""
62+
63+
64+
65+
@pytest.mark.parametrize("model, input",
66+
[
67+
(nnj.Sequential(nnj.Linear(_features, 2)), _linear_input),
68+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid()), _linear_input),
69+
(nnj.Sequential(nnj.Linear(_features, 5), nnj.Sigmoid(), nnj.Linear(5, 2)), _linear_input),
70+
(nnj.Sequential(
71+
nnj.Linear(_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh()
72+
), _linear_input),
73+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ReLU()), _linear_input),
74+
(nnj.Sequential(nnj.Conv1d(_features, 2, 5)), _1d_conv_input),
75+
(nnj.Sequential(nnj.Conv2d(_features, 2, 5)), _2d_conv_input),
76+
(nnj.Sequential(nnj.Conv3d(_features, 2, 5)), _3d_conv_input),
77+
(nnj.Sequential(
78+
nnj.Linear(_features, 8), nnj.Sigmoid(), nnj.Reshape(2, 4), nnj.Conv1d(2, 1, 2),
79+
),_linear_input),
80+
(nnj.Sequential(
81+
nnj.Linear(_features, 32), nnj.Sigmoid(), nnj.Reshape(2, 4, 4), nnj.Conv2d(2, 1, 2),
82+
),_linear_input),
83+
(nnj.Sequential(
84+
nnj.Linear(_features, 128), nnj.Sigmoid(), nnj.Reshape(2, 4, 4, 4), nnj.Conv3d(2, 1, 2),
85+
),_linear_input),
86+
(nnj.Sequential(
87+
nnj.Conv1d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*2, 5), nnj.ReLU(),
88+
),_1d_conv_input),
89+
(nnj.Sequential(
90+
nnj.Conv2d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*8*2, 5), nnj.ReLU(),
91+
),_2d_conv_input),
92+
(nnj.Sequential(
93+
nnj.Conv3d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*8*8*2, 5), nnj.ReLU(),
94+
),_3d_conv_input)
95+
]
96+
)
97+
class TestJacobian:
98+
def test_jacobians(self, model, input):
99+
"""Test that the analytical jacobian of the model is consistent with finite
100+
order approximation
101+
"""
102+
_, jac = model(input, jacobian=True)
103+
jacnum = _compare_jacobian(model, input)
104+
assert torch.isclose(jac, jacnum, atol=1e-7).all(), "jacobians did not match"
105+
106+
@pytest.mark.parametrize("return_jac", [True, False])
107+
def test_jac_return(self, model, input, return_jac):
108+
""" Test that all models returns the jacobian output if asked for it """
109+
output = model(input, jacobian=return_jac)
110+
if return_jac:
111+
assert len(output) == 2, "expected two outputs when jacobian=True"
112+
assert all(isinstance(o, torch.Tensor) for o in output), "expected all outputs to be torch tensors"
113+
else:
114+
assert isinstance(output, torch.Tensor)
105115

106-
107-
@pytest.mark.parametrize("model", _models)
108-
def test_jacobians(model):
109-
"""Test that the analytical jacobian of the model is consistent with finite
110-
order approximation
111-
"""
112-
J, Jnum = _jacobian_check(model, _in_features)
113-
numpy.testing.assert_allclose(J, Jnum, rtol=1, atol=1e-2)
114-
115-
116-
@pytest.mark.parametrize("model", _models)
117-
@pytest.mark.parametrize("return_jac", [True, False])
118-
def test_jac_return(model, return_jac):
119-
x = torch.randn(5, 10)
120-
output = model(x, jacobian=return_jac)
121-
if return_jac:
122-
assert len(output) == 2, "expected two outputs when jacobian=True"
123-
assert all(isinstance(o, torch.Tensor) for o in output), "expected all outputs to be torch tensors"
124-
else:
125-
assert isinstance(output, torch.Tensor)
126-
127-
128-
_testcases = [
129-
(nnj.jacobian(torch.ones(5, 10, 10), "full"), nnj.jacobian(torch.ones(5, 10, 10), "full")),
130-
(nnj.jacobian(torch.ones(5, 10, 10), "full"), nnj.jacobian(torch.ones(5, 10), "diag")),
131-
(nnj.jacobian(torch.ones(5, 10), "diag"), nnj.jacobian(torch.ones(5, 10, 10), "full")),
132-
(nnj.jacobian(torch.ones(5, 10), "diag"), nnj.jacobian(torch.ones(5, 10), "diag")),
133-
]
134-
135-
136-
@pytest.mark.parametrize("cases", _testcases)
137-
def test_add(cases):
138-
""" test generic add for different combinations of jacobians"""
139-
j_out = cases[0] + cases[1]
140-
assert isinstance(j_out, nnj.Jacobian)
141-
142-
if cases[0].jactype == cases[1].jactype:
143-
# if same type, all elements should be 2
144-
j_out = j_out.flatten()
145-
assert all(j_out == 2 * torch.ones_like(j_out))
146-
else:
147-
# if not same type, only diag should be 2
148-
j_out_diag = torch.stack([jo.diag() for jo in j_out]).flatten()
149-
assert all(j_out_diag == 2 * torch.ones_like(j_out_diag))
150-
151-
152-
@pytest.mark.parametrize("cases", _testcases)
153-
def test_matmul(cases):
154-
j_out = cases[0] @ cases[1]
155-
assert isinstance(j_out, nnj.Jacobian)

0 commit comments

Comments
 (0)