Skip to content

Commit b4a32b0

Browse files
committed
update: cautious variant
1 parent 54e43e7 commit b4a32b0

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

pytorch_optimizer/optimizer/mars.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@ class MARS(BaseOptimizer):
1717
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1818
:param lr: float. learning rate.
1919
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
20+
:param gamma: float. the scaling parameter that controls the strength of gradient correction.
21+
:param mars_type: MARS TYPE. type of MARS. `adamw`, `lion`, `shampoo` are supported.
22+
:param optimize_1d: bool. whether MARS should optimize 1D parameters.
23+
:param lr_1d: float. learning rate for AdamW when optimize_1d is set to False.
2024
:param betas_1d: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
2125
for 1d.
22-
:param gamma: float. gamma.
23-
:param mars_type: MARS TYPE. type of MARS. `adamw`, `lion`, `shampoo` are supported.
2426
:param weight_decay: float. weight decay (L2 penalty).
2527
:param weight_decay_1d: float. weight decay for 1d.
2628
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2729
:param fixed_decay: bool. fix weight decay.
2830
:param ams_bound: bool. whether to use the AMSBound variant.
31+
:param cautious: bool. whether to use cautious feature.
2932
:param eps: float. term added to the denominator to improve numerical stability.
3033
"""
3134

@@ -39,11 +42,12 @@ def __init__(
3942
optimize_1d: bool = False,
4043
lr_1d: bool = 3e-3,
4144
betas_1d: BETAS = (0.9, 0.95),
42-
weight_decay_1d: float = 1e-1,
4345
weight_decay: float = 0.0,
46+
weight_decay_1d: float = 1e-1,
4447
weight_decouple: bool = True,
4548
fixed_decay: bool = False,
4649
ams_bound: bool = False,
50+
cautious: bool = False,
4751
eps: float = 1e-8,
4852
**kwargs,
4953
):
@@ -70,6 +74,7 @@ def __init__(
7074
'weight_decouple': weight_decouple,
7175
'fixed_decay': fixed_decay,
7276
'ams_bound': ams_bound,
77+
'cautious': cautious,
7378
'eps': eps,
7479
}
7580

@@ -104,6 +109,7 @@ def optimize_mixed(
104109
is_grad_2d: bool,
105110
step: int,
106111
ams_bound: bool,
112+
cautious: bool,
107113
eps: float,
108114
) -> torch.Tensor:
109115
beta1, beta2 = betas
@@ -115,6 +121,9 @@ def optimize_mixed(
115121

116122
exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1)
117123

124+
if cautious:
125+
self.apply_cautious(exp_avg, grad)
126+
118127
if mars_type == 'adamw' or (mars_type == 'shampoo' and not is_grad_2d):
119128
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2)
120129

@@ -142,6 +151,7 @@ def optimize_1d(
142151
betas: BETAS,
143152
step: int,
144153
ams_bound: bool,
154+
cautious: bool,
145155
eps: float,
146156
) -> torch.Tensor:
147157
beta1, beta2 = betas
@@ -155,6 +165,9 @@ def optimize_1d(
155165
update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
156166
update.div_(bias_correction2_sq).mul_(bias_correction1)
157167

168+
if cautious:
169+
self.apply_cautious(exp_avg, grad)
170+
158171
return exp_avg.div(update)
159172

160173
@torch.no_grad()
@@ -207,6 +220,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
207220
is_grad_2d,
208221
group['step'],
209222
group['ams_bound'],
223+
group['cautious'],
210224
group['eps'],
211225
)
212226
else:
@@ -218,6 +232,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
218232
group['betas_1d'],
219233
group['step'],
220234
group['ams_bound'],
235+
group['cautious'],
221236
group['eps'],
222237
)
223238

0 commit comments

Comments
 (0)