Skip to content

Commit 5974aef

Browse files
authored
fix OrthoGrad state management bug (#327)
* fix state management bug * run black with proper args * execute ruff
1 parent 55c3553 commit 5974aef

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

pytorch_optimizer/optimizer/orthograd.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
from collections import defaultdict
12
from typing import Callable, Dict
23

34
import torch
45
from torch.optim import Optimizer
56

67
from pytorch_optimizer.base.optimizer import BaseOptimizer
7-
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS
8+
from pytorch_optimizer.base.types import (
9+
CLOSURE,
10+
DEFAULTS,
11+
LOSS,
12+
OPTIMIZER_INSTANCE_OR_CLASS,
13+
STATE,
14+
)
815

916

1017
class OrthoGrad(BaseOptimizer):
@@ -20,6 +27,8 @@ def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
2027
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
2128
self.eps: float = 1e-30
2229

30+
self.state: STATE = defaultdict(dict)
31+
2332
if isinstance(optimizer, Optimizer):
2433
self.optimizer = optimizer
2534
elif 'params' in kwargs:

0 commit comments

Comments
 (0)