6
6
7
7
from math import prod
8
8
9
+
9
10
class Identity (nn .Module ):
10
11
""" Identity module that will return the same input as it receives. """
12
+
11
13
def __init__ (self ):
12
14
super ().__init__ ()
13
15
14
16
def forward (self , x : Tensor , jacobian : bool = False ) -> Union [Tensor , Tuple [Tensor , Tensor ]]:
15
17
val = x
16
-
18
+
17
19
if jacobian :
18
20
xs = x .shape
19
21
jac = torch .eye (prod (xs [1 :]), prod (xs [1 :])).repeat (xs [0 ], 1 , 1 ).reshape (xs [0 ], * xs [1 :], * xs [1 :])
@@ -32,7 +34,10 @@ def identity(x: Tensor) -> Tensor:
32
34
33
35
class Sequential (nn .Sequential ):
34
36
""" Subclass of sequential that also supports calculating the jacobian through an network """
35
- def forward (self , x : Tensor , jacobian : Union [Tensor , bool ] = False ) -> Union [Tensor , Tuple [Tensor , Tensor ]]:
37
+
38
+ def forward (
39
+ self , x : Tensor , jacobian : Union [Tensor , bool ] = False
40
+ ) -> Union [Tensor , Tuple [Tensor , Tensor ]]:
36
41
if jacobian :
37
42
j = identity (x ) if (not isinstance (jacobian , Tensor ) and jacobian ) else jacobian
38
43
for module in self ._modules .values ():
@@ -46,9 +51,10 @@ def forward(self, x: Tensor, jacobian: Union[Tensor, bool] = False) -> Union[Ten
46
51
47
52
48
53
class AbstractJacobian :
49
- """ Abstract class that will overwrite the default behaviour of the forward method such that it
50
- is also possible to return the jacobian
54
+ """Abstract class that will overwrite the default behaviour of the forward method such that it
55
+ is also possible to return the jacobian
51
56
"""
57
+
52
58
def _jacobian (self , x : Tensor , val : Tensor ) -> Tensor :
53
59
return self ._jacobian_mult (x , val , identity (x ))
54
60
@@ -62,7 +68,7 @@ def __call__(self, x: Tensor, jacobian: bool = False) -> Union[Tensor, Tuple[Ten
62
68
63
69
class Linear (AbstractJacobian , nn .Linear ):
64
70
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
65
- return F .linear (jac_in .movedim (1 ,- 1 ), self .weight , bias = None ).movedim (- 1 ,1 )
71
+ return F .linear (jac_in .movedim (1 , - 1 ), self .weight , bias = None ).movedim (- 1 , 1 )
66
72
67
73
68
74
class PosLinear (AbstractJacobian , nn .Linear ):
@@ -74,82 +80,166 @@ def forward(self, x: Tensor):
74
80
return val
75
81
76
82
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
77
- return F .linear (jac_in .movedim (1 ,- 1 ), F .softplus (self .weight ), bias = None ).movedim (- 1 ,1 )
83
+ return F .linear (jac_in .movedim (1 , - 1 ), F .softplus (self .weight ), bias = None ).movedim (- 1 , 1 )
78
84
79
85
80
86
class Upsample (AbstractJacobian , nn .Upsample ):
81
87
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
82
88
xs = x .shape
83
89
vs = val .shape
84
90
if x .ndim == 3 :
85
- return F .interpolate (jac_in .movedim ((1 ,2 ),(- 2 ,- 1 )).reshape (- 1 , * xs [1 :]),
86
- self .size , self .scale_factor , self .mode , self .align_corners
87
- ).reshape (xs [0 ], * jac_in .shape [3 :], * vs [1 :]).movedim ((- 2 , - 1 ), (1 , 2 ))
91
+ return (
92
+ F .interpolate (
93
+ jac_in .movedim ((1 , 2 ), (- 2 , - 1 )).reshape (- 1 , * xs [1 :]),
94
+ self .size ,
95
+ self .scale_factor ,
96
+ self .mode ,
97
+ self .align_corners ,
98
+ )
99
+ .reshape (xs [0 ], * jac_in .shape [3 :], * vs [1 :])
100
+ .movedim ((- 2 , - 1 ), (1 , 2 ))
101
+ )
88
102
if x .ndim == 4 :
89
- return F .interpolate (jac_in .movedim ((1 ,2 ,3 ),(- 3 ,- 2 ,- 1 )).reshape (- 1 , * xs [1 :]),
90
- self .size , self .scale_factor , self .mode , self .align_corners
91
- ).reshape (xs [0 ], * jac_in .shape [4 :], * vs [1 :]).movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
103
+ return (
104
+ F .interpolate (
105
+ jac_in .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , * xs [1 :]),
106
+ self .size ,
107
+ self .scale_factor ,
108
+ self .mode ,
109
+ self .align_corners ,
110
+ )
111
+ .reshape (xs [0 ], * jac_in .shape [4 :], * vs [1 :])
112
+ .movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
113
+ )
92
114
if x .ndim == 5 :
93
- return F .interpolate (jac_in .movedim ((1 ,2 ,3 ,4 ),(- 4 ,- 3 ,- 2 ,- 1 )).reshape (- 1 , * xs [1 :]),
94
- self .size , self .scale_factor , self .mode , self .align_corners
95
- ).reshape (xs [0 ], * jac_in .shape [5 :], * vs [1 :]).movedim ((- 4 ,- 3 ,- 2 , - 1 ), (1 , 2 , 3 , 4 ))
115
+ return (
116
+ F .interpolate (
117
+ jac_in .movedim ((1 , 2 , 3 , 4 ), (- 4 , - 3 , - 2 , - 1 )).reshape (- 1 , * xs [1 :]),
118
+ self .size ,
119
+ self .scale_factor ,
120
+ self .mode ,
121
+ self .align_corners ,
122
+ )
123
+ .reshape (xs [0 ], * jac_in .shape [5 :], * vs [1 :])
124
+ .movedim ((- 4 , - 3 , - 2 , - 1 ), (1 , 2 , 3 , 4 ))
125
+ )
96
126
97
127
98
128
class Conv1d (AbstractJacobian , nn .Conv1d ):
99
129
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
100
130
b , c1 , l1 = x .shape
101
131
c2 , l2 = val .shape [1 :]
102
- return F .conv1d (jac_in .movedim ((1 , 2 ), (- 2 , - 1 )).reshape (- 1 , c1 , l1 ), weight = self .weight ,
103
- bias = None , stride = self .stride , padding = self .padding , dilation = self .dilation , groups = self .groups ,
104
- ).reshape (b , * jac_in .shape [3 :], c2 , l2 ).movedim ((- 2 , - 1 ), (1 , 2 ))
132
+ return (
133
+ F .conv1d (
134
+ jac_in .movedim ((1 , 2 ), (- 2 , - 1 )).reshape (- 1 , c1 , l1 ),
135
+ weight = self .weight ,
136
+ bias = None ,
137
+ stride = self .stride ,
138
+ padding = self .padding ,
139
+ dilation = self .dilation ,
140
+ groups = self .groups ,
141
+ )
142
+ .reshape (b , * jac_in .shape [3 :], c2 , l2 )
143
+ .movedim ((- 2 , - 1 ), (1 , 2 ))
144
+ )
105
145
106
146
107
147
class ConvTranspose1d (AbstractJacobian , nn .ConvTranspose1d ):
108
148
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
109
149
b , c1 , l1 = x .shape
110
150
c2 , l2 = val .shape [1 :]
111
- return F .conv_transpose1d (jac_in .movedim ((1 , 2 ), (- 2 , - 1 )).reshape (- 1 , c1 , l1 ), weight = self .weight ,
112
- bias = None , stride = self .stride , padding = self .padding , dilation = self .dilation , groups = self .groups ,
113
- output_padding = self .output_padding
114
- ).reshape (b , * jac_in .shape [3 :], c2 , l2 ).movedim ((- 2 , - 1 ), (1 , 2 ))
151
+ return (
152
+ F .conv_transpose1d (
153
+ jac_in .movedim ((1 , 2 ), (- 2 , - 1 )).reshape (- 1 , c1 , l1 ),
154
+ weight = self .weight ,
155
+ bias = None ,
156
+ stride = self .stride ,
157
+ padding = self .padding ,
158
+ dilation = self .dilation ,
159
+ groups = self .groups ,
160
+ output_padding = self .output_padding ,
161
+ )
162
+ .reshape (b , * jac_in .shape [3 :], c2 , l2 )
163
+ .movedim ((- 2 , - 1 ), (1 , 2 ))
164
+ )
115
165
116
166
117
167
class Conv2d (AbstractJacobian , nn .Conv2d ):
118
168
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
119
169
b , c1 , h1 , w1 = x .shape
120
170
c2 , h2 , w2 = val .shape [1 :]
121
- return F .conv2d (jac_in .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , c1 , h1 , w1 ), weight = self .weight ,
122
- bias = None , stride = self .stride , padding = self .padding , dilation = self .dilation , groups = self .groups ,
123
- ).reshape (b , * jac_in .shape [4 :], c2 , h2 , w2 ).movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
171
+ return (
172
+ F .conv2d (
173
+ jac_in .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , c1 , h1 , w1 ),
174
+ weight = self .weight ,
175
+ bias = None ,
176
+ stride = self .stride ,
177
+ padding = self .padding ,
178
+ dilation = self .dilation ,
179
+ groups = self .groups ,
180
+ )
181
+ .reshape (b , * jac_in .shape [4 :], c2 , h2 , w2 )
182
+ .movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
183
+ )
124
184
125
185
126
186
class ConvTranspose2d (AbstractJacobian , nn .ConvTranspose2d ):
127
187
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
128
188
b , c1 , h1 , w1 = x .shape
129
189
c2 , h2 , w2 = val .shape [1 :]
130
- return F .conv_transpose2d (jac_in .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , c1 , h1 , w1 ), weight = self .weight ,
131
- bias = None , stride = self .stride , padding = self .padding , dilation = self .dilation , groups = self .groups ,
132
- output_padding = self .output_padding ,
133
- ).reshape (b , * jac_in .shape [4 :], c2 , h2 , w2 ).movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
190
+ return (
191
+ F .conv_transpose2d (
192
+ jac_in .movedim ((1 , 2 , 3 ), (- 3 , - 2 , - 1 )).reshape (- 1 , c1 , h1 , w1 ),
193
+ weight = self .weight ,
194
+ bias = None ,
195
+ stride = self .stride ,
196
+ padding = self .padding ,
197
+ dilation = self .dilation ,
198
+ groups = self .groups ,
199
+ output_padding = self .output_padding ,
200
+ )
201
+ .reshape (b , * jac_in .shape [4 :], c2 , h2 , w2 )
202
+ .movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
203
+ )
134
204
135
205
136
206
class Conv3d (AbstractJacobian , nn .Conv3d ):
137
207
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
138
208
b , c1 , d1 , h1 , w1 = x .shape
139
209
c2 , d2 , h2 , w2 = val .shape [1 :]
140
- return F .conv3d (jac_in .movedim ((1 , 2 , 3 , 4 ), (- 4 , - 3 , - 2 , - 1 )).reshape (- 1 , c1 , d1 , h1 , w1 ), weight = self .weight ,
141
- bias = None , stride = self .stride , padding = self .padding , dilation = self .dilation , groups = self .groups ,
142
- ).reshape (b , * jac_in .shape [5 :], c2 , d2 , h2 , w2 ).movedim ((- 4 , - 3 , - 2 , - 1 ), (1 , 2 , 3 , 4 ))
210
+ return (
211
+ F .conv3d (
212
+ jac_in .movedim ((1 , 2 , 3 , 4 ), (- 4 , - 3 , - 2 , - 1 )).reshape (- 1 , c1 , d1 , h1 , w1 ),
213
+ weight = self .weight ,
214
+ bias = None ,
215
+ stride = self .stride ,
216
+ padding = self .padding ,
217
+ dilation = self .dilation ,
218
+ groups = self .groups ,
219
+ )
220
+ .reshape (b , * jac_in .shape [5 :], c2 , d2 , h2 , w2 )
221
+ .movedim ((- 4 , - 3 , - 2 , - 1 ), (1 , 2 , 3 , 4 ))
222
+ )
143
223
144
224
145
225
class ConvTranspose3d (AbstractJacobian , nn .ConvTranspose3d ):
146
226
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
147
227
b , c1 , d1 , h1 , w1 = x .shape
148
228
c2 , d2 , h2 , w2 = val .shape [1 :]
149
- return F .conv_transpose3d (jac_in .movedim ((1 , 2 , 3 , 4 ), (- 4 , - 3 , - 2 , - 1 )).reshape (- 1 , c1 , d1 , h1 , w1 ), weight = self .weight ,
150
- bias = None , stride = self .stride , padding = self .padding , dilation = self .dilation , groups = self .groups ,
151
- output_padding = self .output_padding
152
- ).reshape (b , * jac_in .shape [5 :], c2 , d2 , h2 , w2 ).movedim ((- 4 , - 3 , - 2 , - 1 ), (1 , 2 , 3 , 4 ))
229
+ return (
230
+ F .conv_transpose3d (
231
+ jac_in .movedim ((1 , 2 , 3 , 4 ), (- 4 , - 3 , - 2 , - 1 )).reshape (- 1 , c1 , d1 , h1 , w1 ),
232
+ weight = self .weight ,
233
+ bias = None ,
234
+ stride = self .stride ,
235
+ padding = self .padding ,
236
+ dilation = self .dilation ,
237
+ groups = self .groups ,
238
+ output_padding = self .output_padding ,
239
+ )
240
+ .reshape (b , * jac_in .shape [5 :], c2 , d2 , h2 , w2 )
241
+ .movedim ((- 4 , - 3 , - 2 , - 1 ), (1 , 2 , 3 , 4 ))
242
+ )
153
243
154
244
155
245
class Reshape (AbstractJacobian , nn .Module ):
@@ -186,14 +276,14 @@ class AbstractActivationJacobian:
186
276
def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
187
277
jac = self ._jacobian (x , val )
188
278
n = jac_in .ndim - jac .ndim
189
- return jac_in * jac .reshape (jac .shape + (1 ,)* n )
279
+ return jac_in * jac .reshape (jac .shape + (1 ,) * n )
190
280
191
281
def __call__ (self , x : Tensor , jacobian : bool = False ) -> Union [Tensor , Tuple [Tensor , Tensor ]]:
192
282
val = self ._call_impl (x )
193
283
if jacobian :
194
284
jac = self ._jacobian (x , val )
195
285
return val , jac
196
- return val
286
+ return val
197
287
198
288
199
289
class Sigmoid (AbstractActivationJacobian , nn .Sigmoid ):
0 commit comments