Skip to content

Commit 03bc6f5

Browse files
committed
fix dtype for nnj
1 parent 4117a9b commit 03bc6f5

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

stochman/nnj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def forward(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Tens
1818

1919
if jacobian:
2020
xs = x.shape
21-
jac = torch.eye(prod(xs[1:]), prod(xs[1:])).repeat(xs[0], 1, 1).reshape(xs[0], *xs[1:], *xs[1:])
21+
jac = torch.eye(prod(xs[1:]), prod(xs[1:]), dtype=x.dtype).repeat(xs[0], 1, 1).reshape(xs[0], *xs[1:], *xs[1:])
2222
return val, jac
2323
return val
2424

tests/test_nnj.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy
22
import pytest
33
import torch
4-
4+
from copy import deepcopy
55
from stochman import nnj
66

77
_batch_size = 2
@@ -63,14 +63,19 @@ def _compare_jacobian(f, x):
6363
]
6464
)
6565
class TestJacobian:
66-
def test_jacobians(self, model, input):
66+
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
67+
def test_jacobians(self, model, input, dtype):
6768
"""Test that the analytical jacobian of the model is consistent with finite
6869
order approximation
6970
"""
71+
model=deepcopy(model).to(dtype)
72+
input=deepcopy(input).to(dtype)
7073
_, jac = model(input, jacobian=True)
7174
jacnum = _compare_jacobian(model, input)
7275
assert torch.isclose(jac, jacnum, atol=1e-7).all(), "jacobians did not match"
7376

77+
78+
7479
@pytest.mark.parametrize("return_jac", [True, False])
7580
def test_jac_return(self, model, input, return_jac):
7681
""" Test that all models returns the jacobian output if asked for it """

0 commit comments

Comments
 (0)