@@ -54,6 +54,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
54
54
(nnj .Sequential (nnj .Linear (_features , 2 ), nnj .Tanh ()), _linear_input_shape ),
55
55
(nnj .Sequential (nnj .Linear (_features , 2 ), nnj .OneMinusX ()), _linear_input_shape ),
56
56
(nnj .Sequential (nnj .Linear (_features , 2 ), nnj .PReLU ()), _linear_input_shape ),
57
+ (nnj .Sequential (nnj .Linear (_features , 2 ), nnj .Softmax (dim = - 1 )), _linear_input_shape ),
57
58
(
58
59
nnj .Sequential (nnj .Conv1d (_features , 2 , 5 ), nnj .ConvTranspose1d (2 , _features , 5 )),
59
60
_1d_conv_input_shape ,
@@ -98,7 +99,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
98
99
nnj .Conv1d (_features , 2 , 3 ),
99
100
nnj .Flatten (),
100
101
nnj .Linear (4 * 2 , 5 ),
101
- nnj .ReLU ( ),
102
+ nnj .Softmax ( dim = - 1 ),
102
103
),
103
104
_1d_conv_input_shape ,
104
105
),
@@ -107,7 +108,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
107
108
nnj .Conv2d (_features , 2 , 3 ),
108
109
nnj .Flatten (),
109
110
nnj .Linear (4 * 4 * 2 , 5 ),
110
- nnj .ReLU ( ),
111
+ nnj .Softmax ( dim = - 1 ),
111
112
),
112
113
_2d_conv_input_shape ,
113
114
),
@@ -116,7 +117,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
116
117
nnj .Conv3d (_features , 2 , 3 ),
117
118
nnj .Flatten (),
118
119
nnj .Linear (4 * 4 * 4 * 2 , 5 ),
119
- nnj .ReLU ( ),
120
+ nnj .Softmax ( dim = - 1 ),
120
121
),
121
122
_3d_conv_input_shape ,
122
123
),
0 commit comments