@@ -17,15 +17,18 @@ class MARS(BaseOptimizer):
1717 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1818 :param lr: float. learning rate.
1919 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
20+ :param gamma: float. the scaling parameter that controls the strength of gradient correction.
21+ :param mars_type: MARS TYPE. type of MARS. `adamw`, `lion`, `shampoo` are supported.
22+ :param optimize_1d: bool. whether MARS should optimize 1D parameters.
23+ :param lr_1d: float. learning rate for AdamW when optimize_1d is set to False.
2024 :param betas_1d: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
2125 for 1d.
22- :param gamma: float. gamma.
23- :param mars_type: MARS TYPE. type of MARS. `adamw`, `lion`, `shampoo` are supported.
2426 :param weight_decay: float. weight decay (L2 penalty).
2527 :param weight_decay_1d: float. weight decay for 1d.
2628 :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2729 :param fixed_decay: bool. fix weight decay.
2830 :param ams_bound: bool. whether to use the AMSBound variant.
31+ :param cautious: bool. whether to use cautious feature.
2932 :param eps: float. term added to the denominator to improve numerical stability.
3033 """
3134
@@ -39,11 +42,12 @@ def __init__(
3942 optimize_1d : bool = False ,
4043 lr_1d : bool = 3e-3 ,
4144 betas_1d : BETAS = (0.9 , 0.95 ),
42- weight_decay_1d : float = 1e-1 ,
4345 weight_decay : float = 0.0 ,
46+ weight_decay_1d : float = 1e-1 ,
4447 weight_decouple : bool = True ,
4548 fixed_decay : bool = False ,
4649 ams_bound : bool = False ,
50+ cautious : bool = False ,
4751 eps : float = 1e-8 ,
4852 ** kwargs ,
4953 ):
@@ -70,6 +74,7 @@ def __init__(
7074 'weight_decouple' : weight_decouple ,
7175 'fixed_decay' : fixed_decay ,
7276 'ams_bound' : ams_bound ,
77+ 'cautious' : cautious ,
7378 'eps' : eps ,
7479 }
7580
@@ -104,6 +109,7 @@ def optimize_mixed(
104109 is_grad_2d : bool ,
105110 step : int ,
106111 ams_bound : bool ,
112+ cautious : bool ,
107113 eps : float ,
108114 ) -> torch .Tensor :
109115 beta1 , beta2 = betas
@@ -115,6 +121,9 @@ def optimize_mixed(
115121
116122 exp_avg .mul_ (beta1 ).add_ (c_t , alpha = 1.0 - beta1 )
117123
124+ if cautious :
125+ self .apply_cautious (exp_avg , grad )
126+
118127 if mars_type == 'adamw' or (mars_type == 'shampoo' and not is_grad_2d ):
119128 exp_avg_sq .mul_ (beta2 ).addcmul_ (c_t , c_t , value = 1.0 - beta2 )
120129
@@ -142,6 +151,7 @@ def optimize_1d(
142151 betas : BETAS ,
143152 step : int ,
144153 ams_bound : bool ,
154+ cautious : bool ,
145155 eps : float ,
146156 ) -> torch .Tensor :
147157 beta1 , beta2 = betas
@@ -155,6 +165,9 @@ def optimize_1d(
155165 update = self .apply_ams_bound (ams_bound , exp_avg_sq , max_exp_avg_sq , eps )
156166 update .div_ (bias_correction2_sq ).mul_ (bias_correction1 )
157167
168+ if cautious :
169+ self .apply_cautious (exp_avg , grad )
170+
158171 return exp_avg .div (update )
159172
160173 @torch .no_grad ()
@@ -207,6 +220,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
207220 is_grad_2d ,
208221 group ['step' ],
209222 group ['ams_bound' ],
223+ group ['cautious' ],
210224 group ['eps' ],
211225 )
212226 else :
@@ -218,6 +232,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
218232 group ['betas_1d' ],
219233 group ['step' ],
220234 group ['ams_bound' ],
235+ group ['cautious' ],
221236 group ['eps' ],
222237 )
223238
0 commit comments