4545
4646
4747class JacobianReg (nn .Module ):
48- '''
49- Loss criterion that computes the trace of the square of the Jacobian.
50-
51- Arguments:
52- n (int, optional): determines the number of random projections.
53- If n=-1, then it is set to the dimension of the output
54- space and projection is non-random and orthonormal, yielding
55- the exact result. For any reasonable batch size, the default
56- (n=1) should be sufficient.
57- '''
58-
59- def __init__ (self , n = 1 ):
48+ """Loss criterion that computes the trace of the square of the Jacobian.
49+
50+ Args:
51+ n: Determines the number of random projections. If n=-1, then it is set to the dimension
52+ of the output space and projection is non-random and orthonormal, yielding the exact
53+ result. For any reasonable batch size, the default (n=1) should be sufficient.
54+ |Default:| ``1``
55+
56+ Note:
57+ This implementation is adapted from the Jacobian regularization described in [1].
58+ [1] Judy Hoffman, Daniel A. Roberts, and Sho Yaida,
59+ "Robust Learning with Jacobian Regularization," 2019.
60+ [arxiv:1908.02729](https://arxiv.org/abs/1908.02729)
61+ """
62+
63+ def __init__ (self , n : int = 1 ):
6064 assert n == - 1 or n > 0
6165 self .n = n
6266 super (JacobianReg , self ).__init__ ()
6367
64- def forward (self , x , y ):
65- '''
66- computes (1/2) tr |dy/dx|^2
67- '''
68+ def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
69+ """Computes (1/2) tr |dy/dx|^2.
70+
71+ Args:
72+ x: Input tensor
73+ y: Output tensor
74+
75+ Returns:
76+ The computed regularization term
77+ """
6878 B , C = y .shape
6979 if self .n == - 1 :
7080 num_proj = C
@@ -86,11 +96,18 @@ def forward(self, x, y):
8696 R = (1 / 2 ) * J2
8797 return R
8898
89- def _random_vector (self , C , B ):
90- '''
91- creates a random vector of dimension C with a norm of C^(1/2)
92- (as needed for the projection formula to work)
93- '''
99+ def _random_vector (self , C : int , B : int ) -> torch .Tensor :
100+ """Creates a random vector of dimension C with a norm of C^(1/2).
101+
102+ This is needed for the projection formula to work.
103+
104+ Args:
105+ C: Output dimension
106+ B: Batch size
107+
108+ Returns:
109+ A random normalized vector
110+ """
94111 if C == 1 :
95112 return torch .ones (B )
96113 v = torch .randn (B , C )
@@ -99,13 +116,26 @@ def _random_vector(self, C, B):
99116 v = torch .addcdiv (arxilirary_zero , 1.0 , v , vnorm )
100117 return v
101118
102- def _jacobian_vector_product (self , y , x , v , create_graph = False ):
103- '''
104- Produce jacobian-vector product dy/dx dot v.
119+ def _jacobian_vector_product (self ,
120+ y : torch .Tensor ,
121+ x : torch .Tensor ,
122+ v : torch .Tensor ,
123+ create_graph : bool = False ) -> torch .Tensor :
124+ """Produce jacobian-vector product dy/dx dot v.
125+
126+ Args:
127+ y: Output tensor
128+ x: Input tensor
129+ v: Vector to compute product with
130+ create_graph: If True, graph of the derivative will be constructed, allowing
131+ to compute higher order derivative products. |Default:| ``False``
132+
133+ Returns:
134+ The Jacobian-vector product
105135
106- Note that if you want to differentiate it,
107- you need to make create_graph=True
108- '''
136+ Note:
137+ If you want to differentiate the result, you need to make create_graph=True
138+ """
109139 flat_y = y .reshape (- 1 )
110140 flat_v = v .reshape (- 1 )
111141 grad_x , = torch .autograd .grad (flat_y ,
0 commit comments