@@ -52,36 +52,36 @@ def __init__(
5252 self .lr = lr
5353 self .betas = betas
5454 self .weight_decay = weight_decay
55+ self .eps = eps
5556 self .wd_ratio = wd_ratio
5657 self .use_gc = use_gc
57- self .eps = eps
5858
5959 self .check_valid_parameters ()
6060
6161 defaults : DEFAULTS = dict (
6262 lr = lr ,
6363 betas = betas ,
64- eps = eps ,
6564 weight_decay = weight_decay ,
6665 delta = delta ,
6766 wd_ratio = wd_ratio ,
6867 nesterov = nesterov ,
68+ eps = eps ,
6969 )
7070 super ().__init__ (params , defaults )
7171
7272 def check_valid_parameters (self ):
7373 if self .lr < 0.0 :
7474 raise ValueError (f'Invalid learning rate : { self .lr } ' )
75- if self .weight_decay < 0.0 :
76- raise ValueError (f'Invalid weight_decay : { self .weight_decay } ' )
7775 if not 0.0 <= self .betas [0 ] < 1.0 :
7876 raise ValueError (f'Invalid beta_0 : { self .betas [0 ]} ' )
7977 if not 0.0 <= self .betas [1 ] < 1.0 :
8078 raise ValueError (f'Invalid beta_1 : { self .betas [1 ]} ' )
81- if not 0.0 <= self .wd_ratio < 1 .0 :
82- raise ValueError (f'Invalid wd_ratio : { self .wd_ratio } ' )
79+ if self .weight_decay < 0 .0 :
80+ raise ValueError (f'Invalid weight_decay : { self .weight_decay } ' )
8381 if self .eps < 0.0 :
8482 raise ValueError (f'Invalid eps : { self .eps } ' )
83+ if not 0.0 <= self .wd_ratio < 1.0 :
84+ raise ValueError (f'Invalid wd_ratio : { self .wd_ratio } ' )
8585
8686 @staticmethod
8787 def channel_view (x : torch .Tensor ) -> torch .Tensor :
@@ -97,7 +97,7 @@ def cosine_similarity(
9797 y : torch .Tensor ,
9898 eps : float ,
9999 view_func : Callable [[torch .Tensor ], torch .Tensor ],
100- ):
100+ ) -> torch . Tensor :
101101 x = view_func (x )
102102 y = view_func (y )
103103 return F .cosine_similarity (x , y , dim = 1 , eps = eps ).abs_ ()
0 commit comments