3
3
from enum import Enum
4
4
from typing import Optional , Tuple , Union
5
5
6
+ import numpy as np
6
7
import torch
7
8
import torch .nn .functional as F
8
9
from torch import nn
9
- import numpy as np
10
10
11
11
12
12
class JacType (Enum ):
@@ -17,40 +17,43 @@ class JacType(Enum):
17
17
FULL: The Jacobian is a matrix of whatever size.
18
18
"""
19
19
20
- DIAG = ' diag'
21
- FULL = ' full'
22
- CONV = ' conv'
23
-
20
+ DIAG = " diag"
21
+ FULL = " full"
22
+ CONV = " conv"
23
+
24
24
def __eq__ (self , other : Union [str , Enum ]) -> bool :
25
25
other = other .value if isinstance (other , Enum ) else str (other )
26
26
return self .value .lower () == other .lower ()
27
-
27
+
28
28
29
29
class Jacobian (torch .Tensor ):
30
- """ Class representing a jacobian tensor, subclasses from torch.Tensor
31
- Requires the additional `jactype` parameter to initialize, which
32
- is a string indicating the jacobian type
30
+ """Class representing a jacobian tensor, subclasses from torch.Tensor
31
+ Requires the additional `jactype` parameter to initialize, which
32
+ is a string indicating the jacobian type
33
33
"""
34
+
34
35
def __init__ (self , tensor , jactype ):
35
36
available_jactype = [item .value for item in JacType ]
36
37
if jactype not in available_jactype :
37
- raise ValueError (f'Tried to initialize jacobian tensor with unknown jacobian type { jactype } .'
38
- f' Please choose between { available_jactype } ' )
38
+ raise ValueError (
39
+ f"Tried to initialize jacobian tensor with unknown jacobian type { jactype } ."
40
+ f" Please choose between { available_jactype } "
41
+ )
39
42
self .jactype = jactype
40
-
43
+
41
44
@staticmethod
42
45
def __new__ (cls , x , jactype , * args , ** kwargs ):
43
46
cls .jactype = jactype
44
47
return super ().__new__ (cls , x , * args , ** kwargs )
45
-
48
+
46
49
def __repr__ (self ):
47
50
tensor_repr = super ().__repr__ ()
48
- tensor_repr = tensor_repr .replace (' tensor' , ' jacobian' )
49
- tensor_repr += f' \n jactype={ self .jactype .value if isinstance (self .jactype , Enum ) else self .jactype } '
51
+ tensor_repr = tensor_repr .replace (" tensor" , " jacobian" )
52
+ tensor_repr += f" \n jactype={ self .jactype .value if isinstance (self .jactype , Enum ) else self .jactype } "
50
53
return tensor_repr
51
-
54
+
52
55
def __add__ (self , other ):
53
- if isinstance (other , Jacobian ):
56
+ if isinstance (other , Jacobian ):
54
57
if self .jactype == other .jactype :
55
58
res = torch .add (self , other )
56
59
return jacobian (res , self .jactype )
@@ -59,14 +62,14 @@ def __add__(self, other):
59
62
return jacobian (res , JacType .FULL )
60
63
if self .jactype == JacType .DIAG and other .jactype == JacType .FULL :
61
64
res = torch .add (torch .diag_embed (self ), other )
62
- return jacobian (res , JacType .FULL )
65
+ return jacobian (res , JacType .FULL )
63
66
if self .jactype == JacType .CONV and other .jactype == JacType .CONV :
64
67
res = torch .add (self , other )
65
68
return jacobian (res , JacType .CONV )
66
- raise ValueError (' Unknown addition of jacobian matrices' )
67
-
69
+ raise ValueError (" Unknown addition of jacobian matrices" )
70
+
68
71
return super ().__add__ (other )
69
-
72
+
70
73
def __matmul__ (self , other ):
71
74
if isinstance (other , Jacobian ):
72
75
# diag * diag
@@ -90,9 +93,9 @@ def __matmul__(self, other):
90
93
if other == JacType .CONV :
91
94
res = self * other
92
95
return jacobian (res , JacType .CONV )
93
-
94
- raise ValueError (' Unknown matrix multiplication of jacobian matrices' )
95
-
96
+
97
+ raise ValueError (" Unknown matrix multiplication of jacobian matrices" )
98
+
96
99
97
100
def jacobian (tensor , jactype ):
98
101
""" Initialize a jacobian tensor by a specified jacobian type """
@@ -126,8 +129,7 @@ def _jacobian(self, x: torch.Tensor, val: torch.Tensor) -> Jacobian:
126
129
attains value val."""
127
130
pass
128
131
129
- def _jac_mul (
130
- self , x : torch .Tensor , val : torch .Tensor , jac_in : torch .Tensor ) -> Jacobian :
132
+ def _jac_mul (self , x : torch .Tensor , val : torch .Tensor , jac_in : torch .Tensor ) -> Jacobian :
131
133
"""Multiply the Jacobian at x with M.
132
134
This can potentially be done more efficiently than
133
135
first computing the Jacobian, and then performing the
@@ -557,9 +559,9 @@ def forward(self, x: torch.Tensor, jacobian: bool = False):
557
559
def _jacobian (self , x : torch .Tensor , val : torch .Tensor ) -> Jacobian :
558
560
w = self ._conv_to_toeplitz (x .shape [1 :])
559
561
w = w .unsqueeze (0 ).repeat (x .shape [0 ], 1 , 1 )
560
- return jacobian (w , JacType .CONV )
562
+ return jacobian (w , JacType .CONV )
563
+
561
564
562
-
563
565
class Conv1d (_BaseJacConv , nn .Conv1d ):
564
566
def _conv_to_toeplitz (self , input_shape ):
565
567
identity = torch .eye (np .prod (input_shape ).item ()).reshape ([- 1 ] + list (input_shape ))
0 commit comments