55from pytorch_optimizer .base .exception import NoComplexParameterError , NoSparseGradientError
66from pytorch_optimizer .base .optimizer import BaseOptimizer
77from pytorch_optimizer .base .type import BETAS , CLOSURE , DEFAULTS , GROUP , LOSS , PARAMETERS
8- from pytorch_optimizer .optimizer .galore_utils import GaLoreProjector
8+ from pytorch_optimizer .optimizer .galore_utils import PROJECTION_TYPE , GaLoreProjector
99
1010
1111class Conda (BaseOptimizer ):
@@ -15,6 +15,9 @@ class Conda(BaseOptimizer):
1515 :param lr: float. learning rate.
1616 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
1717 :param weight_decay: float. weight decay (L2 penalty).
18+ :param update_proj_gap: int. update projection gap.
19+ :param scale: float. galore projection scaling factor.
20+ :param projection_type: PROJECTION_TYPE. the type of the projection.
1821 :param eps: float. term added to the denominator to improve numerical stability.
1922 :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
2023 """
@@ -25,18 +28,31 @@ def __init__(
2528 lr : float = 1e-3 ,
2629 betas : BETAS = (0.9 , 0.999 ),
2730 weight_decay : float = 0.0 ,
31+ update_proj_gap : int = 2000 ,
32+ scale : float = 1.0 ,
33+ projection_type : PROJECTION_TYPE = 'std' ,
2834 eps : float = 1e-8 ,
2935 maximize : bool = False ,
3036 ** kwargs ,
3137 ):
3238 self .validate_learning_rate (lr )
3339 self .validate_betas (betas )
40+ self .validate_positive (update_proj_gap , 'update_proj_gap' )
3441 self .validate_non_negative (weight_decay , 'weight_decay' )
3542 self .validate_non_negative (eps , 'eps' )
3643
3744 self .maximize = maximize
3845
39- defaults : DEFAULTS = {'lr' : lr , 'betas' : betas , 'weight_decay' : weight_decay , 'eps' : eps , ** kwargs }
46+ defaults : DEFAULTS = {
47+ 'lr' : lr ,
48+ 'betas' : betas ,
49+ 'weight_decay' : weight_decay ,
50+ 'update_proj_gap' : update_proj_gap ,
51+ 'scale' : scale ,
52+ 'projection_type' : projection_type ,
53+ 'eps' : eps ,
54+ ** kwargs ,
55+ }
4056 super ().__init__ (params , defaults )
4157
4258 def __str__ (self ) -> str :
@@ -94,7 +110,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
94110 exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
95111 exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
96112
97- if 'update_proj_gap' in group and p .dim () == 2 :
113+ if p .dim () == 2 :
98114 if 'projector' not in state :
99115 state ['projector' ] = GaLoreProjector (
100116 rank = None ,
@@ -112,7 +128,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
112128
113129 norm_grad = exp_avg / de_nom
114130
115- if 'update_proj_gap' in group and p .dim () == 2 :
131+ if p .dim () == 2 :
116132 norm_grad = state ['projector' ].project_back (norm_grad )
117133
118134 p .add_ (norm_grad , alpha = - step_size )
0 commit comments