Skip to content

Commit b1b5ed4

Browse files
authored
Merge pull request #228 from kozistr/feature/galore-optimizer
[Feature] Implement GaLore optimizer
2 parents 523f140 + f2d6f14 commit b1b5ed4

File tree

12 files changed

+489
-204
lines changed

12 files changed

+489
-204
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **62 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **63 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -160,6 +160,7 @@ supported_optimizers = get_supported_optimizers()
160160
| CAME | *Confidence-guided Adaptive Memory Efficient Optimization* | [github](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME) | <https://aclanthology.org/2023.acl-long.243/> | [cite](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME#citation) |
161161
| WSAM | *Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term* | [github](https://github.com/intelligent-machine-learning/dlrover/blob/master/atorch/atorch/optimizers/wsam.py) | <https://arxiv.org/abs/2305.15817> | [cite](https://github.com/intelligent-machine-learning/dlrover) |
162162
| Aida | *A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range* | [github](https://github.com/guoqiang-zhang-x/Aida-Optimizer) | <https://arxiv.org/abs/2203.13273> | [cite](https://github.com/guoqiang-zhang-x/Aida-Optimizer?tab=readme-ov-file#1-brief-description-of-aida) |
163+
| GaLore | *Memory-Efficient LLM Training by Gradient Low-Rank Projection* | [github](https://github.com/jiaweizzhao/GaLore) | <https://arxiv.org/abs/2403.03507> | [cite](https://github.com/jiaweizzhao/GaLore/tree/master?tab=readme-ov-file#citation) |
163164

164165
## Supported LR Scheduler
165166

docs/changelogs/v3.0.0.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,28 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164)
1010
* [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273)
1111
* Implement `WSAM` optimizer. (#213, #216)
1212
* [Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term](https://arxiv.org/abs/2305.15817)
13+
* Implement `GaLore` optimizer. (#224, #228)
14+
* [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
1315

14-
## Dependency
16+
### Fix
17+
18+
* Fix SRMM to allow operation beyond memory_length. (#227)
19+
20+
### Dependency
1521

1622
* Drop `Python 3.7` support officially. (#221)
1723
* Please check the [README](https://github.com/kozistr/pytorch_optimizer?tab=readme-ov-file#getting-started).
24+
* Update `bitsandbytes` to `0.43.0`. (#228)
1825

19-
## Docs
26+
### Docs
2027

2128
* Add missing parameters in `Ranger21 optimizer` document. (#214, #215)
2229
* Fix `WSAM` optimizer paper link. (#219)
2330

24-
### Contributions
31+
## Contributions
2532

26-
thanks to @sdbds
33+
thanks to @sdbds, @i404788
2734

28-
### Diff
35+
## Diff
2936

3037
[2.12.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.12.0...v3.0.0)

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@
132132
:docstring:
133133
:members:
134134

135+
::: pytorch_optimizer.GaLoreProjector
136+
:docstring:
137+
:members:
138+
135139
::: pytorch_optimizer.centralize_gradient
136140
:docstring:
137141
:members:

poetry.lock

Lines changed: 140 additions & 173 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
1515
"AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", "DAdaptAdam",
16-
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO",
17-
"Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM",
18-
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
19-
"SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
20-
"LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
16+
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS", "Lamb", "Lion",
17+
"LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam",
18+
"QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
19+
"SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
20+
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
2121
]
2222
classifiers = [
2323
"License :: OSI Approved :: Apache Software License",
@@ -45,14 +45,14 @@ classifiers = [
4545
python = ">=3.8,<4.0.0"
4646
numpy = { version = "*", python = ">=3.8" }
4747
torch = { version = ">=1.10", python = ">=3.8", source = "torch" }
48-
bitsandbytes = { version = "^0.42", optional = true }
48+
bitsandbytes = { version = "^0.43", optional = true }
4949

5050
[tool.poetry.dev-dependencies]
5151
isort = { version = "^5", python = ">=3.8" }
5252
black = { version = "^24", python = ">=3.8"}
53-
ruff = "^0.3"
54-
pytest = "^8"
55-
pytest-cov = "^4"
53+
ruff = "*"
54+
pytest = "*"
55+
pytest-cov = "*"
5656

5757
[tool.poetry.extras]
5858
bitsandbytes = ["bitsandbytes"]

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
5656
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
5757
from pytorch_optimizer.optimizer.fromage import Fromage
58+
from pytorch_optimizer.optimizer.galore import GaLore, GaLoreProjector
5859
from pytorch_optimizer.optimizer.gc import centralize_gradient
5960
from pytorch_optimizer.optimizer.gravity import Gravity
6061
from pytorch_optimizer.optimizer.lamb import Lamb
@@ -182,6 +183,7 @@
182183
CAME,
183184
DAdaptLion,
184185
Aida,
186+
GaLore,
185187
]
186188
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
187189

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import math
2+
from typing import Literal, Optional, Tuple, Union
3+
4+
import torch
5+
from torch.optim.optimizer import Optimizer
6+
7+
from pytorch_optimizer.base.exception import NoSparseGradientError
8+
from pytorch_optimizer.base.optimizer import BaseOptimizer
9+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
10+
11+
PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full']
12+
13+
14+
class GaLoreProjector:
15+
r"""Memory-Efficient LLM Training by Gradient Low-Rank Projection.
16+
17+
:param rank: int. low rank to project.
18+
:param update_proj_gap: int. num steps to update the projection.
19+
:param scale: float. scale factor.
20+
:param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are
21+
supported.
22+
"""
23+
24+
def __init__(
25+
self, rank: int = 128, update_proj_gap: int = 50, scale: float = 1.0, projection_type: PROJECTION_TYPE = 'std'
26+
):
27+
self.rank = rank
28+
self.update_proj_gap = update_proj_gap
29+
self.scale = scale
30+
self.projection_type = projection_type
31+
32+
self.ortho_matrix: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None
33+
34+
@staticmethod
35+
def get_orthogonal_matrix(
36+
weights: torch.Tensor, rank: int, projection_type: str
37+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
38+
if projection_type not in {'right', 'left', 'full'}:
39+
raise ValueError('projection_type should be one of left, right or full')
40+
41+
original_type = weights.data.dtype
42+
original_device = weights.data.device
43+
is_float: bool = original_type == torch.float
44+
45+
u, s, vh = torch.linalg.svd(weights if is_float else weights.float(), full_matrices=False)
46+
47+
if projection_type == 'right':
48+
b = vh[:rank, :]
49+
return b if is_float else b.to(original_device).type(original_type)
50+
if projection_type == 'left':
51+
a = u[:, :rank]
52+
return a if is_float else a.to(original_device).type(original_type)
53+
54+
a = u[:, :rank]
55+
b = vh[:rank, :]
56+
57+
return (
58+
(a, b)
59+
if is_float
60+
else (a.to(original_device).type(original_type), b.to(original_device).type(original_type))
61+
)
62+
63+
def get_low_rank_grad_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
64+
if grad.shape[0] >= grad.shape[1]:
65+
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
66+
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
67+
return torch.matmul(grad, self.ortho_matrix.t())
68+
69+
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
70+
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
71+
72+
return torch.matmul(self.ortho_matrix.t(), grad)
73+
74+
def get_low_rank_grad_reverse_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
75+
if grad.shape[0] >= grad.shape[1]:
76+
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
77+
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
78+
return torch.matmul(self.ortho_matrix.t(), grad)
79+
80+
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
81+
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
82+
83+
return torch.matmul(grad, self.ortho_matrix.t())
84+
85+
def get_low_rank_grad_right(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
86+
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
87+
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
88+
return torch.matmul(grad, self.ortho_matrix.t())
89+
90+
def get_low_rank_grad_left(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
91+
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
92+
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
93+
return torch.matmul(self.ortho_matrix.t(), grad)
94+
95+
def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
96+
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
97+
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full')
98+
return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t()
99+
100+
def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
101+
if self.projection_type == 'std':
102+
return self.get_low_rank_grad_std(full_rank_grad, steps)
103+
if self.projection_type == 'reverse_std':
104+
return self.get_low_rank_grad_reverse_std(full_rank_grad, steps)
105+
if self.projection_type == 'right':
106+
return self.get_low_rank_grad_right(full_rank_grad, steps)
107+
if self.projection_type == 'left':
108+
return self.get_low_rank_grad_left(full_rank_grad, steps)
109+
if self.projection_type == 'full':
110+
return self.get_low_rank_grad_full(full_rank_grad, steps)
111+
raise NotImplementedError
112+
113+
def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
114+
if self.projection_type == 'std':
115+
return (
116+
torch.matmul(low_rank_grad, self.ortho_matrix)
117+
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]
118+
else torch.matmul(self.ortho_matrix, low_rank_grad)
119+
) * self.scale
120+
if self.projection_type == 'reverse_std':
121+
return (
122+
torch.matmul(self.ortho_matrix, low_rank_grad.t())
123+
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]
124+
else torch.matmul(low_rank_grad, self.ortho_matrix.t())
125+
) * self.scale
126+
if self.projection_type == 'right':
127+
return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale
128+
if self.projection_type == 'left':
129+
return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale
130+
if self.projection_type == 'full':
131+
return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale
132+
133+
raise NotImplementedError
134+
135+
136+
class GaLore(Optimizer, BaseOptimizer):
137+
r"""AdamW optimizer with GaLore projector.
138+
139+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
140+
:param lr: float. learning rate.
141+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
142+
:param weight_decay: float. weight decay (L2 penalty).
143+
:param eps: float. term added to the denominator to improve numerical stability.
144+
"""
145+
146+
def __init__(
147+
self,
148+
params: PARAMETERS,
149+
lr: float = 1e-3,
150+
betas: BETAS = (0.9, 0.999),
151+
weight_decay: float = 0.0,
152+
eps: float = 1e-6,
153+
**kwargs,
154+
):
155+
self.validate_learning_rate(lr)
156+
self.validate_betas(betas)
157+
self.validate_non_negative(weight_decay, 'weight_decay')
158+
self.validate_non_negative(eps, 'eps')
159+
160+
defaults: DEFAULTS = {
161+
'lr': lr,
162+
'betas': betas,
163+
'weight_decay': weight_decay,
164+
'eps': eps,
165+
**kwargs,
166+
}
167+
168+
super().__init__(params, defaults)
169+
170+
def __str__(self) -> str:
171+
return 'GaLore'
172+
173+
@torch.no_grad()
174+
def reset(self):
175+
for group in self.param_groups:
176+
for p in group['params']:
177+
state = self.state[p]
178+
179+
state['exp_avg'] = torch.zeros_like(p)
180+
state['exp_avg_sq'] = torch.zeros_like(p)
181+
182+
@torch.no_grad()
183+
def step(self, closure: CLOSURE = None) -> LOSS:
184+
loss: LOSS = None
185+
if closure is not None:
186+
with torch.enable_grad():
187+
loss = closure()
188+
189+
for group in self.param_groups:
190+
if 'step' in group:
191+
group['step'] += 1
192+
else:
193+
group['step'] = 1
194+
195+
beta1, beta2 = group['betas']
196+
197+
bias_correction1: float = 1.0 - beta1 ** group['step']
198+
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
199+
200+
step_size: float = group['lr'] * bias_correction2_sq / bias_correction1
201+
202+
for p in group['params']:
203+
if p.grad is None:
204+
continue
205+
206+
grad = p.grad
207+
if grad.is_sparse:
208+
raise NoSparseGradientError(str(self))
209+
210+
state = self.state[p]
211+
212+
if len(state) == 0:
213+
state['exp_avg'] = torch.zeros_like(p)
214+
state['exp_avg_sq'] = torch.zeros_like(p)
215+
216+
if 'rank' in group and p.dim() > 1:
217+
if 'projector' not in state:
218+
state['projector'] = GaLoreProjector(
219+
rank=group['rank'],
220+
update_proj_gap=group['update_proj_gap'],
221+
scale=group['scale'],
222+
projection_type=group['projection_type'],
223+
)
224+
225+
grad = state['projector'].project(grad, group['step'])
226+
227+
self.apply_weight_decay(
228+
p=p,
229+
grad=None,
230+
lr=group['lr'],
231+
weight_decay=group['weight_decay'],
232+
weight_decouple=True,
233+
fixed_decay=False,
234+
)
235+
236+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
237+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
238+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
239+
240+
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
241+
242+
norm_grad = exp_avg / de_nom
243+
244+
if 'rank' in group and p.dim() > 1:
245+
norm_grad = state['projector'].project_back(norm_grad)
246+
247+
p.add_(norm_grad, alpha=-step_size)
248+
249+
return loss

0 commit comments

Comments
 (0)