99
1010
1111class LayerWiseGrafting (IntEnum ):
12- r"""layer-wise grafting
12+ r"""Layer-wise grafting.
13+
1314 Grafting is a technique to fix the layer-wise scale of Shampoo optimizer.
1415 https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
1516 allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
@@ -29,12 +30,15 @@ def __init__(self, *args):
2930 pass
3031
3132 def add_statistics (self , grad : torch .Tensor ):
33+ r"""Add the statistics."""
3234 pass
3335
3436 def precondition_gradient (self , grad : torch .Tensor ) -> torch .Tensor :
37+ r"""Get preconditioned gradient."""
3538 return grad
3639
3740 def update_momentum (self , update : torch .Tensor , unused_beta1 : float ) -> torch .Tensor : # noqa: ARG002
41+ r"""Update momentum."""
3842 return update
3943
4044
@@ -46,27 +50,35 @@ def __init__(self, var: torch.Tensor):
4650 self .momentum : torch .Tensor = torch .zeros_like (var , device = var .device )
4751
4852 def update_momentum (self , update : torch .Tensor , beta1 : float ) -> torch .Tensor :
53+ r"""Update momentum."""
4954 self .momentum .mul_ (beta1 ).add_ (update )
5055 return self .momentum
5156
5257
5358class AdagradGraft (SGDGraft ):
54- r"""Graft using Adagrad. Essentially an implementation of Adagrad with momentum."""
59+ r"""Graft using Adagrad. Essentially an implementation of Adagrad with momentum.
60+
61+ :param var: torch.Tensor. variable.
62+ :param diagonal_eps: float. diagonal epsilon.
63+ """
5564
5665 def __init__ (self , var : torch .Tensor , diagonal_eps : float ):
5766 super ().__init__ (var )
5867 self .diagonal_eps = diagonal_eps
5968 self .statistics : torch .Tensor = torch .zeros_like (var , device = var .device )
6069
6170 def add_statistics (self , grad : torch .Tensor ):
71+ r"""Add the statistics."""
6272 self .statistics .add_ (grad .pow (2 ))
6373
6474 def precondition_gradient (self , grad : torch .Tensor ) -> torch .Tensor :
75+ r"""Get preconditioned gradient."""
6576 return grad / (torch .sqrt (self .statistics ) + self .diagonal_eps )
6677
6778
6879class BlockPartitioner :
69- r"""Partitions a tensor into smaller tensors for preconditioning.
80+ r"""Partition a tensor into smaller tensors for preconditioning.
81+
7082 For example, if a variable has shape (4096, 512), we might split the 4096 into 4 blocks,
7183 so we effectively have 4 variables of size (1024, 512) each.
7284
@@ -101,6 +113,7 @@ def __init__(self, var: torch.Tensor, block_size: int):
101113 self .pre_conditioner_shapes .extend ([[d , d ] for d in t ])
102114
103115 def shapes_for_pre_conditioners (self ) -> List [List [int ]]:
116+ r"""Get shapes of pre-conditioner."""
104117 return self .pre_conditioner_shapes
105118
106119 def partition (self , x : torch .Tensor ) -> List [torch .Tensor ]:
@@ -132,7 +145,15 @@ def merge_partitions(self, partitions: List[torch.Tensor]) -> torch.Tensor:
132145
133146
134147class PreConditioner :
135- r"""Compute statistics/shape from gradients for preconditioning."""
148+ r"""Compute statistics/shape from gradients for preconditioning.
149+
150+ :param var: torch.Tensor. variable.
151+ :param beta2: float. beta2.
152+ :param inverse_exponent_override: int.
153+ :param block_size: int.
154+ :param shape_interpretation: bool.
155+ :param matrix_eps: float.
156+ """
136157
137158 def __init__ (
138159 self ,
@@ -182,7 +203,7 @@ def add_statistics(self, grad: torch.Tensor):
182203 self .statistics [j * rank + i ].mul_ (self .beta2 ).add_ (stat , alpha = w2 )
183204
184205 def exponent_for_pre_conditioner (self ) -> int :
185- r"""Returns exponent to use for inverse-pth root M^{-1/p}."""
206+ r"""Return exponent to use for inverse-pth root M^{-1/p}."""
186207 return (
187208 self .inverse_exponent_override if self .inverse_exponent_override > 0 else 2 * len (self .transformed_shape )
188209 )
0 commit comments