Skip to content

Commit beea7da

Browse files
committed
update: AdamMini optimizer
1 parent 337bb0e commit beea7da

File tree

5 files changed

+17
-1
lines changed

5 files changed

+17
-1
lines changed

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from pytorch_optimizer.optimizer.adahessian import AdaHessian
4141
from pytorch_optimizer.optimizer.adai import Adai
4242
from pytorch_optimizer.optimizer.adalite import Adalite
43+
from pytorch_optimizer.optimizer.adam_mini import AdamMini
4344
from pytorch_optimizer.optimizer.adamax import AdaMax
4445
from pytorch_optimizer.optimizer.adamod import AdaMod
4546
from pytorch_optimizer.optimizer.adamp import AdamP
@@ -203,6 +204,7 @@
203204
GrokFastAdamW,
204205
Kate,
205206
StableAdamW,
207+
AdamMini,
206208
]
207209
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
208210

pytorch_optimizer/base/optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ def validate_learning_rate(learning_rate: Optional[float]) -> None:
258258
if learning_rate is not None and learning_rate < 0.0:
259259
raise NegativeLRError(learning_rate)
260260

261+
@staticmethod
262+
def validate_mod(x: int, y: int) -> None:
263+
if x % y != 0:
264+
raise ValueError(f'[-] {x} must be divisible by {y}')
265+
261266
def validate_betas(self, betas: BETAS) -> None:
262267
if betas[0] is not None:
263268
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')

tests/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
'fadam',
135135
'grokfastadamw',
136136
'stableadamw',
137+
'adammini',
137138
]
138139

139140
VALID_LR_SCHEDULER_NAMES: List[str] = [

tests/test_load_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3838

3939

4040
def test_get_supported_optimizers():
41-
assert len(get_supported_optimizers()) == 70
41+
assert len(get_supported_optimizers()) == 71
4242

4343

4444
def test_get_supported_lr_schedulers():

tests/test_optimizers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,11 @@ def test_stableadamw_optimizer(environment):
650650
optimizer = load_optimizer('StableAdamW')(model.parameters())
651651
optimizer.reset()
652652
optimizer.step()
653+
654+
655+
def test_adam_mini_optimizer(environment):
656+
_, model, _ = environment
657+
658+
optimizer = load_optimizer('AdamMini')(model)
659+
optimizer.reset()
660+
optimizer.step()

0 commit comments

Comments
 (0)