Skip to content

Commit 7086629

Browse files
committed
docs: docstring
1 parent 59ed2a8 commit 7086629

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010

1111
class LayerWiseGrafting(IntEnum):
12-
r"""layer-wise grafting
12+
r"""Layer-wise grafting.
13+
1314
Grafting is a technique to fix the layer-wise scale of Shampoo optimizer.
1415
https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
1516
allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
@@ -29,12 +30,15 @@ def __init__(self, *args):
2930
pass
3031

3132
def add_statistics(self, grad: torch.Tensor):
33+
r"""Add the statistics."""
3234
pass
3335

3436
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
37+
r"""Get preconditioned gradient."""
3538
return grad
3639

3740
def update_momentum(self, update: torch.Tensor, unused_beta1: float) -> torch.Tensor: # noqa: ARG002
41+
r"""Update momentum."""
3842
return update
3943

4044

@@ -46,27 +50,35 @@ def __init__(self, var: torch.Tensor):
4650
self.momentum: torch.Tensor = torch.zeros_like(var, device=var.device)
4751

4852
def update_momentum(self, update: torch.Tensor, beta1: float) -> torch.Tensor:
53+
r"""Update momentum."""
4954
self.momentum.mul_(beta1).add_(update)
5055
return self.momentum
5156

5257

5358
class AdagradGraft(SGDGraft):
54-
r"""Graft using Adagrad. Essentially an implementation of Adagrad with momentum."""
59+
r"""Graft using Adagrad. Essentially an implementation of Adagrad with momentum.
60+
61+
:param var: torch.Tensor. variable.
62+
:param diagonal_eps: float. diagonal epsilon.
63+
"""
5564

5665
def __init__(self, var: torch.Tensor, diagonal_eps: float):
5766
super().__init__(var)
5867
self.diagonal_eps = diagonal_eps
5968
self.statistics: torch.Tensor = torch.zeros_like(var, device=var.device)
6069

6170
def add_statistics(self, grad: torch.Tensor):
71+
r"""Add the statistics."""
6272
self.statistics.add_(grad.pow(2))
6373

6474
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
75+
r"""Get preconditioned gradient."""
6576
return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)
6677

6778

6879
class BlockPartitioner:
69-
r"""Partitions a tensor into smaller tensors for preconditioning.
80+
r"""Partition a tensor into smaller tensors for preconditioning.
81+
7082
For example, if a variable has shape (4096, 512), we might split the 4096 into 4 blocks,
7183
so we effectively have 4 variables of size (1024, 512) each.
7284
@@ -101,6 +113,7 @@ def __init__(self, var: torch.Tensor, block_size: int):
101113
self.pre_conditioner_shapes.extend([[d, d] for d in t])
102114

103115
def shapes_for_pre_conditioners(self) -> List[List[int]]:
116+
r"""Get shapes of pre-conditioner."""
104117
return self.pre_conditioner_shapes
105118

106119
def partition(self, x: torch.Tensor) -> List[torch.Tensor]:
@@ -132,7 +145,15 @@ def merge_partitions(self, partitions: List[torch.Tensor]) -> torch.Tensor:
132145

133146

134147
class PreConditioner:
135-
r"""Compute statistics/shape from gradients for preconditioning."""
148+
r"""Compute statistics/shape from gradients for preconditioning.
149+
150+
:param var: torch.Tensor. variable.
151+
:param beta2: float. beta2.
152+
:param inverse_exponent_override: int.
153+
:param block_size: int.
154+
:param shape_interpretation: bool.
155+
:param matrix_eps: float.
156+
"""
136157

137158
def __init__(
138159
self,
@@ -182,7 +203,7 @@ def add_statistics(self, grad: torch.Tensor):
182203
self.statistics[j * rank + i].mul_(self.beta2).add_(stat, alpha=w2)
183204

184205
def exponent_for_pre_conditioner(self) -> int:
185-
r"""Returns exponent to use for inverse-pth root M^{-1/p}."""
206+
r"""Return exponent to use for inverse-pth root M^{-1/p}."""
186207
return (
187208
self.inverse_exponent_override if self.inverse_exponent_override > 0 else 2 * len(self.transformed_shape)
188209
)

0 commit comments

Comments
 (0)