Skip to content

Commit 46138fb

Browse files
committed
update: CosineDecay
1 parent 239024e commit 46138fb

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pytorch_optimizer/optimizer/spam.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def step(self, current_step: int) -> None:
3131
3232
:param current_step: int. Current step index.
3333
"""
34-
self.cosine_stepper.step(current_step)
34+
self.cosine_stepper.last_epoch = current_step
35+
self.cosine_stepper.step()
3536

3637
def get_death_rate(self, current_step: int) -> float:
3738
r"""Get the updated rate (death_rate) at the given step.
@@ -266,9 +267,9 @@ class StableSPAM(BaseOptimizer):
266267
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
267268
:param lr: float. learning rate.
268269
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
269-
:param gamma1: float.
270-
:param gamma2: float.
271-
:param theta: float.
270+
:param gamma1: float. gamma1 parameter.
271+
:param gamma2: float. gamma2 parameter.
272+
:param theta: float. theta parameter.
272273
:param t_max: Optional[int]. total number of steps.
273274
:param eta_min: float. eta_min of CosineDecay.
274275
:param weight_decay: float. weight decay (L2 penalty).

0 commit comments

Comments
 (0)