|
1 | | -import itertools |
2 | | -from enum import IntEnum |
3 | | -from typing import List, Tuple |
4 | | - |
5 | | -import numpy as np |
6 | 1 | import torch |
7 | 2 | from torch.optim.optimizer import Optimizer |
8 | 3 |
|
9 | 4 | from pytorch_optimizer.base.exception import NoSparseGradientError |
10 | 5 | from pytorch_optimizer.base.optimizer import BaseOptimizer |
11 | 6 | from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS |
12 | | -from pytorch_optimizer.optimizer.utils import compute_power, merge_small_dims |
13 | | - |
14 | | - |
15 | | -class LayerWiseGrafting(IntEnum): |
16 | | - r"""layer-wise grafting |
17 | | - Grafting is a technique to fix the layer-wise scale of Shampoo optimizer. |
18 | | - https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This |
19 | | - allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad |
20 | | - is already well tuned. Grafting onto Shampoo means take the Shampoo direction, |
21 | | - but use the step magnitude from the grafted optimizer such as Adagrad or SGD. |
22 | | - """ |
23 | | - NONE = 0 |
24 | | - SGD = 1 |
25 | | - ADAGRAD = 2 |
26 | | - |
27 | | - |
28 | | -class Graft: |
29 | | - r"""Base class to perform grafting onto Shampoo. This class does no grafting.""" |
30 | | - |
31 | | - def __init__(self, *args): |
32 | | - pass |
33 | | - |
34 | | - def add_statistics(self, grad: torch.Tensor): |
35 | | - pass |
36 | | - |
37 | | - def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor: |
38 | | - return grad |
39 | | - |
40 | | - def update_momentum(self, update: torch.Tensor, unused_beta1: float) -> torch.Tensor: # noqa: ARG002 |
41 | | - return update |
42 | | - |
43 | | - |
44 | | -class SGDGraft(Graft): |
45 | | - r"""Graft using SGD + momentum. momentum maintains an exponentially weighted moving average of gradients.""" |
46 | | - |
47 | | - def __init__(self, var: torch.Tensor): |
48 | | - super().__init__(var) |
49 | | - self.momentum: torch.Tensor = torch.zeros_like(var, device=var.device) |
50 | | - |
51 | | - def update_momentum(self, update: torch.Tensor, beta1: float) -> torch.Tensor: |
52 | | - self.momentum.mul_(beta1).add_(update) |
53 | | - return self.momentum |
54 | | - |
55 | | - |
56 | | -class AdagradGraft(SGDGraft): |
57 | | - r"""Graft using Adagrad. Essentially an implementation of Adagrad with momentum.""" |
58 | | - |
59 | | - def __init__(self, var: torch.Tensor, diagonal_eps: float): |
60 | | - super().__init__(var) |
61 | | - self.diagonal_eps = diagonal_eps |
62 | | - self.statistics: torch.Tensor = torch.zeros_like(var, device=var.device) |
63 | | - |
64 | | - def add_statistics(self, grad: torch.Tensor): |
65 | | - self.statistics.add_(grad.pow(2)) |
66 | | - |
67 | | - def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor: |
68 | | - return grad / (torch.sqrt(self.statistics) + self.diagonal_eps) |
69 | | - |
70 | | - |
71 | | -class BlockPartitioner: |
72 | | - r"""Partitions a tensor into smaller tensors for preconditioning. |
73 | | - For example, if a variable has shape (4096, 512), we might split the 4096 into 4 blocks, |
74 | | - so we effectively have 4 variables of size (1024, 512) each. |
75 | | -
|
76 | | - :param var: torch.Tensor. tensor variable. |
77 | | - :param block_size: int. block size. |
78 | | - """ |
79 | | - |
80 | | - def __init__(self, var: torch.Tensor, block_size: int): |
81 | | - self.shape: List[int] = var.shape |
82 | | - self.splits: List[Tuple[int, np.ndarray]] = [] |
83 | | - self.split_sizes: List[Tuple[int, np.ndarray]] = [] |
84 | | - |
85 | | - split_sizes: List[np.ndarray] = [] |
86 | | - |
87 | | - # We split var into smaller blocks. Here we store the metadata to make that split. |
88 | | - for i, d in enumerate(self.shape): |
89 | | - if 0 < block_size < d: |
90 | | - # d - 1, otherwise split appends a 0-size array. |
91 | | - num_split: int = (d - 1) // block_size |
92 | | - indices = (np.arange(num_split, dtype=np.int32) + 1) * block_size |
93 | | - sizes = np.ones(num_split + 1, dtype=np.int32) * block_size |
94 | | - sizes[-1] = d - indices[-1] |
95 | | - self.splits.append((i, indices)) |
96 | | - self.split_sizes.append((i, sizes)) |
97 | | - split_sizes.append(sizes) |
98 | | - else: |
99 | | - split_sizes.append(np.array([d], dtype=np.int32)) |
100 | | - |
101 | | - self.num_splits: int = len(split_sizes) |
102 | | - self.pre_conditioner_shapes: List[List[int]] = [] |
103 | | - for t in itertools.product(*split_sizes): |
104 | | - self.pre_conditioner_shapes.extend([[d, d] for d in t]) |
105 | | - |
106 | | - def shapes_for_pre_conditioners(self) -> List[List[int]]: |
107 | | - return self.pre_conditioner_shapes |
108 | | - |
109 | | - def partition(self, x: torch.Tensor) -> List[torch.Tensor]: |
110 | | - r"""Partition tensor into blocks.""" |
111 | | - if x.shape != self.shape: |
112 | | - raise ValueError(f'self._shape != x.shape ({self.shape} vs {x.shape})') |
113 | | - |
114 | | - tensors: List[torch.Tensor] = [x] |
115 | | - for i, sizes in self.split_sizes: |
116 | | - tensors_local: List[torch.Tensor] = [] |
117 | | - for t in tensors: |
118 | | - tensors_local.extend(torch.split(t, list(sizes), dim=i)) |
119 | | - tensors = tensors_local |
120 | | - return tensors |
121 | | - |
122 | | - def merge_partitions(self, partitions: List[torch.Tensor]) -> torch.Tensor: |
123 | | - r"""Merge partitions back to original shape.""" |
124 | | - for i, indices in reversed(self.splits): |
125 | | - n: int = len(indices) + 1 |
126 | | - |
127 | | - partitions: List[torch.Tensor] = [ |
128 | | - torch.cat(partitions[idx:idx + n], axis=i) for idx in range(0, len(partitions), n) # fmt: skip |
129 | | - ] |
130 | | - |
131 | | - # if len(partitions) == 1: |
132 | | - # raise ValueError('[-] num of partitions is 1') |
133 | | - |
134 | | - return partitions[0] |
135 | | - |
136 | | - |
137 | | -class PreConditioner: |
138 | | - r"""Compute statistics/shape from gradients for preconditioning.""" |
139 | | - |
140 | | - def __init__( |
141 | | - self, |
142 | | - var: torch.Tensor, |
143 | | - beta2: float, |
144 | | - inverse_exponent_override: int, |
145 | | - block_size: int, |
146 | | - shape_interpretation: bool, |
147 | | - matrix_eps: float, |
148 | | - ): |
149 | | - self.beta2 = beta2 |
150 | | - self.inverse_exponent_override = inverse_exponent_override |
151 | | - self.matrix_eps = matrix_eps |
152 | | - |
153 | | - self.original_shape: List[int] = var.shape |
154 | | - self.transformed_shape: List[int] = var.shape |
155 | | - if shape_interpretation: |
156 | | - self.transformed_shape = merge_small_dims(self.original_shape, block_size) |
157 | | - |
158 | | - self.statistics: List[torch.Tensor] = [] |
159 | | - self.pre_conditioners: List[torch.Tensor] = [] |
160 | | - if len(self.transformed_shape) > 1: |
161 | | - reshaped_var = torch.reshape(var, self.transformed_shape) |
162 | | - self.partitioner = BlockPartitioner(reshaped_var, block_size) |
163 | | - |
164 | | - shapes = self.partitioner.shapes_for_pre_conditioners() |
165 | | - self.statistics = [self.matrix_eps * torch.eye(s[0], device=var.device) for s in shapes] |
166 | | - self.pre_conditioners = [torch.eye(s[0], device=var.device) for s in shapes] |
167 | | - |
168 | | - def add_statistics(self, grad: torch.Tensor): |
169 | | - r"""Compute statistics from gradients and add to the correct state entries. |
170 | | -
|
171 | | - :param grad: torch.Tensor. gradient to compute statistics from. |
172 | | - """ |
173 | | - if not self.statistics: |
174 | | - return |
175 | | - |
176 | | - reshaped_grad: torch.Tensor = torch.reshape(grad, self.transformed_shape) |
177 | | - partitioned_grads: List[torch.Tensor] = self.partitioner.partition(reshaped_grad) |
178 | | - |
179 | | - w2: float = 1.0 if self.beta2 == 1.0 else (1.0 - self.beta2) |
180 | | - rank: int = len(self.transformed_shape) |
181 | | - for j, partitioned_grad in enumerate(partitioned_grads): |
182 | | - for i in range(rank): |
183 | | - axes: List[int] = list(range(i)) + list(range(i + 1, rank)) |
184 | | - stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, [axes, axes]) |
185 | | - self.statistics[j * rank + i].mul_(self.beta2).add_(stat, alpha=w2) |
186 | | - |
187 | | - def exponent_for_pre_conditioner(self) -> int: |
188 | | - r"""Returns exponent to use for inverse-pth root M^{-1/p}.""" |
189 | | - return ( |
190 | | - self.inverse_exponent_override if self.inverse_exponent_override > 0 else 2 * len(self.transformed_shape) |
191 | | - ) |
192 | | - |
193 | | - def compute_pre_conditioners(self): |
194 | | - r"""Compute L^{-1/exp} for each stats matrix L.""" |
195 | | - exp: int = self.exponent_for_pre_conditioner() |
196 | | - for i, stat in enumerate(self.statistics): |
197 | | - self.pre_conditioners[i] = compute_power(stat, exp, ridge_epsilon=self.matrix_eps) |
198 | | - |
199 | | - def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor: |
200 | | - r"""Precondition the gradient. |
201 | | -
|
202 | | - :param grad: torch.Tensor. a gradient tensor to precondition. |
203 | | - """ |
204 | | - if not self.pre_conditioners: |
205 | | - return grad |
206 | | - |
207 | | - reshaped_grad = torch.reshape(grad, self.transformed_shape) |
208 | | - partitioned_grads = self.partitioner.partition(reshaped_grad) |
209 | | - |
210 | | - num_splits: int = self.partitioner.num_splits |
211 | | - pre_cond_partitioned_grads: List[torch.Tensor] = [] |
212 | | - for i, partitioned_grad in enumerate(partitioned_grads): |
213 | | - pre_conditioners_for_grad = self.pre_conditioners[i * num_splits:(i + 1) * num_splits] # fmt: skip |
214 | | - rank: int = len(partitioned_grad.shape) |
215 | | - |
216 | | - pre_cond_grad = partitioned_grad |
217 | | - for j in range(rank): |
218 | | - pre_cond_grad = torch.tensordot(pre_cond_grad, pre_conditioners_for_grad[j], [[0], [0]]) |
219 | | - |
220 | | - pre_cond_partitioned_grads.append(pre_cond_grad) |
221 | | - |
222 | | - merged_grad = self.partitioner.merge_partitions(pre_cond_partitioned_grads) |
223 | | - |
224 | | - return torch.reshape(merged_grad, self.original_shape) |
| 7 | +from pytorch_optimizer.optimizer.shampoo_utils import AdagradGraft, Graft, LayerWiseGrafting, PreConditioner, SGDGraft |
225 | 8 |
|
226 | 9 |
|
227 | 10 | class Shampoo(Optimizer, BaseOptimizer): |
|
0 commit comments