Skip to content

Commit 56c2be0

Browse files
authored
[Update] Expose the parameters of the GaloreProjector to the init params of the Conda optimizer (#444)
* build(deps): dev deps * update: expose Galore parameters * docs: v3.8.1 changelog
1 parent 63f6623 commit 56c2be0

File tree

4 files changed

+73
-43
lines changed

4 files changed

+73
-43
lines changed

docs/changelogs/v3.8.1.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
* Implement `Conda` optimizer. (#440, #441)
1212
* [Conda: Column-Normalized Adam for Training Large Language Models Faster](https://arxiv.org/abs/2509.24218)
1313

14+
### Update
15+
16+
* Accept the `GaloreProjector` parameters in the init params of the `Conda` optimizer. (#443, #444)
17+
1418
### Bug
1519

1620
* Fix NaN problem when grad norm is zero in StableSPAM optimizer. (#431)

poetry.lock

Lines changed: 47 additions & 37 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/optimizer/conda.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytorch_optimizer.base.exception import NoComplexParameterError, NoSparseGradientError
66
from pytorch_optimizer.base.optimizer import BaseOptimizer
77
from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS
8-
from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector
8+
from pytorch_optimizer.optimizer.galore_utils import PROJECTION_TYPE, GaLoreProjector
99

1010

1111
class Conda(BaseOptimizer):
@@ -15,6 +15,9 @@ class Conda(BaseOptimizer):
1515
:param lr: float. learning rate.
1616
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
1717
:param weight_decay: float. weight decay (L2 penalty).
18+
:param update_proj_gap: int. update projection gap.
19+
:param scale: float. galore projection scaling factor.
20+
:param projection_type: PROJECTION_TYPE. the type of the projection.
1821
:param eps: float. term added to the denominator to improve numerical stability.
1922
:param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
2023
"""
@@ -25,18 +28,31 @@ def __init__(
2528
lr: float = 1e-3,
2629
betas: BETAS = (0.9, 0.999),
2730
weight_decay: float = 0.0,
31+
update_proj_gap: int = 2000,
32+
scale: float = 1.0,
33+
projection_type: PROJECTION_TYPE = 'std',
2834
eps: float = 1e-8,
2935
maximize: bool = False,
3036
**kwargs,
3137
):
3238
self.validate_learning_rate(lr)
3339
self.validate_betas(betas)
40+
self.validate_positive(update_proj_gap, 'update_proj_gap')
3441
self.validate_non_negative(weight_decay, 'weight_decay')
3542
self.validate_non_negative(eps, 'eps')
3643

3744
self.maximize = maximize
3845

39-
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}
46+
defaults: DEFAULTS = {
47+
'lr': lr,
48+
'betas': betas,
49+
'weight_decay': weight_decay,
50+
'update_proj_gap': update_proj_gap,
51+
'scale': scale,
52+
'projection_type': projection_type,
53+
'eps': eps,
54+
**kwargs,
55+
}
4056
super().__init__(params, defaults)
4157

4258
def __str__(self) -> str:
@@ -94,7 +110,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
94110
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
95111
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
96112

97-
if 'update_proj_gap' in group and p.dim() == 2:
113+
if p.dim() == 2:
98114
if 'projector' not in state:
99115
state['projector'] = GaLoreProjector(
100116
rank=None,
@@ -112,7 +128,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
112128

113129
norm_grad = exp_avg / de_nom
114130

115-
if 'update_proj_gap' in group and p.dim() == 2:
131+
if p.dim() == 2:
116132
norm_grad = state['projector'].project_back(norm_grad)
117133

118134
p.add_(norm_grad, alpha=-step_size)

requirements-dev.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ pyright==1.1.406 ; python_version >= "3.8"
3232
pytest-cov==5.0.0 ; python_version == "3.8"
3333
pytest-cov==6.3.0 ; python_version >= "3.9"
3434
pytest==8.3.5 ; python_version >= "3.8"
35-
pytokens==0.1.10 ; python_version >= "3.9"
35+
pytokens==0.2.0 ; python_version >= "3.9"
3636
ruff==0.12.12 ; python_version >= "3.8"
3737
setuptools==80.9.0 ; python_version >= "3.12"
3838
sympy==1.13.3 ; python_version == "3.8"
3939
sympy==1.14.0 ; python_version >= "3.9"
40-
tomli==2.2.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
40+
tomli==2.3.0 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
4141
torch==2.4.1+cpu ; python_version == "3.8"
4242
torch==2.8.0+cpu ; python_version >= "3.9"
4343
typing-extensions==4.13.2 ; python_version == "3.8"

0 commit comments

Comments
 (0)