File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -31,7 +31,8 @@ def step(self, current_step: int) -> None:
3131
3232 :param current_step: int. Current step index.
3333 """
34- self .cosine_stepper .step (current_step )
34+ self .cosine_stepper .last_epoch = current_step
35+ self .cosine_stepper .step ()
3536
3637 def get_death_rate (self , current_step : int ) -> float :
3738 r"""Get the updated rate (death_rate) at the given step.
@@ -266,9 +267,9 @@ class StableSPAM(BaseOptimizer):
266267 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
267268 :param lr: float. learning rate.
268269 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
269- :param gamma1: float.
270- :param gamma2: float.
271- :param theta: float.
270+ :param gamma1: float. gamma1 parameter.
271+ :param gamma2: float. gamma2 parameter.
272+ :param theta: float. theta parameter.
272273 :param t_max: Optional[int]. total number of steps.
273274 :param eta_min: float. eta_min of CosineDecay.
274275 :param weight_decay: float. weight decay (L2 penalty).
You can’t perform that action at this time.
0 commit comments