Skip to content

Commit 734162e

Browse files
committed
improve testing
1 parent 03bc6f5 commit 734162e

File tree

4 files changed

+126
-47
lines changed

4 files changed

+126
-47
lines changed

stochman/curves.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
200200
(torch.floor(tt * num_edges).clamp(min=0, max=num_edges - 1).long()) # Bx|t|
201201
.unsqueeze(2)
202202
.repeat(1, 1, D)
203-
).to(self.device) # Bx|t|xD, this assumes that nodes are equi-distant
203+
).to(
204+
self.device
205+
) # Bx|t|xD, this assumes that nodes are equi-distant
204206
result = torch.gather(a, 1, idx) * tt.unsqueeze(2) + torch.gather(b, 1, idx) # Bx|t|xD
205207
if B == 1:
206208
result = result.squeeze(0) # |t|xD

stochman/nnj.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
from math import prod
12
from typing import Optional, Tuple, Union
23

34
import torch
45
import torch.nn.functional as F
56
from torch import nn, Tensor
67

7-
from math import prod
8-
98

109
class Identity(nn.Module):
1110
""" Identity module that will return the same input as it receives. """
@@ -18,7 +17,11 @@ def forward(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Tens
1817

1918
if jacobian:
2019
xs = x.shape
21-
jac = torch.eye(prod(xs[1:]), prod(xs[1:]), dtype=x.dtype).repeat(xs[0], 1, 1).reshape(xs[0], *xs[1:], *xs[1:])
20+
jac = (
21+
torch.eye(prod(xs[1:]), prod(xs[1:]), dtype=x.dtype)
22+
.repeat(xs[0], 1, 1)
23+
.reshape(xs[0], *xs[1:], *xs[1:])
24+
)
2225
return val, jac
2326
return val
2427

@@ -366,5 +369,5 @@ def forward(self, x: Tensor) -> Tensor:
366369
return val
367370

368371
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
369-
jac = -0.5 / val
372+
jac = 0.5 / val
370373
return jac

tests/test_curves.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class TestCurves:
1010
@pytest.mark.parametrize("batch_dim", [1, 5])
1111
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
1212
def test_curve_evaluation(self, curve_class, requires_grad, batch_dim, device):
13-
if not torch.cuda.is_available() and device == "cuda":
13+
if not torch.cuda.is_available() and device == "cuda:0":
1414
pytest.skip("test requires cuda")
1515

1616
dim = 2
@@ -92,3 +92,23 @@ def test_to_other(self, curve_class):
9292
elif curve_class == curves.CubicSpline:
9393
new_c = c.todiscrete()
9494
assert isinstance(new_c, curves.DiscreteCurve)
95+
96+
def test_euclidean_length(self, curve_class):
97+
begin = torch.zeros(1, 2).float()
98+
end = torch.ones(1, 2).float()
99+
c = curve_class(begin, end, 20)
100+
el = c.euclidean_length()
101+
assert torch.isclose(el, torch.tensor([2.0]).sqrt())
102+
103+
def test_constant_speed(self, curve_class):
104+
batch_size = 5
105+
dim = 2
106+
timesteps = 50
107+
begin = torch.randn(batch_size, dim)
108+
end = torch.randn(batch_size, dim)
109+
c = curve_class(begin, end, 20)
110+
new_t, Ct = c.constant_speed(t=torch.linspace(0, 1, timesteps))
111+
assert isinstance(new_t, torch.Tensor)
112+
assert isinstance(Ct, torch.Tensor)
113+
assert new_t.shape == (batch_size, timesteps)
114+
assert Ct.shape == (batch_size, timesteps, dim)

tests/test_nnj.py

Lines changed: 95 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from copy import deepcopy
2+
from typing import Callable
3+
14
import numpy
25
import pytest
36
import torch
4-
from copy import deepcopy
7+
58
from stochman import nnj
69

710
_batch_size = 2
@@ -14,75 +17,126 @@
1417
_3d_conv_input = torch.randn(_batch_size, _features, _dims, _dims, _dims)
1518

1619

17-
def _compare_jacobian(f, x):
20+
def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
21+
""" Use pytorch build-in jacobian function to compare for correctness of computations"""
1822
out = f(x)
19-
output = torch.autograd.functional.jacobian(f, x)
23+
output = torch.autograd.functional.jacobian(f, x)
2024
m = out.ndim
21-
output = output.movedim(m,1)
22-
res = torch.stack([output[i,i] for i in range(_batch_size)], dim=0)
25+
output = output.movedim(m, 1)
26+
res = torch.stack([output[i, i] for i in range(_batch_size)], dim=0)
2327
return res
2428

2529

26-
@pytest.mark.parametrize("model, input",
30+
@pytest.mark.parametrize(
31+
"model, input",
2732
[
2833
(nnj.Sequential(nnj.Identity(), nnj.Identity()), _linear_input),
2934
(nnj.Linear(_features, 2), _linear_input),
30-
(nnj.PosLinear(_features, 2), _linear_input),
35+
(nnj.Sequential(nnj.PosLinear(_features, 2), nnj.Reciprocal()), _linear_input),
3136
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ArcTanh()), _linear_input),
3237
(nnj.Sequential(nnj.Linear(_features, 5), nnj.Sigmoid(), nnj.Linear(5, 2)), _linear_input),
33-
(nnj.Sequential(
34-
nnj.Linear(_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh()
35-
), _linear_input),
36-
(nnj.Sequential(
37-
nnj.ELU(), nnj.Linear(_features, 2), nnj.Sigmoid(), nnj.ReLU(), nnj.Hardshrink(), nnj.LeakyReLU()
38-
), _linear_input),
38+
(
39+
nnj.Sequential(
40+
nnj.Linear(_features, 2), nnj.Softplus(beta=100, threshold=5), nnj.Linear(2, 4), nnj.Tanh()
41+
),
42+
_linear_input,
43+
),
44+
(
45+
nnj.Sequential(
46+
nnj.ELU(),
47+
nnj.Linear(_features, 2),
48+
nnj.Sigmoid(),
49+
nnj.ReLU(),
50+
nnj.Sqrt(),
51+
nnj.Hardshrink(),
52+
nnj.LeakyReLU(),
53+
),
54+
_linear_input,
55+
),
56+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.OneMinusX()), _linear_input),
3957
(nnj.Sequential(nnj.Conv1d(_features, 2, 5), nnj.ConvTranspose1d(2, _features, 5)), _1d_conv_input),
4058
(nnj.Sequential(nnj.Conv2d(_features, 2, 5), nnj.ConvTranspose2d(2, _features, 5)), _2d_conv_input),
4159
(nnj.Sequential(nnj.Conv3d(_features, 2, 5), nnj.ConvTranspose3d(2, _features, 5)), _3d_conv_input),
42-
(nnj.Sequential(
43-
nnj.Linear(_features, 8), nnj.Sigmoid(), nnj.Reshape(2, 4), nnj.Conv1d(2, 1, 2),
44-
),_linear_input),
45-
(nnj.Sequential(
46-
nnj.Linear(_features, 32), nnj.Sigmoid(), nnj.Reshape(2, 4, 4), nnj.Conv2d(2, 1, 2),
47-
),_linear_input),
48-
(nnj.Sequential(
49-
nnj.Linear(_features, 128), nnj.Sigmoid(), nnj.Reshape(2, 4, 4, 4), nnj.Conv3d(2, 1, 2),
50-
),_linear_input),
51-
(nnj.Sequential(
52-
nnj.Conv1d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*2, 5), nnj.ReLU(),
53-
),_1d_conv_input),
54-
(nnj.Sequential(
55-
nnj.Conv2d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*8*2, 5), nnj.ReLU(),
56-
),_2d_conv_input),
57-
(nnj.Sequential(
58-
nnj.Conv3d(_features, 2, 3), nnj.Flatten(), nnj.Linear(8*8*8*2, 5), nnj.ReLU(),
59-
),_3d_conv_input),
60-
(nnj.Sequential(
61-
nnj.Conv2d(_features, 2, 3), nnj.Hardtanh(), nnj.Upsample(scale_factor=2)
62-
), _2d_conv_input)
63-
]
60+
(
61+
nnj.Sequential(
62+
nnj.Linear(_features, 8),
63+
nnj.Sigmoid(),
64+
nnj.Reshape(2, 4),
65+
nnj.Conv1d(2, 1, 2),
66+
),
67+
_linear_input,
68+
),
69+
(
70+
nnj.Sequential(
71+
nnj.Linear(_features, 32),
72+
nnj.Sigmoid(),
73+
nnj.Reshape(2, 4, 4),
74+
nnj.Conv2d(2, 1, 2),
75+
),
76+
_linear_input,
77+
),
78+
(
79+
nnj.Sequential(
80+
nnj.Linear(_features, 128),
81+
nnj.Sigmoid(),
82+
nnj.Reshape(2, 4, 4, 4),
83+
nnj.Conv3d(2, 1, 2),
84+
),
85+
_linear_input,
86+
),
87+
(
88+
nnj.Sequential(
89+
nnj.Conv1d(_features, 2, 3),
90+
nnj.Flatten(),
91+
nnj.Linear(8 * 2, 5),
92+
nnj.ReLU(),
93+
),
94+
_1d_conv_input,
95+
),
96+
(
97+
nnj.Sequential(
98+
nnj.Conv2d(_features, 2, 3),
99+
nnj.Flatten(),
100+
nnj.Linear(8 * 8 * 2, 5),
101+
nnj.ReLU(),
102+
),
103+
_2d_conv_input,
104+
),
105+
(
106+
nnj.Sequential(
107+
nnj.Conv3d(_features, 2, 3),
108+
nnj.Flatten(),
109+
nnj.Linear(8 * 8 * 8 * 2, 5),
110+
nnj.ReLU(),
111+
),
112+
_3d_conv_input,
113+
),
114+
(
115+
nnj.Sequential(nnj.Conv2d(_features, 2, 3), nnj.Hardtanh(), nnj.Upsample(scale_factor=2)),
116+
_2d_conv_input,
117+
),
118+
],
64119
)
65120
class TestJacobian:
66121
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
67122
def test_jacobians(self, model, input, dtype):
68123
"""Test that the analytical jacobian of the model is consistent with finite
69124
order approximation
70125
"""
71-
model=deepcopy(model).to(dtype)
72-
input=deepcopy(input).to(dtype)
126+
model = deepcopy(model).to(dtype)
127+
input = deepcopy(input).to(dtype)
73128
_, jac = model(input, jacobian=True)
74129
jacnum = _compare_jacobian(model, input)
75130
assert torch.isclose(jac, jacnum, atol=1e-7).all(), "jacobians did not match"
76131

77-
78-
79132
@pytest.mark.parametrize("return_jac", [True, False])
80133
def test_jac_return(self, model, input, return_jac):
81134
""" Test that all models returns the jacobian output if asked for it """
82135
output = model(input, jacobian=return_jac)
83136
if return_jac:
84137
assert len(output) == 2, "expected two outputs when jacobian=True"
85-
assert all(isinstance(o, torch.Tensor) for o in output), "expected all outputs to be torch tensors"
138+
assert all(
139+
isinstance(o, torch.Tensor) for o in output
140+
), "expected all outputs to be torch tensors"
86141
else:
87142
assert isinstance(output, torch.Tensor)
88-

0 commit comments

Comments
 (0)