Skip to content

Commit 14b7a32

Browse files
committed
support softmax
1 parent fd02ab2 commit 14b7a32

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

stochman/nnj.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,26 @@ def __call__(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Ten
266266
return val
267267

268268

269+
class Softmax(AbstractActivationJacobian, nn.Softmax):
270+
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
271+
if self.dim == 0:
272+
raise ValueError("Jacobian computation not supported for `dim=0`")
273+
jac = torch.diag_embed(val) - torch.matmul(val.unsqueeze(-1), val.unsqueeze(-2))
274+
return jac
275+
276+
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
277+
jac = self._jacobian(x, val)
278+
n = jac_in.ndim - jac.ndim
279+
jac = jac.reshape((1,) * n + jac.shape)
280+
if jac_in.ndim == 4:
281+
return (jac @ jac_in.permute(3, 0, 1, 2)).permute(1, 2, 3, 0)
282+
if jac_in.ndim == 5:
283+
return (jac @ jac_in.permute(3, 4, 0, 1, 2)).permute(2, 3, 4, 0, 1)
284+
if jac_in.ndim == 6:
285+
return (jac @ jac_in.permute(3, 4, 5, 0, 1, 2)).permute(3, 4, 5, 0, 1, 2)
286+
return jac @ jac_in
287+
288+
269289
class BatchNorm1d(AbstractActivationJacobian, nn.BatchNorm1d):
270290
# only implements jacobian during testing
271291
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:

tests/test_nnj.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
5454
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Tanh()), _linear_input_shape),
5555
(nnj.Sequential(nnj.Linear(_features, 2), nnj.OneMinusX()), _linear_input_shape),
5656
(nnj.Sequential(nnj.Linear(_features, 2), nnj.PReLU()), _linear_input_shape),
57+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Softmax(dim=-1)), _linear_input_shape),
5758
(
5859
nnj.Sequential(nnj.Conv1d(_features, 2, 5), nnj.ConvTranspose1d(2, _features, 5)),
5960
_1d_conv_input_shape,
@@ -98,7 +99,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
9899
nnj.Conv1d(_features, 2, 3),
99100
nnj.Flatten(),
100101
nnj.Linear(4 * 2, 5),
101-
nnj.ReLU(),
102+
nnj.Softmax(dim=-1),
102103
),
103104
_1d_conv_input_shape,
104105
),
@@ -107,7 +108,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
107108
nnj.Conv2d(_features, 2, 3),
108109
nnj.Flatten(),
109110
nnj.Linear(4 * 4 * 2, 5),
110-
nnj.ReLU(),
111+
nnj.Softmax(dim=-1),
111112
),
112113
_2d_conv_input_shape,
113114
),
@@ -116,7 +117,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
116117
nnj.Conv3d(_features, 2, 3),
117118
nnj.Flatten(),
118119
nnj.Linear(4 * 4 * 4 * 2, 5),
119-
nnj.ReLU(),
120+
nnj.Softmax(dim=-1),
120121
),
121122
_3d_conv_input_shape,
122123
),

0 commit comments

Comments
 (0)