Skip to content

Commit c394b21

Browse files
committed
fix: second_step
1 parent 07fc0a1 commit c394b21

File tree

1 file changed

+9
-10
lines changed
  • pytorch_optimizer/optimizer

1 file changed

+9
-10
lines changed

pytorch_optimizer/optimizer/sam.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
class SAM(Optimizer, BaseOptimizer):
11-
"""
12-
Reference : https://github.com/davda54/sam
11+
r"""Sharpness-Aware Minimization for Efficiently Improving Generalization
12+
1313
Example :
1414
from pytorch_optimizer import SAM
1515
...
@@ -48,6 +48,12 @@ def closure():
4848
loss.backward()
4949
optimizer.step(closure)
5050
optimizer.zero_grad()
51+
52+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
53+
:param base_optimizer: Optimizer. base optimizer
54+
:param rho: float. size of the neighborhood for computing the max loss
55+
:param adaptive: bool. element-wise Adaptive SAM
56+
:param kwargs: Dict. parameters for optimizer.
5157
"""
5258

5359
def __init__(
@@ -58,13 +64,6 @@ def __init__(
5864
adaptive: bool = False,
5965
**kwargs,
6066
):
61-
"""SAM
62-
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
63-
:param base_optimizer: Optimizer. base optimizer
64-
:param rho: float. size of the neighborhood for computing the max loss
65-
:param adaptive: bool. element-wise Adaptive SAM
66-
:param kwargs: Dict. parameters for optimizer.
67-
"""
6867
self.rho = rho
6968

7069
self.validate_parameters()
@@ -109,7 +108,7 @@ def second_step(self, zero_grad: bool = False):
109108
continue
110109

111110
# get back to "w" from "w + e(w)"
112-
p = self.state[p]['old_p']
111+
p.data = self.state[p]['old_p']
113112

114113
# do the actual "sharpness-aware" update
115114
self.base_optimizer.step()

0 commit comments

Comments
 (0)