Skip to content

Commit 846820a

Browse files
committed
formatting
1 parent a4f4069 commit 846820a

File tree

1 file changed

+128
-38
lines changed

1 file changed

+128
-38
lines changed

stochman/nnj.py

Lines changed: 128 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
from math import prod
88

9+
910
class Identity(nn.Module):
1011
""" Identity module that will return the same input as it receives. """
12+
1113
def __init__(self):
1214
super().__init__()
1315

1416
def forward(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1517
val = x
16-
18+
1719
if jacobian:
1820
xs = x.shape
1921
jac = torch.eye(prod(xs[1:]), prod(xs[1:])).repeat(xs[0], 1, 1).reshape(xs[0], *xs[1:], *xs[1:])
@@ -32,7 +34,10 @@ def identity(x: Tensor) -> Tensor:
3234

3335
class Sequential(nn.Sequential):
3436
""" Subclass of sequential that also supports calculating the jacobian through an network """
35-
def forward(self, x: Tensor, jacobian: Union[Tensor, bool] = False) -> Union[Tensor, Tuple[Tensor, Tensor]]:
37+
38+
def forward(
39+
self, x: Tensor, jacobian: Union[Tensor, bool] = False
40+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
3641
if jacobian:
3742
j = identity(x) if (not isinstance(jacobian, Tensor) and jacobian) else jacobian
3843
for module in self._modules.values():
@@ -46,9 +51,10 @@ def forward(self, x: Tensor, jacobian: Union[Tensor, bool] = False) -> Union[Ten
4651

4752

4853
class AbstractJacobian:
49-
""" Abstract class that will overwrite the default behaviour of the forward method such that it
50-
is also possible to return the jacobian
54+
"""Abstract class that will overwrite the default behaviour of the forward method such that it
55+
is also possible to return the jacobian
5156
"""
57+
5258
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
5359
return self._jacobian_mult(x, val, identity(x))
5460

@@ -62,7 +68,7 @@ def __call__(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Ten
6268

6369
class Linear(AbstractJacobian, nn.Linear):
6470
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
65-
return F.linear(jac_in.movedim(1,-1), self.weight, bias=None).movedim(-1,1)
71+
return F.linear(jac_in.movedim(1, -1), self.weight, bias=None).movedim(-1, 1)
6672

6773

6874
class PosLinear(AbstractJacobian, nn.Linear):
@@ -74,82 +80,166 @@ def forward(self, x: Tensor):
7480
return val
7581

7682
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
77-
return F.linear(jac_in.movedim(1,-1), F.softplus(self.weight), bias=None).movedim(-1,1)
83+
return F.linear(jac_in.movedim(1, -1), F.softplus(self.weight), bias=None).movedim(-1, 1)
7884

7985

8086
class Upsample(AbstractJacobian, nn.Upsample):
8187
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
8288
xs = x.shape
8389
vs = val.shape
8490
if x.ndim == 3:
85-
return F.interpolate(jac_in.movedim((1,2),(-2,-1)).reshape(-1, *xs[1:]),
86-
self.size, self.scale_factor, self.mode, self.align_corners
87-
).reshape(xs[0], *jac_in.shape[3:], *vs[1:]).movedim((-2, -1), (1, 2))
91+
return (
92+
F.interpolate(
93+
jac_in.movedim((1, 2), (-2, -1)).reshape(-1, *xs[1:]),
94+
self.size,
95+
self.scale_factor,
96+
self.mode,
97+
self.align_corners,
98+
)
99+
.reshape(xs[0], *jac_in.shape[3:], *vs[1:])
100+
.movedim((-2, -1), (1, 2))
101+
)
88102
if x.ndim == 4:
89-
return F.interpolate(jac_in.movedim((1,2,3),(-3,-2,-1)).reshape(-1, *xs[1:]),
90-
self.size, self.scale_factor, self.mode, self.align_corners
91-
).reshape(xs[0], *jac_in.shape[4:], *vs[1:]).movedim((-3, -2, -1), (1, 2, 3))
103+
return (
104+
F.interpolate(
105+
jac_in.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, *xs[1:]),
106+
self.size,
107+
self.scale_factor,
108+
self.mode,
109+
self.align_corners,
110+
)
111+
.reshape(xs[0], *jac_in.shape[4:], *vs[1:])
112+
.movedim((-3, -2, -1), (1, 2, 3))
113+
)
92114
if x.ndim == 5:
93-
return F.interpolate(jac_in.movedim((1,2,3,4),(-4,-3,-2,-1)).reshape(-1, *xs[1:]),
94-
self.size, self.scale_factor, self.mode, self.align_corners
95-
).reshape(xs[0], *jac_in.shape[5:], *vs[1:]).movedim((-4,-3,-2, -1), (1, 2, 3, 4))
115+
return (
116+
F.interpolate(
117+
jac_in.movedim((1, 2, 3, 4), (-4, -3, -2, -1)).reshape(-1, *xs[1:]),
118+
self.size,
119+
self.scale_factor,
120+
self.mode,
121+
self.align_corners,
122+
)
123+
.reshape(xs[0], *jac_in.shape[5:], *vs[1:])
124+
.movedim((-4, -3, -2, -1), (1, 2, 3, 4))
125+
)
96126

97127

98128
class Conv1d(AbstractJacobian, nn.Conv1d):
99129
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
100130
b, c1, l1 = x.shape
101131
c2, l2 = val.shape[1:]
102-
return F.conv1d(jac_in.movedim((1, 2), (-2, -1)).reshape(-1, c1, l1), weight=self.weight,
103-
bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups,
104-
).reshape(b, *jac_in.shape[3:], c2, l2).movedim((-2, -1), (1, 2))
132+
return (
133+
F.conv1d(
134+
jac_in.movedim((1, 2), (-2, -1)).reshape(-1, c1, l1),
135+
weight=self.weight,
136+
bias=None,
137+
stride=self.stride,
138+
padding=self.padding,
139+
dilation=self.dilation,
140+
groups=self.groups,
141+
)
142+
.reshape(b, *jac_in.shape[3:], c2, l2)
143+
.movedim((-2, -1), (1, 2))
144+
)
105145

106146

107147
class ConvTranspose1d(AbstractJacobian, nn.ConvTranspose1d):
108148
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
109149
b, c1, l1 = x.shape
110150
c2, l2 = val.shape[1:]
111-
return F.conv_transpose1d(jac_in.movedim((1, 2), (-2, -1)).reshape(-1, c1, l1), weight=self.weight,
112-
bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups,
113-
output_padding=self.output_padding
114-
).reshape(b, *jac_in.shape[3:], c2, l2).movedim((-2, -1), (1, 2))
151+
return (
152+
F.conv_transpose1d(
153+
jac_in.movedim((1, 2), (-2, -1)).reshape(-1, c1, l1),
154+
weight=self.weight,
155+
bias=None,
156+
stride=self.stride,
157+
padding=self.padding,
158+
dilation=self.dilation,
159+
groups=self.groups,
160+
output_padding=self.output_padding,
161+
)
162+
.reshape(b, *jac_in.shape[3:], c2, l2)
163+
.movedim((-2, -1), (1, 2))
164+
)
115165

116166

117167
class Conv2d(AbstractJacobian, nn.Conv2d):
118168
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
119169
b, c1, h1, w1 = x.shape
120170
c2, h2, w2 = val.shape[1:]
121-
return F.conv2d(jac_in.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, c1, h1, w1), weight=self.weight,
122-
bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups,
123-
).reshape(b, *jac_in.shape[4:], c2, h2, w2).movedim((-3, -2, -1), (1, 2, 3))
171+
return (
172+
F.conv2d(
173+
jac_in.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, c1, h1, w1),
174+
weight=self.weight,
175+
bias=None,
176+
stride=self.stride,
177+
padding=self.padding,
178+
dilation=self.dilation,
179+
groups=self.groups,
180+
)
181+
.reshape(b, *jac_in.shape[4:], c2, h2, w2)
182+
.movedim((-3, -2, -1), (1, 2, 3))
183+
)
124184

125185

126186
class ConvTranspose2d(AbstractJacobian, nn.ConvTranspose2d):
127187
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
128188
b, c1, h1, w1 = x.shape
129189
c2, h2, w2 = val.shape[1:]
130-
return F.conv_transpose2d(jac_in.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, c1, h1, w1), weight=self.weight,
131-
bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups,
132-
output_padding=self.output_padding,
133-
).reshape(b, *jac_in.shape[4:], c2, h2, w2).movedim((-3, -2, -1), (1, 2, 3))
190+
return (
191+
F.conv_transpose2d(
192+
jac_in.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, c1, h1, w1),
193+
weight=self.weight,
194+
bias=None,
195+
stride=self.stride,
196+
padding=self.padding,
197+
dilation=self.dilation,
198+
groups=self.groups,
199+
output_padding=self.output_padding,
200+
)
201+
.reshape(b, *jac_in.shape[4:], c2, h2, w2)
202+
.movedim((-3, -2, -1), (1, 2, 3))
203+
)
134204

135205

136206
class Conv3d(AbstractJacobian, nn.Conv3d):
137207
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
138208
b, c1, d1, h1, w1 = x.shape
139209
c2, d2, h2, w2 = val.shape[1:]
140-
return F.conv3d(jac_in.movedim((1, 2, 3, 4), (-4, -3, -2, -1)).reshape(-1, c1, d1, h1, w1), weight=self.weight,
141-
bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups,
142-
).reshape(b, *jac_in.shape[5:], c2, d2, h2, w2).movedim((-4, -3, -2, -1), (1, 2, 3, 4))
210+
return (
211+
F.conv3d(
212+
jac_in.movedim((1, 2, 3, 4), (-4, -3, -2, -1)).reshape(-1, c1, d1, h1, w1),
213+
weight=self.weight,
214+
bias=None,
215+
stride=self.stride,
216+
padding=self.padding,
217+
dilation=self.dilation,
218+
groups=self.groups,
219+
)
220+
.reshape(b, *jac_in.shape[5:], c2, d2, h2, w2)
221+
.movedim((-4, -3, -2, -1), (1, 2, 3, 4))
222+
)
143223

144224

145225
class ConvTranspose3d(AbstractJacobian, nn.ConvTranspose3d):
146226
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
147227
b, c1, d1, h1, w1 = x.shape
148228
c2, d2, h2, w2 = val.shape[1:]
149-
return F.conv_transpose3d(jac_in.movedim((1, 2, 3, 4), (-4, -3, -2, -1)).reshape(-1, c1, d1, h1, w1), weight=self.weight,
150-
bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups,
151-
output_padding=self.output_padding
152-
).reshape(b, *jac_in.shape[5:], c2, d2, h2, w2).movedim((-4, -3, -2, -1), (1, 2, 3, 4))
229+
return (
230+
F.conv_transpose3d(
231+
jac_in.movedim((1, 2, 3, 4), (-4, -3, -2, -1)).reshape(-1, c1, d1, h1, w1),
232+
weight=self.weight,
233+
bias=None,
234+
stride=self.stride,
235+
padding=self.padding,
236+
dilation=self.dilation,
237+
groups=self.groups,
238+
output_padding=self.output_padding,
239+
)
240+
.reshape(b, *jac_in.shape[5:], c2, d2, h2, w2)
241+
.movedim((-4, -3, -2, -1), (1, 2, 3, 4))
242+
)
153243

154244

155245
class Reshape(AbstractJacobian, nn.Module):
@@ -186,14 +276,14 @@ class AbstractActivationJacobian:
186276
def _jacobian_mult(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
187277
jac = self._jacobian(x, val)
188278
n = jac_in.ndim - jac.ndim
189-
return jac_in * jac.reshape(jac.shape + (1,)*n)
279+
return jac_in * jac.reshape(jac.shape + (1,) * n)
190280

191281
def __call__(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]:
192282
val = self._call_impl(x)
193283
if jacobian:
194284
jac = self._jacobian(x, val)
195285
return val, jac
196-
return val
286+
return val
197287

198288

199289
class Sigmoid(AbstractActivationJacobian, nn.Sigmoid):

0 commit comments

Comments
 (0)