Skip to content

Commit ed97a2f

Browse files
authored
Merge pull request #158 from kozistr/update/dadaptation
[Update] D-Adaptation v3
2 parents 9110bbd + ddf81a7 commit ed97a2f

File tree

3 files changed

+92
-115
lines changed

3 files changed

+92
-115
lines changed

pytorch_optimizer/optimizer/adashift.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class AdaShift(Optimizer, BaseOptimizer):
1515
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1616
:param lr: float. learning rate.
1717
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
18-
:param keep_num: int. number of gradients used to compute first moment estimation.
18+
:param keep_num: int. number of gradients used to compute first moment estimation.
1919
:param reduce_func: Optional[Callable]. function applied to squared gradients to further reduce the correlation.
2020
If None, no function is applied.
2121
:param eps: float. term added to the denominator to improve numerical stability.
@@ -69,7 +69,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
6969

7070
beta1, beta2 = group['betas']
7171

72-
exp_weight_sum: int = sum(beta1**i for i in range(group['keep_num']))
72+
exp_weight_sum: int = sum(beta1 ** i for i in range(group['keep_num'])) # fmt: skip
7373
first_grad_weight: float = beta1 ** (group['keep_num'] - 1) / exp_weight_sum
7474
last_grad_weight: float = 1.0 / exp_weight_sum
7575

0 commit comments

Comments
 (0)