Skip to content

Commit a3c6561

Browse files
committed
refactor: staticmethod to utils
1 parent 8f5892c commit a3c6561

File tree

3 files changed

+29
-47
lines changed

3 files changed

+29
-47
lines changed

pytorch_optimizer/adamp.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import math
2-
from typing import Callable, List, Tuple
2+
from typing import List, Tuple
33

44
import torch
5-
import torch.nn.functional as F
65
from torch.optim.optimizer import Optimizer
76

87
from pytorch_optimizer.base_optimizer import BaseOptimizer
98
from pytorch_optimizer.gc import centralize_gradient
109
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
10+
from pytorch_optimizer.utils import channel_view, cosine_similarity_by_view, layer_view
1111

1212

1313
class AdamP(Optimizer, BaseOptimizer):
@@ -80,25 +80,6 @@ def validate_parameters(self):
8080
self.validate_weight_decay_ratio(self.wd_ratio)
8181
self.validate_epsilon(self.eps)
8282

83-
@staticmethod
84-
def channel_view(x: torch.Tensor) -> torch.Tensor:
85-
return x.view(x.size()[0], -1)
86-
87-
@staticmethod
88-
def layer_view(x: torch.Tensor) -> torch.Tensor:
89-
return x.view(1, -1)
90-
91-
@staticmethod
92-
def cosine_similarity(
93-
x: torch.Tensor,
94-
y: torch.Tensor,
95-
eps: float,
96-
view_func: Callable[[torch.Tensor], torch.Tensor],
97-
) -> torch.Tensor:
98-
x = view_func(x)
99-
y = view_func(y)
100-
return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()
101-
10283
def projection(
10384
self,
10485
p,
@@ -110,8 +91,8 @@ def projection(
11091
) -> Tuple[torch.Tensor, float]:
11192
wd: float = 1.0
11293
expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
113-
for view_func in (self.channel_view, self.layer_view):
114-
cosine_sim = self.cosine_similarity(grad, p, eps, view_func)
94+
for view_func in (channel_view, layer_view):
95+
cosine_sim = cosine_similarity_by_view(grad, p, eps, view_func)
11596

11697
if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
11798
p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)

pytorch_optimizer/sgdp.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import math
2-
from typing import Callable, List, Tuple
2+
from typing import List, Tuple
33

44
import torch
5-
from torch.nn import functional as F
65
from torch.optim.optimizer import Optimizer
76

87
from pytorch_optimizer.base_optimizer import BaseOptimizer
98
from pytorch_optimizer.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
from pytorch_optimizer.utils import channel_view, cosine_similarity_by_view, layer_view
1010

1111

1212
class SGDP(Optimizer, BaseOptimizer):
@@ -74,25 +74,6 @@ def validate_parameters(self):
7474
self.validate_weight_decay_ratio(self.wd_ratio)
7575
self.validate_epsilon(self.eps)
7676

77-
@staticmethod
78-
def channel_view(x: torch.Tensor) -> torch.Tensor:
79-
return x.view(x.size()[0], -1)
80-
81-
@staticmethod
82-
def layer_view(x: torch.Tensor) -> torch.Tensor:
83-
return x.view(1, -1)
84-
85-
@staticmethod
86-
def cosine_similarity(
87-
x: torch.Tensor,
88-
y: torch.Tensor,
89-
eps: float,
90-
view_func: Callable[[torch.Tensor], torch.Tensor],
91-
):
92-
x = view_func(x)
93-
y = view_func(y)
94-
return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()
95-
9677
def projection(
9778
self,
9879
p,
@@ -104,8 +85,8 @@ def projection(
10485
) -> Tuple[torch.Tensor, float]:
10586
wd: float = 1.0
10687
expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
107-
for view_func in (self.channel_view, self.layer_view):
108-
cosine_sim = self.cosine_similarity(grad, p, eps, view_func)
88+
for view_func in (channel_view, layer_view):
89+
cosine_sim = cosine_similarity_by_view(grad, p, eps, view_func)
10990

11091
if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
11192
p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)

pytorch_optimizer/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import math
2-
from typing import List, Optional, Tuple, Union
2+
from typing import Callable, List, Optional, Tuple, Union
33

44
import numpy as np
55
import torch
66
from torch import nn
77
from torch.distributed import all_reduce
8+
from torch.nn import functional as F
89
from torch.nn.utils import clip_grad_norm_
910

1011
from pytorch_optimizer.types import PARAMETERS
@@ -50,6 +51,25 @@ def un_flatten_grad(grads: torch.Tensor, shapes: List[int]) -> List[torch.Tensor
5051
return un_flatten_grad
5152

5253

54+
def channel_view(x: torch.Tensor) -> torch.Tensor:
55+
return x.view(x.size()[0], -1)
56+
57+
58+
def layer_view(x: torch.Tensor) -> torch.Tensor:
59+
return x.view(1, -1)
60+
61+
62+
def cosine_similarity_by_view(
63+
x: torch.Tensor,
64+
y: torch.Tensor,
65+
eps: float,
66+
view_func: Callable[[torch.Tensor], torch.Tensor],
67+
) -> torch.Tensor:
68+
x = view_func(x)
69+
y = view_func(y)
70+
return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()
71+
72+
5373
def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> Union[torch.Tensor, float]:
5474
"""Clips grad norms.
5575
During combination with FSDP, will also ensure that grad norms are aggregated

0 commit comments

Comments
 (0)