Skip to content

Commit 75a023a

Browse files
authored
Merge pull request #101 from kozistr/feature/d-adaptation
[Feature] Implements D-Adaptation optimizers
2 parents 44c423a + a3a5557 commit 75a023a

File tree

16 files changed

+1000
-224
lines changed

16 files changed

+1000
-224
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ Supported Optimizers
114114
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
115115
| GSAM | *Surrogate Gap Guided Sharpness-Aware Minimization* | `github <https://github.com/juntang-zhuang/GSAM>`__ | `https://openreview.net/pdf?id=edONMAnhLu- <https://openreview.net/pdf?id=edONMAnhLu->`__ |
116116
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
117+
| D-Adaptation | *Learning-Rate-Free Learning by D-Adaptation* | `github <https://github.com/facebookresearch/dadaptation>`__ | `https://arxiv.org/abs/2301.07733 <https://arxiv.org/abs/2301.07733>`__ |
118+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
117119

118120
Useful Resources
119121
----------------
@@ -307,6 +309,8 @@ Citations
307309

308310
`GSAM <https://github.com/juntang-zhuang/GSAM#citation>`__
309311

312+
`D-Adaptation <https://ui.adsabs.harvard.edu/abs/2023arXiv230107733D/exportcitation>`__
313+
310314
Citation
311315
--------
312316

docs/optimizer_api.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,27 @@ GSAM
200200

201201
.. autoclass:: pytorch_optimizer.GSAM
202202
:members:
203+
204+
.. _DAdaptAdaGrad:
205+
206+
DAdaptAdaGrad
207+
-------------
208+
209+
.. autoclass:: pytorch_optimizer.DAdaptAdaGrad
210+
:members:
211+
212+
.. _DAdaptAdam:
213+
214+
DAdaptAdam
215+
----------
216+
217+
.. autoclass:: pytorch_optimizer.DAdaptAdam
218+
:members:
219+
220+
.. _DAdaptSGD:
221+
222+
DAdaptSGD
223+
---------
224+
225+
.. autoclass:: pytorch_optimizer.DAdaptSGD
226+
:members:

docs/util_api.rst

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,108 @@ disable_running_stats
6565

6666
.. autoclass:: pytorch_optimizer.disable_running_stats
6767
:members:
68+
69+
.. _LayerWiseGrafting:
70+
71+
LayerWiseGrafting
72+
-----------------
73+
74+
.. autoclass:: pytorch_optimizer.LayerWiseGrafting
75+
:members:
76+
77+
.. _Graft:
78+
79+
Graft
80+
-----
81+
82+
.. autoclass:: pytorch_optimizer.Graft
83+
:members:
84+
85+
.. _SGDGraft:
86+
87+
SGDGraft
88+
--------
89+
90+
.. autoclass:: pytorch_optimizer.SGDGraft
91+
:members:
92+
93+
.. _SQRTNGraft:
94+
95+
SQRTNGraft
96+
----------
97+
98+
.. autoclass:: pytorch_optimizer.SQRTNGraft
99+
:members:
100+
101+
.. _AdaGradGraft:
102+
103+
AdaGradGraft
104+
------------
105+
106+
.. autoclass:: pytorch_optimizer.AdaGradGraft
107+
:members:
108+
109+
.. _RMSPropGraft:
110+
111+
RMSPropGraft
112+
------------
113+
114+
.. autoclass:: pytorch_optimizer.RMSPropGraft
115+
:members:
116+
117+
.. _BlockPartitioner:
118+
119+
BlockPartitioner
120+
----------------
121+
122+
.. autoclass:: pytorch_optimizer.BlockPartitioner
123+
:members:
124+
125+
.. _PreConditionerType:
126+
127+
PreConditionerType
128+
------------------
129+
130+
.. autoclass:: pytorch_optimizer.PreConditionerType
131+
:members:
132+
133+
.. _PreConditioner:
134+
135+
PreConditioner
136+
--------------
137+
138+
.. autoclass:: pytorch_optimizer.PreConditioner
139+
:members:
140+
141+
.. _power_iter:
142+
143+
power_iter
144+
----------
145+
146+
.. autoclass:: pytorch_optimizer.power_iter
147+
:members:
148+
149+
.. _matrix_power:
150+
151+
matrix_power
152+
------------
153+
154+
.. autoclass:: pytorch_optimizer.matrix_power
155+
:members:
156+
157+
.. _compute_power:
158+
159+
compute_power
160+
-------------
161+
162+
.. autoclass:: pytorch_optimizer.compute_power
163+
:members:
164+
165+
.. _merge_small_dims:
166+
167+
merge_small_dims
168+
----------------
169+
170+
.. autoclass:: pytorch_optimizer.merge_small_dims
171+
:members:
172+
re

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.3.1"
3+
version = "2.4.0"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytorch_optimizer.optimizer.adan import Adan
2222
from pytorch_optimizer.optimizer.adapnm import AdaPNM
2323
from pytorch_optimizer.optimizer.agc import agc
24+
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptSGD
2425
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
2526
from pytorch_optimizer.optimizer.diffrgrad import DiffRGrad
2627
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
@@ -40,6 +41,21 @@
4041
from pytorch_optimizer.optimizer.sam import SAM
4142
from pytorch_optimizer.optimizer.sgdp import SGDP
4243
from pytorch_optimizer.optimizer.shampoo import Shampoo
44+
from pytorch_optimizer.optimizer.shampoo_utils import (
45+
AdaGradGraft,
46+
BlockPartitioner,
47+
Graft,
48+
LayerWiseGrafting,
49+
PreConditioner,
50+
PreConditionerType,
51+
RMSPropGraft,
52+
SGDGraft,
53+
SQRTNGraft,
54+
compute_power,
55+
matrix_power,
56+
merge_small_dims,
57+
power_iter,
58+
)
4359
from pytorch_optimizer.optimizer.utils import (
4460
clip_grad_norm,
4561
disable_running_stats,
@@ -69,6 +85,9 @@
6985
Ranger21,
7086
SGDP,
7187
Shampoo,
88+
DAdaptAdaGrad,
89+
DAdaptAdam,
90+
DAdaptSGD,
7291
]
7392
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
7493

0 commit comments

Comments
 (0)