Skip to content

Commit 6e2887f

Browse files
committed
add PReLU
1 parent bfa99f7 commit 6e2887f

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

stochman/nnj.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,13 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
320320
return jac
321321

322322

323+
class PReLU(AbstractActivationJacobian, nn.PReLU):
324+
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
325+
jac = torch.ones_like(val)
326+
jac[x < 0.0] = self.weight
327+
return jac
328+
329+
323330
class LeakyReLU(AbstractActivationJacobian, nn.LeakyReLU):
324331
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
325332
jac = torch.ones_like(val)

tests/test_nnj.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
5353
(nnj.Sequential(nnj.Linear(_features, 2), nnj.LeakyReLU()), _linear_input_shape),
5454
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Tanh()), _linear_input_shape),
5555
(nnj.Sequential(nnj.Linear(_features, 2), nnj.OneMinusX()), _linear_input_shape),
56+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.PReLU()), _linear_input_shape),
5657
(
5758
nnj.Sequential(nnj.Conv1d(_features, 2, 5), nnj.ConvTranspose1d(2, _features, 5)),
5859
_1d_conv_input_shape,

0 commit comments

Comments
 (0)