11import math
2- from typing import Callable , List , Tuple
2+ from typing import List , Tuple
33
44import torch
5- import torch .nn .functional as F
65from torch .optim .optimizer import Optimizer
76
87from pytorch_optimizer .base_optimizer import BaseOptimizer
98from pytorch_optimizer .gc import centralize_gradient
109from pytorch_optimizer .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS
10+ from pytorch_optimizer .utils import channel_view , cosine_similarity_by_view , layer_view
1111
1212
1313class AdamP (Optimizer , BaseOptimizer ):
@@ -80,25 +80,6 @@ def validate_parameters(self):
8080 self .validate_weight_decay_ratio (self .wd_ratio )
8181 self .validate_epsilon (self .eps )
8282
83- @staticmethod
84- def channel_view (x : torch .Tensor ) -> torch .Tensor :
85- return x .view (x .size ()[0 ], - 1 )
86-
87- @staticmethod
88- def layer_view (x : torch .Tensor ) -> torch .Tensor :
89- return x .view (1 , - 1 )
90-
91- @staticmethod
92- def cosine_similarity (
93- x : torch .Tensor ,
94- y : torch .Tensor ,
95- eps : float ,
96- view_func : Callable [[torch .Tensor ], torch .Tensor ],
97- ) -> torch .Tensor :
98- x = view_func (x )
99- y = view_func (y )
100- return F .cosine_similarity (x , y , dim = 1 , eps = eps ).abs_ ()
101-
10283 def projection (
10384 self ,
10485 p ,
@@ -110,8 +91,8 @@ def projection(
11091 ) -> Tuple [torch .Tensor , float ]:
11192 wd : float = 1.0
11293 expand_size : List [int ] = [- 1 ] + [1 ] * (len (p .shape ) - 1 )
113- for view_func in (self . channel_view , self . layer_view ):
114- cosine_sim = self . cosine_similarity (grad , p , eps , view_func )
94+ for view_func in (channel_view , layer_view ):
95+ cosine_sim = cosine_similarity_by_view (grad , p , eps , view_func )
11596
11697 if cosine_sim .max () < delta / math .sqrt (view_func (p ).size ()[1 ]):
11798 p_n = p / view_func (p ).norm (dim = 1 ).view (expand_size ).add_ (eps )
0 commit comments