Skip to content

Commit e792181

Browse files
committed
refactor: shampoo_utils
1 parent 502861a commit e792181

File tree

3 files changed

+222
-219
lines changed

3 files changed

+222
-219
lines changed

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 1 addition & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -1,227 +1,10 @@
1-
import itertools
2-
from enum import IntEnum
3-
from typing import List, Tuple
4-
5-
import numpy as np
61
import torch
72
from torch.optim.optimizer import Optimizer
83

94
from pytorch_optimizer.base.exception import NoSparseGradientError
105
from pytorch_optimizer.base.optimizer import BaseOptimizer
116
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
12-
from pytorch_optimizer.optimizer.utils import compute_power, merge_small_dims
13-
14-
15-
class LayerWiseGrafting(IntEnum):
16-
r"""layer-wise grafting
17-
Grafting is a technique to fix the layer-wise scale of Shampoo optimizer.
18-
https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
19-
allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
20-
is already well tuned. Grafting onto Shampoo means take the Shampoo direction,
21-
but use the step magnitude from the grafted optimizer such as Adagrad or SGD.
22-
"""
23-
NONE = 0
24-
SGD = 1
25-
ADAGRAD = 2
26-
27-
28-
class Graft:
29-
r"""Base class to perform grafting onto Shampoo. This class does no grafting."""
30-
31-
def __init__(self, *args):
32-
pass
33-
34-
def add_statistics(self, grad: torch.Tensor):
35-
pass
36-
37-
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
38-
return grad
39-
40-
def update_momentum(self, update: torch.Tensor, unused_beta1: float) -> torch.Tensor: # noqa: ARG002
41-
return update
42-
43-
44-
class SGDGraft(Graft):
45-
r"""Graft using SGD + momentum. momentum maintains an exponentially weighted moving average of gradients."""
46-
47-
def __init__(self, var: torch.Tensor):
48-
super().__init__(var)
49-
self.momentum: torch.Tensor = torch.zeros_like(var, device=var.device)
50-
51-
def update_momentum(self, update: torch.Tensor, beta1: float) -> torch.Tensor:
52-
self.momentum.mul_(beta1).add_(update)
53-
return self.momentum
54-
55-
56-
class AdagradGraft(SGDGraft):
57-
r"""Graft using Adagrad. Essentially an implementation of Adagrad with momentum."""
58-
59-
def __init__(self, var: torch.Tensor, diagonal_eps: float):
60-
super().__init__(var)
61-
self.diagonal_eps = diagonal_eps
62-
self.statistics: torch.Tensor = torch.zeros_like(var, device=var.device)
63-
64-
def add_statistics(self, grad: torch.Tensor):
65-
self.statistics.add_(grad.pow(2))
66-
67-
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
68-
return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)
69-
70-
71-
class BlockPartitioner:
72-
r"""Partitions a tensor into smaller tensors for preconditioning.
73-
For example, if a variable has shape (4096, 512), we might split the 4096 into 4 blocks,
74-
so we effectively have 4 variables of size (1024, 512) each.
75-
76-
:param var: torch.Tensor. tensor variable.
77-
:param block_size: int. block size.
78-
"""
79-
80-
def __init__(self, var: torch.Tensor, block_size: int):
81-
self.shape: List[int] = var.shape
82-
self.splits: List[Tuple[int, np.ndarray]] = []
83-
self.split_sizes: List[Tuple[int, np.ndarray]] = []
84-
85-
split_sizes: List[np.ndarray] = []
86-
87-
# We split var into smaller blocks. Here we store the metadata to make that split.
88-
for i, d in enumerate(self.shape):
89-
if 0 < block_size < d:
90-
# d - 1, otherwise split appends a 0-size array.
91-
num_split: int = (d - 1) // block_size
92-
indices = (np.arange(num_split, dtype=np.int32) + 1) * block_size
93-
sizes = np.ones(num_split + 1, dtype=np.int32) * block_size
94-
sizes[-1] = d - indices[-1]
95-
self.splits.append((i, indices))
96-
self.split_sizes.append((i, sizes))
97-
split_sizes.append(sizes)
98-
else:
99-
split_sizes.append(np.array([d], dtype=np.int32))
100-
101-
self.num_splits: int = len(split_sizes)
102-
self.pre_conditioner_shapes: List[List[int]] = []
103-
for t in itertools.product(*split_sizes):
104-
self.pre_conditioner_shapes.extend([[d, d] for d in t])
105-
106-
def shapes_for_pre_conditioners(self) -> List[List[int]]:
107-
return self.pre_conditioner_shapes
108-
109-
def partition(self, x: torch.Tensor) -> List[torch.Tensor]:
110-
r"""Partition tensor into blocks."""
111-
if x.shape != self.shape:
112-
raise ValueError(f'self._shape != x.shape ({self.shape} vs {x.shape})')
113-
114-
tensors: List[torch.Tensor] = [x]
115-
for i, sizes in self.split_sizes:
116-
tensors_local: List[torch.Tensor] = []
117-
for t in tensors:
118-
tensors_local.extend(torch.split(t, list(sizes), dim=i))
119-
tensors = tensors_local
120-
return tensors
121-
122-
def merge_partitions(self, partitions: List[torch.Tensor]) -> torch.Tensor:
123-
r"""Merge partitions back to original shape."""
124-
for i, indices in reversed(self.splits):
125-
n: int = len(indices) + 1
126-
127-
partitions: List[torch.Tensor] = [
128-
torch.cat(partitions[idx:idx + n], axis=i) for idx in range(0, len(partitions), n) # fmt: skip
129-
]
130-
131-
# if len(partitions) == 1:
132-
# raise ValueError('[-] num of partitions is 1')
133-
134-
return partitions[0]
135-
136-
137-
class PreConditioner:
138-
r"""Compute statistics/shape from gradients for preconditioning."""
139-
140-
def __init__(
141-
self,
142-
var: torch.Tensor,
143-
beta2: float,
144-
inverse_exponent_override: int,
145-
block_size: int,
146-
shape_interpretation: bool,
147-
matrix_eps: float,
148-
):
149-
self.beta2 = beta2
150-
self.inverse_exponent_override = inverse_exponent_override
151-
self.matrix_eps = matrix_eps
152-
153-
self.original_shape: List[int] = var.shape
154-
self.transformed_shape: List[int] = var.shape
155-
if shape_interpretation:
156-
self.transformed_shape = merge_small_dims(self.original_shape, block_size)
157-
158-
self.statistics: List[torch.Tensor] = []
159-
self.pre_conditioners: List[torch.Tensor] = []
160-
if len(self.transformed_shape) > 1:
161-
reshaped_var = torch.reshape(var, self.transformed_shape)
162-
self.partitioner = BlockPartitioner(reshaped_var, block_size)
163-
164-
shapes = self.partitioner.shapes_for_pre_conditioners()
165-
self.statistics = [self.matrix_eps * torch.eye(s[0], device=var.device) for s in shapes]
166-
self.pre_conditioners = [torch.eye(s[0], device=var.device) for s in shapes]
167-
168-
def add_statistics(self, grad: torch.Tensor):
169-
r"""Compute statistics from gradients and add to the correct state entries.
170-
171-
:param grad: torch.Tensor. gradient to compute statistics from.
172-
"""
173-
if not self.statistics:
174-
return
175-
176-
reshaped_grad: torch.Tensor = torch.reshape(grad, self.transformed_shape)
177-
partitioned_grads: List[torch.Tensor] = self.partitioner.partition(reshaped_grad)
178-
179-
w2: float = 1.0 if self.beta2 == 1.0 else (1.0 - self.beta2)
180-
rank: int = len(self.transformed_shape)
181-
for j, partitioned_grad in enumerate(partitioned_grads):
182-
for i in range(rank):
183-
axes: List[int] = list(range(i)) + list(range(i + 1, rank))
184-
stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, [axes, axes])
185-
self.statistics[j * rank + i].mul_(self.beta2).add_(stat, alpha=w2)
186-
187-
def exponent_for_pre_conditioner(self) -> int:
188-
r"""Returns exponent to use for inverse-pth root M^{-1/p}."""
189-
return (
190-
self.inverse_exponent_override if self.inverse_exponent_override > 0 else 2 * len(self.transformed_shape)
191-
)
192-
193-
def compute_pre_conditioners(self):
194-
r"""Compute L^{-1/exp} for each stats matrix L."""
195-
exp: int = self.exponent_for_pre_conditioner()
196-
for i, stat in enumerate(self.statistics):
197-
self.pre_conditioners[i] = compute_power(stat, exp, ridge_epsilon=self.matrix_eps)
198-
199-
def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
200-
r"""Precondition the gradient.
201-
202-
:param grad: torch.Tensor. a gradient tensor to precondition.
203-
"""
204-
if not self.pre_conditioners:
205-
return grad
206-
207-
reshaped_grad = torch.reshape(grad, self.transformed_shape)
208-
partitioned_grads = self.partitioner.partition(reshaped_grad)
209-
210-
num_splits: int = self.partitioner.num_splits
211-
pre_cond_partitioned_grads: List[torch.Tensor] = []
212-
for i, partitioned_grad in enumerate(partitioned_grads):
213-
pre_conditioners_for_grad = self.pre_conditioners[i * num_splits:(i + 1) * num_splits] # fmt: skip
214-
rank: int = len(partitioned_grad.shape)
215-
216-
pre_cond_grad = partitioned_grad
217-
for j in range(rank):
218-
pre_cond_grad = torch.tensordot(pre_cond_grad, pre_conditioners_for_grad[j], [[0], [0]])
219-
220-
pre_cond_partitioned_grads.append(pre_cond_grad)
221-
222-
merged_grad = self.partitioner.merge_partitions(pre_cond_partitioned_grads)
223-
224-
return torch.reshape(merged_grad, self.original_shape)
7+
from pytorch_optimizer.optimizer.shampoo_utils import AdagradGraft, Graft, LayerWiseGrafting, PreConditioner, SGDGraft
2258

2269

22710
class Shampoo(Optimizer, BaseOptimizer):

0 commit comments

Comments
 (0)