Skip to content

Commit fe2ff95

Browse files
committed
refactor: staticmethod to utils
1 parent a3c6561 commit fe2ff95

File tree

3 files changed

+26
-49
lines changed

3 files changed

+26
-49
lines changed

pytorch_optimizer/adamp.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytorch_optimizer.base_optimizer import BaseOptimizer
88
from pytorch_optimizer.gc import centralize_gradient
99
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
10-
from pytorch_optimizer.utils import channel_view, cosine_similarity_by_view, layer_view
10+
from pytorch_optimizer.utils import projection
1111

1212

1313
class AdamP(Optimizer, BaseOptimizer):
@@ -80,28 +80,6 @@ def validate_parameters(self):
8080
self.validate_weight_decay_ratio(self.wd_ratio)
8181
self.validate_epsilon(self.eps)
8282

83-
def projection(
84-
self,
85-
p,
86-
grad,
87-
perturb: torch.Tensor,
88-
delta: float,
89-
wd_ratio: float,
90-
eps: float,
91-
) -> Tuple[torch.Tensor, float]:
92-
wd: float = 1.0
93-
expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
94-
for view_func in (channel_view, layer_view):
95-
cosine_sim = cosine_similarity_by_view(grad, p, eps, view_func)
96-
97-
if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
98-
p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)
99-
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
100-
wd = wd_ratio
101-
return perturb, wd
102-
103-
return perturb, wd
104-
10583
@torch.no_grad()
10684
def reset(self):
10785
for group in self.param_groups:
@@ -157,7 +135,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
157135

158136
wd_ratio: float = 1
159137
if len(p.shape) > 1:
160-
perturb, wd_ratio = self.projection(
138+
perturb, wd_ratio = projection(
161139
p,
162140
grad,
163141
perturb,

pytorch_optimizer/sgdp.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pytorch_optimizer.base_optimizer import BaseOptimizer
88
from pytorch_optimizer.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
9-
from pytorch_optimizer.utils import channel_view, cosine_similarity_by_view, layer_view
9+
from pytorch_optimizer.utils import projection
1010

1111

1212
class SGDP(Optimizer, BaseOptimizer):
@@ -74,29 +74,6 @@ def validate_parameters(self):
7474
self.validate_weight_decay_ratio(self.wd_ratio)
7575
self.validate_epsilon(self.eps)
7676

77-
def projection(
78-
self,
79-
p,
80-
grad,
81-
perturb: torch.Tensor,
82-
delta: float,
83-
wd_ratio: float,
84-
eps: float,
85-
) -> Tuple[torch.Tensor, float]:
86-
wd: float = 1.0
87-
expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
88-
for view_func in (channel_view, layer_view):
89-
cosine_sim = cosine_similarity_by_view(grad, p, eps, view_func)
90-
91-
if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
92-
p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)
93-
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
94-
wd = wd_ratio
95-
96-
return perturb, wd
97-
98-
return perturb, wd
99-
10077
@torch.no_grad()
10178
def reset(self):
10279
for group in self.param_groups:
@@ -137,7 +114,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
137114

138115
wd_ratio: float = 1.0
139116
if len(p.shape) > 1:
140-
d_p, wd_ratio = self.projection(
117+
d_p, wd_ratio = projection(
141118
p,
142119
grad,
143120
d_p,

pytorch_optimizer/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,28 @@ def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = Fal
105105
return grad_norm
106106

107107

108+
def projection(
109+
p,
110+
grad,
111+
perturb: torch.Tensor,
112+
delta: float,
113+
wd_ratio: float,
114+
eps: float,
115+
) -> Tuple[torch.Tensor, float]:
116+
wd: float = 1.0
117+
expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
118+
for view_func in (channel_view, layer_view):
119+
cosine_sim = cosine_similarity_by_view(grad, p, eps, view_func)
120+
121+
if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
122+
p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)
123+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
124+
wd = wd_ratio
125+
return perturb, wd
126+
127+
return perturb, wd
128+
129+
108130
def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor:
109131
keep_dim: bool = True
110132
dim: Optional[Union[int, Tuple[int, ...]]] = None

0 commit comments

Comments
 (0)