Skip to content

Commit 8bca98e

Browse files
committed
fix leakyrelu
1 parent 8639ee2 commit 8bca98e

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

stochman/nnj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,8 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
301301

302302
class LeakyReLU(AbstractActivationJacobian, nn.LeakyReLU):
303303
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
304-
jac = torch.zeros_like(val)
305-
jac[val.abs() < 1.0] = 1.0
304+
jac = torch.ones_like(val)
305+
jac[val < 0.0] = self.negative_slope
306306
return jac
307307

308308

tests/test_nnj.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
_batch_size = 2
1111
_features = 5
12-
_dims = 10
12+
_dims = 6
1313

1414
_linear_input = torch.randn(_batch_size, _features)
1515
_1d_conv_input = torch.randn(_batch_size, _features, _dims)
@@ -49,10 +49,10 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
4949
nnj.ReLU(),
5050
nnj.Sqrt(),
5151
nnj.Hardshrink(),
52-
nnj.LeakyReLU(),
5352
),
5453
_linear_input,
5554
),
55+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.LeakyReLU()), _linear_input),
5656
(nnj.Sequential(nnj.Linear(_features, 2), nnj.OneMinusX()), _linear_input),
5757
(nnj.Sequential(nnj.Conv1d(_features, 2, 5), nnj.ConvTranspose1d(2, _features, 5)), _1d_conv_input),
5858
(nnj.Sequential(nnj.Conv2d(_features, 2, 5), nnj.ConvTranspose2d(2, _features, 5)), _2d_conv_input),
@@ -88,7 +88,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
8888
nnj.Sequential(
8989
nnj.Conv1d(_features, 2, 3),
9090
nnj.Flatten(),
91-
nnj.Linear(8 * 2, 5),
91+
nnj.Linear(4 * 2, 5),
9292
nnj.ReLU(),
9393
),
9494
_1d_conv_input,
@@ -97,7 +97,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
9797
nnj.Sequential(
9898
nnj.Conv2d(_features, 2, 3),
9999
nnj.Flatten(),
100-
nnj.Linear(8 * 8 * 2, 5),
100+
nnj.Linear(4 * 4 * 2, 5),
101101
nnj.ReLU(),
102102
),
103103
_2d_conv_input,
@@ -106,7 +106,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
106106
nnj.Sequential(
107107
nnj.Conv3d(_features, 2, 3),
108108
nnj.Flatten(),
109-
nnj.Linear(8 * 8 * 8 * 2, 5),
109+
nnj.Linear(4 * 4 * 4 * 2, 5),
110110
nnj.ReLU(),
111111
),
112112
_3d_conv_input,

0 commit comments

Comments
 (0)