|
4 | 4 |
|
5 | 5 | from stochman import nnj
|
6 | 6 |
|
| 7 | +_batch_size = 2 |
| 8 | +_features = 5 |
| 9 | +_dims = 10 |
7 | 10 |
|
8 |
| -def _fd_jacobian(function, x, h=1e-4): |
9 |
| - """Compute finite difference Jacobian of given function |
10 |
| - at a single location x. This function is mainly considered |
11 |
| - useful for debugging.""" |
| 11 | +_linear_input = torch.randn(_batch_size, _features) |
| 12 | +_1d_conv_input = torch.randn(_batch_size, _features, _dims) |
| 13 | +_2d_conv_input = torch.randn(_batch_size, _features, _dims, _dims) |
| 14 | +_3d_conv_input = torch.randn(_batch_size, _features, _dims, _dims, _dims) |
12 | 15 |
|
13 |
| - no_batch = x.dim() == 1 |
14 |
| - if no_batch: |
15 |
| - x = x.unsqueeze(0) |
16 |
| - elif x.dim() > 2: |
17 |
| - raise Exception("The input should be a D-vector or a BxD matrix") |
18 |
| - B, D = x.shape |
19 | 16 |
|
20 |
| - # Compute finite differences |
21 |
| - E = h * torch.eye(D) |
22 |
| - try: |
23 |
| - function.eval() |
24 |
| - # Disable "training" in the function (relevant eg. for batch normalization) |
25 |
| - Jnum = torch.cat( |
26 |
| - [((function(x[b] + E) - function(x[b].unsqueeze(0))).t() / h).unsqueeze(0) for b in range(B)] |
27 |
| - ) |
28 |
| - finally: |
29 |
| - function.train() |
| 17 | +def _compare_jacobian(f, x): |
| 18 | + out = f(x) |
| 19 | + output = torch.autograd.functional.jacobian(f, x) |
| 20 | + m = out.ndim |
| 21 | + output = output.movedim(m,1) |
| 22 | + res = torch.stack([output[i,i] for i in range(_batch_size)], dim=0) |
| 23 | + return res |
30 | 24 |
|
31 |
| - if no_batch: |
32 |
| - Jnum = Jnum.squeeze(0) |
33 | 25 |
|
34 |
| - return Jnum |
35 |
| - |
36 |
| - |
37 |
| -def _jacobian_check(function, in_dim=None): |
38 |
| - """Accepts an nnj module and checks the |
39 |
| - Jacobian via the finite differences method. |
40 |
| -
|
41 |
| - Args: |
42 |
| - function: An nnj module object. The |
43 |
| - function to be tested. |
44 |
| -
|
45 |
| - Returns a tuple of the following form: |
46 |
| - (Jacobian_analytical, Jacobian_finite_differences) |
47 |
| - """ |
48 |
| - |
49 |
| - with torch.no_grad(): |
50 |
| - batch_size = 5 |
51 |
| - if in_dim is None: |
52 |
| - in_dim, _ = function.dimensions() |
53 |
| - if in_dim is None: |
54 |
| - in_dim = 10 |
55 |
| - x = torch.randn(batch_size, in_dim) |
56 |
| - try: |
57 |
| - function.eval() |
58 |
| - y, J = function(x, jacobian=True) |
59 |
| - finally: |
60 |
| - function.train() |
61 |
| - |
62 |
| - if J.jactype == nnj.JacType.DIAG: |
63 |
| - J = J.diag_embed() |
64 |
| - |
65 |
| - Jnum = _fd_jacobian(function, x) |
66 |
| - |
67 |
| - return J, Jnum |
68 |
| - |
69 |
| - |
70 |
| -_in_features = 10 |
| 26 | +""" |
71 | 27 | _models = [
|
72 | 28 | nnj.Sequential(
|
73 | 29 | nnj.Linear(_in_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh()
|
@@ -102,54 +58,58 @@ def _jacobian_check(function, in_dim=None):
|
102 | 58 | nnj.ResidualBlock(nnj.Linear(25, 25), nnj.Softplus()),
|
103 | 59 | ),
|
104 | 60 | ]
|
| 61 | +""" |
| 62 | + |
| 63 | + |
| 64 | + |
| 65 | +@pytest.mark.parametrize("model, input", |
| 66 | + [ |
| 67 | + (nnj.Sequential(nnj.Linear(_features, 2)), _linear_input), |
| 68 | + (nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid()), _linear_input), |
| 69 | + (nnj.Sequential(nnj.Linear(_features, 5), nnj.Sigmoid(), nnj.Linear(5, 2)), _linear_input), |
| 70 | + (nnj.Sequential( |
| 71 | + nnj.Linear(_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh() |
| 72 | + ), _linear_input), |
| 73 | + (nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ReLU()), _linear_input), |
| 74 | + (nnj.Sequential(nnj.Conv1d(_features, 2, 5)), _1d_conv_input), |
| 75 | + (nnj.Sequential(nnj.Conv2d(_features, 2, 5)), _2d_conv_input), |
| 76 | + (nnj.Sequential(nnj.Conv3d(_features, 2, 5)), _3d_conv_input), |
| 77 | + (nnj.Sequential( |
| 78 | + nnj.Linear(_features, 8), nnj.Sigmoid(), nnj.Reshape(2, 4), nnj.Conv1d(2, 1, 2), |
| 79 | + ),_linear_input), |
| 80 | + (nnj.Sequential( |
| 81 | + nnj.Linear(_features, 32), nnj.Sigmoid(), nnj.Reshape(2, 4, 4), nnj.Conv2d(2, 1, 2), |
| 82 | + ),_linear_input), |
| 83 | + (nnj.Sequential( |
| 84 | + nnj.Linear(_features, 128), nnj.Sigmoid(), nnj.Reshape(2, 4, 4, 4), nnj.Conv3d(2, 1, 2), |
| 85 | + ),_linear_input), |
| 86 | + (nnj.Sequential( |
| 87 | + nnj.Conv1d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*2, 5), nnj.ReLU(), |
| 88 | + ),_1d_conv_input), |
| 89 | + (nnj.Sequential( |
| 90 | + nnj.Conv2d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*8*2, 5), nnj.ReLU(), |
| 91 | + ),_2d_conv_input), |
| 92 | + (nnj.Sequential( |
| 93 | + nnj.Conv3d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*8*8*2, 5), nnj.ReLU(), |
| 94 | + ),_3d_conv_input) |
| 95 | + ] |
| 96 | +) |
| 97 | +class TestJacobian: |
| 98 | + def test_jacobians(self, model, input): |
| 99 | + """Test that the analytical jacobian of the model is consistent with finite |
| 100 | + order approximation |
| 101 | + """ |
| 102 | + _, jac = model(input, jacobian=True) |
| 103 | + jacnum = _compare_jacobian(model, input) |
| 104 | + assert torch.isclose(jac, jacnum, atol=1e-7).all(), "jacobians did not match" |
| 105 | + |
| 106 | + @pytest.mark.parametrize("return_jac", [True, False]) |
| 107 | + def test_jac_return(self, model, input, return_jac): |
| 108 | + """ Test that all models returns the jacobian output if asked for it """ |
| 109 | + output = model(input, jacobian=return_jac) |
| 110 | + if return_jac: |
| 111 | + assert len(output) == 2, "expected two outputs when jacobian=True" |
| 112 | + assert all(isinstance(o, torch.Tensor) for o in output), "expected all outputs to be torch tensors" |
| 113 | + else: |
| 114 | + assert isinstance(output, torch.Tensor) |
105 | 115 |
|
106 |
| - |
107 |
| -@pytest.mark.parametrize("model", _models) |
108 |
| -def test_jacobians(model): |
109 |
| - """Test that the analytical jacobian of the model is consistent with finite |
110 |
| - order approximation |
111 |
| - """ |
112 |
| - J, Jnum = _jacobian_check(model, _in_features) |
113 |
| - numpy.testing.assert_allclose(J, Jnum, rtol=1, atol=1e-2) |
114 |
| - |
115 |
| - |
116 |
| -@pytest.mark.parametrize("model", _models) |
117 |
| -@pytest.mark.parametrize("return_jac", [True, False]) |
118 |
| -def test_jac_return(model, return_jac): |
119 |
| - x = torch.randn(5, 10) |
120 |
| - output = model(x, jacobian=return_jac) |
121 |
| - if return_jac: |
122 |
| - assert len(output) == 2, "expected two outputs when jacobian=True" |
123 |
| - assert all(isinstance(o, torch.Tensor) for o in output), "expected all outputs to be torch tensors" |
124 |
| - else: |
125 |
| - assert isinstance(output, torch.Tensor) |
126 |
| - |
127 |
| - |
128 |
| -_testcases = [ |
129 |
| - (nnj.jacobian(torch.ones(5, 10, 10), "full"), nnj.jacobian(torch.ones(5, 10, 10), "full")), |
130 |
| - (nnj.jacobian(torch.ones(5, 10, 10), "full"), nnj.jacobian(torch.ones(5, 10), "diag")), |
131 |
| - (nnj.jacobian(torch.ones(5, 10), "diag"), nnj.jacobian(torch.ones(5, 10, 10), "full")), |
132 |
| - (nnj.jacobian(torch.ones(5, 10), "diag"), nnj.jacobian(torch.ones(5, 10), "diag")), |
133 |
| -] |
134 |
| - |
135 |
| - |
136 |
| -@pytest.mark.parametrize("cases", _testcases) |
137 |
| -def test_add(cases): |
138 |
| - """ test generic add for different combinations of jacobians""" |
139 |
| - j_out = cases[0] + cases[1] |
140 |
| - assert isinstance(j_out, nnj.Jacobian) |
141 |
| - |
142 |
| - if cases[0].jactype == cases[1].jactype: |
143 |
| - # if same type, all elements should be 2 |
144 |
| - j_out = j_out.flatten() |
145 |
| - assert all(j_out == 2 * torch.ones_like(j_out)) |
146 |
| - else: |
147 |
| - # if not same type, only diag should be 2 |
148 |
| - j_out_diag = torch.stack([jo.diag() for jo in j_out]).flatten() |
149 |
| - assert all(j_out_diag == 2 * torch.ones_like(j_out_diag)) |
150 |
| - |
151 |
| - |
152 |
| -@pytest.mark.parametrize("cases", _testcases) |
153 |
| -def test_matmul(cases): |
154 |
| - j_out = cases[0] @ cases[1] |
155 |
| - assert isinstance(j_out, nnj.Jacobian) |
0 commit comments