@@ -129,27 +129,34 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
129
129
(nnj .Sequential (nnj .Conv3d (_features , 3 , 3 ), nnj .BatchNorm3d (3 )), _3d_conv_input_shape ),
130
130
],
131
131
)
132
+ @pytest .mark .parametrize ("device" , ["cpu" , "cuda:0" ])
132
133
class TestJacobian :
133
134
@pytest .mark .parametrize ("dtype" , [torch .float , torch .double ])
134
- def test_jacobians (self , model , input_shape , dtype ):
135
+ def test_jacobians (self , model , input_shape , device , dtype ):
135
136
"""Test that the analytical jacobian of the model is consistent with finite
136
137
order approximation
137
138
"""
138
- model = deepcopy (model ).to (dtype ).eval ()
139
- input = torch .randn (* input_shape , dtype = dtype )
139
+ if device == "cuda" and not torch .cuda .is_available ():
140
+ pytest .skip ("Test requires cuda support" )
141
+
142
+ model = deepcopy (model ).to (device = device , dtype = dtype ).eval ()
143
+ input = torch .randn (* input_shape , device = device , dtype = dtype )
140
144
_ , jac = model (input , jacobian = True )
141
- jacnum = _compare_jacobian (model , input )
142
- assert torch .isclose (jac , jacnum , atol = 1e-7 ).all (), "jacobians did not match"
145
+ jacnum = _compare_jacobian (model , input ). to ( device )
146
+ assert torch .isclose (jac , jacnum , atol = 1e-5 ).all (), "jacobians did not match"
143
147
144
148
@pytest .mark .parametrize ("return_jac" , [True , False ])
145
- def test_jac_return (self , model , input_shape , return_jac ):
149
+ def test_jac_return (self , model , input_shape , device , return_jac ):
146
150
""" Test that all models returns the jacobian output if asked for it """
147
-
148
- output = model (torch .randn (* input_shape ), jacobian = return_jac )
151
+ input = torch .randn (* input_shape , device = device )
152
+ model = deepcopy (model ).to (device )
153
+ output = model (input , jacobian = return_jac )
149
154
if return_jac :
150
155
assert len (output ) == 2 , "expected two outputs when jacobian=True"
151
156
assert all (
152
157
isinstance (o , torch .Tensor ) for o in output
153
158
), "expected all outputs to be torch tensors"
159
+ assert all (str (o .device ) == device for o in output )
154
160
else :
155
161
assert isinstance (output , torch .Tensor )
162
+ assert str (output .device ) == device
0 commit comments