Skip to content

Commit 458958f

Browse files
committed
Fix docstrings and type annot in cebra/models/jacobian_regularizer.py
1 parent 5e30829 commit 458958f

File tree

1 file changed

+57
-27
lines changed

1 file changed

+57
-27
lines changed

cebra/models/jacobian_regularizer.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,36 @@
4545

4646

4747
class JacobianReg(nn.Module):
48-
'''
49-
Loss criterion that computes the trace of the square of the Jacobian.
50-
51-
Arguments:
52-
n (int, optional): determines the number of random projections.
53-
If n=-1, then it is set to the dimension of the output
54-
space and projection is non-random and orthonormal, yielding
55-
the exact result. For any reasonable batch size, the default
56-
(n=1) should be sufficient.
57-
'''
58-
59-
def __init__(self, n=1):
48+
"""Loss criterion that computes the trace of the square of the Jacobian.
49+
50+
Args:
51+
n: Determines the number of random projections. If n=-1, then it is set to the dimension
52+
of the output space and projection is non-random and orthonormal, yielding the exact
53+
result. For any reasonable batch size, the default (n=1) should be sufficient.
54+
|Default:| ``1``
55+
56+
Note:
57+
This implementation is adapted from the Jacobian regularization described in [1].
58+
[1] Judy Hoffman, Daniel A. Roberts, and Sho Yaida,
59+
"Robust Learning with Jacobian Regularization," 2019.
60+
[arxiv:1908.02729](https://arxiv.org/abs/1908.02729)
61+
"""
62+
63+
def __init__(self, n: int = 1):
6064
assert n == -1 or n > 0
6165
self.n = n
6266
super(JacobianReg, self).__init__()
6367

64-
def forward(self, x, y):
65-
'''
66-
computes (1/2) tr |dy/dx|^2
67-
'''
68+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
69+
"""Computes (1/2) tr |dy/dx|^2.
70+
71+
Args:
72+
x: Input tensor
73+
y: Output tensor
74+
75+
Returns:
76+
The computed regularization term
77+
"""
6878
B, C = y.shape
6979
if self.n == -1:
7080
num_proj = C
@@ -86,11 +96,18 @@ def forward(self, x, y):
8696
R = (1 / 2) * J2
8797
return R
8898

89-
def _random_vector(self, C, B):
90-
'''
91-
creates a random vector of dimension C with a norm of C^(1/2)
92-
(as needed for the projection formula to work)
93-
'''
99+
def _random_vector(self, C: int, B: int) -> torch.Tensor:
100+
"""Creates a random vector of dimension C with a norm of C^(1/2).
101+
102+
This is needed for the projection formula to work.
103+
104+
Args:
105+
C: Output dimension
106+
B: Batch size
107+
108+
Returns:
109+
A random normalized vector
110+
"""
94111
if C == 1:
95112
return torch.ones(B)
96113
v = torch.randn(B, C)
@@ -99,13 +116,26 @@ def _random_vector(self, C, B):
99116
v = torch.addcdiv(arxilirary_zero, 1.0, v, vnorm)
100117
return v
101118

102-
def _jacobian_vector_product(self, y, x, v, create_graph=False):
103-
'''
104-
Produce jacobian-vector product dy/dx dot v.
119+
def _jacobian_vector_product(self,
120+
y: torch.Tensor,
121+
x: torch.Tensor,
122+
v: torch.Tensor,
123+
create_graph: bool = False) -> torch.Tensor:
124+
"""Produce jacobian-vector product dy/dx dot v.
125+
126+
Args:
127+
y: Output tensor
128+
x: Input tensor
129+
v: Vector to compute product with
130+
create_graph: If True, graph of the derivative will be constructed, allowing
131+
to compute higher order derivative products. |Default:| ``False``
132+
133+
Returns:
134+
The Jacobian-vector product
105135
106-
Note that if you want to differentiate it,
107-
you need to make create_graph=True
108-
'''
136+
Note:
137+
If you want to differentiate the result, you need to make create_graph=True
138+
"""
109139
flat_y = y.reshape(-1)
110140
flat_v = v.reshape(-1)
111141
grad_x, = torch.autograd.grad(flat_y,

0 commit comments

Comments
 (0)