Skip to content

Commit fd02ab2

Browse files
committed
add gpu support
1 parent 6e2887f commit fd02ab2

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

stochman/nnj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def forward(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Tens
1818
if jacobian:
1919
xs = x.shape
2020
jac = (
21-
torch.eye(prod(xs[1:]), prod(xs[1:]), dtype=x.dtype)
21+
torch.eye(prod(xs[1:]), prod(xs[1:]), dtype=x.dtype, device=x.device)
2222
.repeat(xs[0], 1, 1)
2323
.reshape(xs[0], *xs[1:], *xs[1:])
2424
)

tests/test_nnj.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,27 +129,34 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
129129
(nnj.Sequential(nnj.Conv3d(_features, 3, 3), nnj.BatchNorm3d(3)), _3d_conv_input_shape),
130130
],
131131
)
132+
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
132133
class TestJacobian:
133134
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
134-
def test_jacobians(self, model, input_shape, dtype):
135+
def test_jacobians(self, model, input_shape, device, dtype):
135136
"""Test that the analytical jacobian of the model is consistent with finite
136137
order approximation
137138
"""
138-
model = deepcopy(model).to(dtype).eval()
139-
input = torch.randn(*input_shape, dtype=dtype)
139+
if device == "cuda" and not torch.cuda.is_available():
140+
pytest.skip("Test requires cuda support")
141+
142+
model = deepcopy(model).to(device=device, dtype=dtype).eval()
143+
input = torch.randn(*input_shape, device=device, dtype=dtype)
140144
_, jac = model(input, jacobian=True)
141-
jacnum = _compare_jacobian(model, input)
142-
assert torch.isclose(jac, jacnum, atol=1e-7).all(), "jacobians did not match"
145+
jacnum = _compare_jacobian(model, input).to(device)
146+
assert torch.isclose(jac, jacnum, atol=1e-5).all(), "jacobians did not match"
143147

144148
@pytest.mark.parametrize("return_jac", [True, False])
145-
def test_jac_return(self, model, input_shape, return_jac):
149+
def test_jac_return(self, model, input_shape, device, return_jac):
146150
""" Test that all models returns the jacobian output if asked for it """
147-
148-
output = model(torch.randn(*input_shape), jacobian=return_jac)
151+
input = torch.randn(*input_shape, device=device)
152+
model = deepcopy(model).to(device)
153+
output = model(input, jacobian=return_jac)
149154
if return_jac:
150155
assert len(output) == 2, "expected two outputs when jacobian=True"
151156
assert all(
152157
isinstance(o, torch.Tensor) for o in output
153158
), "expected all outputs to be torch tensors"
159+
assert all(str(o.device) == device for o in output)
154160
else:
155161
assert isinstance(output, torch.Tensor)
162+
assert str(output.device) == device

0 commit comments

Comments
 (0)