|
7 | 7 | from pytorch_optimizer.base_optimizer import BaseOptimizer |
8 | 8 | from pytorch_optimizer.gc import centralize_gradient |
9 | 9 | from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS |
10 | | -from pytorch_optimizer.utils import channel_view, cosine_similarity_by_view, layer_view |
| 10 | +from pytorch_optimizer.utils import projection |
11 | 11 |
|
12 | 12 |
|
13 | 13 | class AdamP(Optimizer, BaseOptimizer): |
@@ -80,28 +80,6 @@ def validate_parameters(self): |
80 | 80 | self.validate_weight_decay_ratio(self.wd_ratio) |
81 | 81 | self.validate_epsilon(self.eps) |
82 | 82 |
|
83 | | - def projection( |
84 | | - self, |
85 | | - p, |
86 | | - grad, |
87 | | - perturb: torch.Tensor, |
88 | | - delta: float, |
89 | | - wd_ratio: float, |
90 | | - eps: float, |
91 | | - ) -> Tuple[torch.Tensor, float]: |
92 | | - wd: float = 1.0 |
93 | | - expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1) |
94 | | - for view_func in (channel_view, layer_view): |
95 | | - cosine_sim = cosine_similarity_by_view(grad, p, eps, view_func) |
96 | | - |
97 | | - if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]): |
98 | | - p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps) |
99 | | - perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) |
100 | | - wd = wd_ratio |
101 | | - return perturb, wd |
102 | | - |
103 | | - return perturb, wd |
104 | | - |
105 | 83 | @torch.no_grad() |
106 | 84 | def reset(self): |
107 | 85 | for group in self.param_groups: |
@@ -157,7 +135,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: |
157 | 135 |
|
158 | 136 | wd_ratio: float = 1 |
159 | 137 | if len(p.shape) > 1: |
160 | | - perturb, wd_ratio = self.projection( |
| 138 | + perturb, wd_ratio = projection( |
161 | 139 | p, |
162 | 140 | grad, |
163 | 141 | perturb, |
|
0 commit comments