88
99
1010class SAM (Optimizer , BaseOptimizer ):
11- """
12- Reference : https://github.com/davda54/sam
11+ r """Sharpness-Aware Minimization for Efficiently Improving Generalization
12+
1313 Example :
1414 from pytorch_optimizer import SAM
1515 ...
@@ -48,6 +48,12 @@ def closure():
4848 loss.backward()
4949 optimizer.step(closure)
5050 optimizer.zero_grad()
51+
52+ :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
53+ :param base_optimizer: Optimizer. base optimizer
54+ :param rho: float. size of the neighborhood for computing the max loss
55+ :param adaptive: bool. element-wise Adaptive SAM
56+ :param kwargs: Dict. parameters for optimizer.
5157 """
5258
5359 def __init__ (
@@ -58,13 +64,6 @@ def __init__(
5864 adaptive : bool = False ,
5965 ** kwargs ,
6066 ):
61- """SAM
62- :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
63- :param base_optimizer: Optimizer. base optimizer
64- :param rho: float. size of the neighborhood for computing the max loss
65- :param adaptive: bool. element-wise Adaptive SAM
66- :param kwargs: Dict. parameters for optimizer.
67- """
6867 self .rho = rho
6968
7069 self .validate_parameters ()
@@ -109,7 +108,7 @@ def second_step(self, zero_grad: bool = False):
109108 continue
110109
111110 # get back to "w" from "w + e(w)"
112- p = self .state [p ]['old_p' ]
111+ p . data = self .state [p ]['old_p' ]
113112
114113 # do the actual "sharpness-aware" update
115114 self .base_optimizer .step ()
0 commit comments