9
9
10
10
_batch_size = 2
11
11
_features = 5
12
- _dims = 10
12
+ _dims = 6
13
13
14
14
_linear_input = torch .randn (_batch_size , _features )
15
15
_1d_conv_input = torch .randn (_batch_size , _features , _dims )
@@ -49,10 +49,10 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
49
49
nnj .ReLU (),
50
50
nnj .Sqrt (),
51
51
nnj .Hardshrink (),
52
- nnj .LeakyReLU (),
53
52
),
54
53
_linear_input ,
55
54
),
55
+ (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .LeakyReLU ()), _linear_input ),
56
56
(nnj .Sequential (nnj .Linear (_features , 2 ), nnj .OneMinusX ()), _linear_input ),
57
57
(nnj .Sequential (nnj .Conv1d (_features , 2 , 5 ), nnj .ConvTranspose1d (2 , _features , 5 )), _1d_conv_input ),
58
58
(nnj .Sequential (nnj .Conv2d (_features , 2 , 5 ), nnj .ConvTranspose2d (2 , _features , 5 )), _2d_conv_input ),
@@ -88,7 +88,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
88
88
nnj .Sequential (
89
89
nnj .Conv1d (_features , 2 , 3 ),
90
90
nnj .Flatten (),
91
- nnj .Linear (8 * 2 , 5 ),
91
+ nnj .Linear (4 * 2 , 5 ),
92
92
nnj .ReLU (),
93
93
),
94
94
_1d_conv_input ,
@@ -97,7 +97,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
97
97
nnj .Sequential (
98
98
nnj .Conv2d (_features , 2 , 3 ),
99
99
nnj .Flatten (),
100
- nnj .Linear (8 * 8 * 2 , 5 ),
100
+ nnj .Linear (4 * 4 * 2 , 5 ),
101
101
nnj .ReLU (),
102
102
),
103
103
_2d_conv_input ,
@@ -106,7 +106,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
106
106
nnj .Sequential (
107
107
nnj .Conv3d (_features , 2 , 3 ),
108
108
nnj .Flatten (),
109
- nnj .Linear (8 * 8 * 8 * 2 , 5 ),
109
+ nnj .Linear (4 * 4 * 4 * 2 , 5 ),
110
110
nnj .ReLU (),
111
111
),
112
112
_3d_conv_input ,
0 commit comments