11
11
_features = 5
12
12
_dims = 6
13
13
14
- _linear_input = torch . randn (_batch_size , _features )
15
- _1d_conv_input = torch . randn (_batch_size , _features , _dims )
16
- _2d_conv_input = torch . randn (_batch_size , _features , _dims , _dims )
17
- _3d_conv_input = torch . randn (_batch_size , _features , _dims , _dims , _dims )
14
+ _linear_input_shape = (_batch_size , _features )
15
+ _1d_conv_input_shape = (_batch_size , _features , _dims )
16
+ _2d_conv_input_shape = (_batch_size , _features , _dims , _dims )
17
+ _3d_conv_input_shape = (_batch_size , _features , _dims , _dims , _dims )
18
18
19
19
20
20
def _compare_jacobian (f : Callable , x : torch .Tensor ) -> torch .Tensor :
@@ -28,18 +28,16 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
28
28
29
29
30
30
@pytest .mark .parametrize (
31
- "model, input " ,
31
+ "model, input_shape " ,
32
32
[
33
- (nnj .Sequential (nnj .Identity (), nnj .Identity ()), _linear_input ),
34
- (nnj .Linear (_features , 2 ), _linear_input ),
35
- (nnj .Sequential (nnj .PosLinear (_features , 2 ), nnj .Reciprocal ()), _linear_input ),
36
- (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .Sigmoid (), nnj .ArcTanh ()), _linear_input ),
37
- (nnj .Sequential (nnj .Linear (_features , 5 ), nnj .Sigmoid (), nnj .Linear (5 , 2 )), _linear_input ),
33
+ (nnj .Sequential (nnj .Identity (), nnj .Identity ()), _linear_input_shape ),
34
+ (nnj .Linear (_features , 2 ), _linear_input_shape ),
35
+ (nnj .Sequential (nnj .PosLinear (_features , 2 ), nnj .Reciprocal ()), _linear_input_shape ),
36
+ (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .Sigmoid (), nnj .ArcTanh ()), _linear_input_shape ),
37
+ (nnj .Sequential (nnj .Linear (_features , 5 ), nnj .Sigmoid (), nnj .Linear (5 , 2 )), _linear_input_shape ),
38
38
(
39
- nnj .Sequential (
40
- nnj .Linear (_features , 2 ), nnj .Softplus (beta = 100 , threshold = 5 ), nnj .Linear (2 , 4 ), nnj .Tanh ()
41
- ),
42
- _linear_input ,
39
+ nnj .Sequential (nnj .Linear (_features , 2 ), nnj .Softplus (beta = 100 , threshold = 5 ), nnj .Linear (2 , 4 )),
40
+ _linear_input_shape ,
43
41
),
44
42
(
45
43
nnj .Sequential (
@@ -50,21 +48,31 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
50
48
nnj .Sqrt (),
51
49
nnj .Hardshrink (),
52
50
),
53
- _linear_input ,
51
+ _linear_input_shape ,
52
+ ),
53
+ (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .LeakyReLU ()), _linear_input_shape ),
54
+ (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .Tanh ()), _linear_input_shape ),
55
+ (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .OneMinusX ()), _linear_input_shape ),
56
+ (
57
+ nnj .Sequential (nnj .Conv1d (_features , 2 , 5 ), nnj .ConvTranspose1d (2 , _features , 5 )),
58
+ _1d_conv_input_shape ,
59
+ ),
60
+ (
61
+ nnj .Sequential (nnj .Conv2d (_features , 2 , 5 ), nnj .ConvTranspose2d (2 , _features , 5 )),
62
+ _2d_conv_input_shape ,
63
+ ),
64
+ (
65
+ nnj .Sequential (nnj .Conv3d (_features , 2 , 5 ), nnj .ConvTranspose3d (2 , _features , 5 )),
66
+ _3d_conv_input_shape ,
54
67
),
55
- (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .LeakyReLU ()), _linear_input ),
56
- (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .OneMinusX ()), _linear_input ),
57
- (nnj .Sequential (nnj .Conv1d (_features , 2 , 5 ), nnj .ConvTranspose1d (2 , _features , 5 )), _1d_conv_input ),
58
- (nnj .Sequential (nnj .Conv2d (_features , 2 , 5 ), nnj .ConvTranspose2d (2 , _features , 5 )), _2d_conv_input ),
59
- (nnj .Sequential (nnj .Conv3d (_features , 2 , 5 ), nnj .ConvTranspose3d (2 , _features , 5 )), _3d_conv_input ),
60
68
(
61
69
nnj .Sequential (
62
70
nnj .Linear (_features , 8 ),
63
71
nnj .Sigmoid (),
64
72
nnj .Reshape (2 , 4 ),
65
73
nnj .Conv1d (2 , 1 , 2 ),
66
74
),
67
- _linear_input ,
75
+ _linear_input_shape ,
68
76
),
69
77
(
70
78
nnj .Sequential (
@@ -73,7 +81,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
73
81
nnj .Reshape (2 , 4 , 4 ),
74
82
nnj .Conv2d (2 , 1 , 2 ),
75
83
),
76
- _linear_input ,
84
+ _linear_input_shape ,
77
85
),
78
86
(
79
87
nnj .Sequential (
@@ -82,7 +90,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
82
90
nnj .Reshape (2 , 4 , 4 , 4 ),
83
91
nnj .Conv3d (2 , 1 , 2 ),
84
92
),
85
- _linear_input ,
93
+ _linear_input_shape ,
86
94
),
87
95
(
88
96
nnj .Sequential (
@@ -91,7 +99,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
91
99
nnj .Linear (4 * 2 , 5 ),
92
100
nnj .ReLU (),
93
101
),
94
- _1d_conv_input ,
102
+ _1d_conv_input_shape ,
95
103
),
96
104
(
97
105
nnj .Sequential (
@@ -100,7 +108,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
100
108
nnj .Linear (4 * 4 * 2 , 5 ),
101
109
nnj .ReLU (),
102
110
),
103
- _2d_conv_input ,
111
+ _2d_conv_input_shape ,
104
112
),
105
113
(
106
114
nnj .Sequential (
@@ -109,30 +117,34 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
109
117
nnj .Linear (4 * 4 * 4 * 2 , 5 ),
110
118
nnj .ReLU (),
111
119
),
112
- _3d_conv_input ,
120
+ _3d_conv_input_shape ,
113
121
),
114
122
(
115
123
nnj .Sequential (nnj .Conv2d (_features , 2 , 3 ), nnj .Hardtanh (), nnj .Upsample (scale_factor = 2 )),
116
- _2d_conv_input ,
124
+ _2d_conv_input_shape ,
117
125
),
126
+ (nnj .Sequential (nnj .Conv1d (_features , 3 , 3 ), nnj .BatchNorm1d (3 )), _1d_conv_input_shape ),
127
+ (nnj .Sequential (nnj .Conv2d (_features , 3 , 3 ), nnj .BatchNorm2d (3 )), _2d_conv_input_shape ),
128
+ (nnj .Sequential (nnj .Conv3d (_features , 3 , 3 ), nnj .BatchNorm3d (3 )), _3d_conv_input_shape ),
118
129
],
119
130
)
120
131
class TestJacobian :
121
132
@pytest .mark .parametrize ("dtype" , [torch .float , torch .double ])
122
- def test_jacobians (self , model , input , dtype ):
133
+ def test_jacobians (self , model , input_shape , dtype ):
123
134
"""Test that the analytical jacobian of the model is consistent with finite
124
135
order approximation
125
136
"""
126
- model = deepcopy (model ).to (dtype )
127
- input = deepcopy ( input ). to ( dtype )
137
+ model = deepcopy (model ).to (dtype ). eval ()
138
+ input = torch . randn ( * input_shape , dtype = dtype )
128
139
_ , jac = model (input , jacobian = True )
129
140
jacnum = _compare_jacobian (model , input )
130
141
assert torch .isclose (jac , jacnum , atol = 1e-7 ).all (), "jacobians did not match"
131
142
132
143
@pytest .mark .parametrize ("return_jac" , [True , False ])
133
- def test_jac_return (self , model , input , return_jac ):
144
+ def test_jac_return (self , model , input_shape , return_jac ):
134
145
""" Test that all models returns the jacobian output if asked for it """
135
- output = model (input , jacobian = return_jac )
146
+
147
+ output = model (torch .randn (* input_shape ), jacobian = return_jac )
136
148
if return_jac :
137
149
assert len (output ) == 2 , "expected two outputs when jacobian=True"
138
150
assert all (
0 commit comments