Skip to content

Commit 5b6d5f6

Browse files
committed
refactor: AdaGradGraft
1 parent aad3c60 commit 5b6d5f6

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

pytorch_optimizer/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@
4141
from pytorch_optimizer.optimizer.sam import SAM
4242
from pytorch_optimizer.optimizer.sgdp import SGDP
4343
from pytorch_optimizer.optimizer.shampoo import Shampoo
44+
from pytorch_optimizer.optimizer.shampoo_utils import (
45+
AdaGradGraft,
46+
BlockPartitioner,
47+
Graft,
48+
LayerWiseGrafting,
49+
PreConditioner,
50+
PreConditionerType,
51+
RMSPropGraft,
52+
SGDGraft,
53+
SQRTNGraft,
54+
compute_power,
55+
matrix_power,
56+
merge_small_dims,
57+
power_iter,
58+
)
4459
from pytorch_optimizer.optimizer.utils import (
4560
clip_grad_norm,
4661
disable_running_stats,

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytorch_optimizer.base.optimizer import BaseOptimizer
66
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
77
from pytorch_optimizer.optimizer.shampoo_utils import (
8-
AdagradGraft,
8+
AdaGradGraft,
99
Graft,
1010
LayerWiseGrafting,
1111
PreConditioner,
@@ -130,7 +130,7 @@ def reset(self):
130130
self.pre_conditioner_type,
131131
)
132132
if self.graft_type == LayerWiseGrafting.ADAGRAD:
133-
state['graft'] = AdagradGraft(p, self.diagonal_eps)
133+
state['graft'] = AdaGradGraft(p, self.diagonal_eps)
134134
elif self.graft_type == LayerWiseGrafting.RMSPROP:
135135
state['graft'] = RMSPropGraft(p, self.diagonal_eps)
136136
elif self.graft_type == LayerWiseGrafting.SGD:
@@ -172,7 +172,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
172172
self.pre_conditioner_type,
173173
)
174174
if self.graft_type == LayerWiseGrafting.ADAGRAD:
175-
state['graft'] = AdagradGraft(p, self.diagonal_eps)
175+
state['graft'] = AdaGradGraft(p, self.diagonal_eps)
176176
elif self.graft_type == LayerWiseGrafting.RMSPROP:
177177
state['graft'] = RMSPropGraft(p, self.diagonal_eps)
178178
elif self.graft_type == LayerWiseGrafting.SGD:

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
6666
return torch.ones_like(grad) * torch.sign(grad)
6767

6868

69-
class AdagradGraft(SGDGraft):
69+
class AdaGradGraft(SGDGraft):
7070
r"""Graft using Adagrad. Essentially an implementation of Adagrad with momentum.
7171
7272
:param var: torch.Tensor. variable.

0 commit comments

Comments
 (0)