Skip to content

Commit cdbd9bc

Browse files
authored
Merge pull request #328 from kozistr/update/wrapper
[Update] proper property
2 parents 5974aef + 4a1b3b7 commit cdbd9bc

File tree

6 files changed

+42
-34
lines changed

6 files changed

+42
-34
lines changed

docs/changelogs/v3.3.5.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
### Change Log
2+
3+
### Fix
4+
5+
* Add the missing `state` property in `OrthoGrad` optimizer. (#326, #327)
6+
7+
### Contributions
8+
9+
thanks to @Vectorrent

poetry.lock

Lines changed: 19 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/optimizer/orthograd.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
from collections import defaultdict
21
from typing import Callable, Dict
32

43
import torch
54
from torch.optim import Optimizer
65

76
from pytorch_optimizer.base.optimizer import BaseOptimizer
8-
from pytorch_optimizer.base.types import (
9-
CLOSURE,
10-
DEFAULTS,
11-
LOSS,
12-
OPTIMIZER_INSTANCE_OR_CLASS,
13-
STATE,
14-
)
7+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS
158

169

1710
class OrthoGrad(BaseOptimizer):
@@ -27,8 +20,6 @@ def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
2720
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
2821
self.eps: float = 1e-30
2922

30-
self.state: STATE = defaultdict(dict)
31-
3223
if isinstance(optimizer, Optimizer):
3324
self.optimizer = optimizer
3425
elif 'params' in kwargs:
@@ -46,8 +37,13 @@ def __str__(self) -> str:
4637
def param_groups(self):
4738
return self.optimizer.param_groups
4839

49-
def __getstate__(self):
50-
return {'optimizer': self.optimizer}
40+
@property
41+
def state(self):
42+
return self.optimizer.state
43+
44+
@torch.no_grad()
45+
def zero_grad(self) -> None:
46+
self.optimizer.zero_grad(set_to_none=True)
5147

5248
@torch.no_grad()
5349
def reset(self):

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ platformdirs==4.3.6 ; python_version >= "3.8"
2222
pluggy==1.5.0 ; python_version >= "3.8"
2323
pytest-cov==5.0.0 ; python_version >= "3.8"
2424
pytest==8.3.4 ; python_version >= "3.8"
25-
ruff==0.9.1 ; python_version >= "3.8"
25+
ruff==0.9.2 ; python_version >= "3.8"
2626
setuptools==75.8.0 ; python_version >= "3.12"
2727
sympy==1.12.1 ; python_version == "3.8"
2828
sympy==1.13.1 ; python_version >= "3.9"

tests/test_gradients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_no_gradients(optimizer_name):
3737
sphere_loss(p1 + p3).backward(create_graph=True)
3838

3939
optimizer.step(lambda: 0.1) # for AliG optimizer
40-
if optimizer_name not in {'lookahead', 'trac'}:
40+
if optimizer_name not in {'lookahead', 'trac', 'orthograd'}:
4141
optimizer.zero_grad(set_to_none=True)
4242

4343

tests/test_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ def test_cpu_offload_optimizer():
265265

266266
def test_orthograd_name():
267267
optimizer = build_orthograd(Example().parameters())
268+
optimizer.zero_grad()
269+
268270
_ = optimizer.param_groups
269-
_ = optimizer.__getstate__()
271+
_ = optimizer.state
272+
270273
assert str(optimizer).lower() == 'orthograd'

0 commit comments

Comments
 (0)