1212
1313
1414class SAM (Optimizer ):
15+ """
16+ Reference : https://github.com/davda54/sam
17+ Example :
18+ from pytorch_optimizer import SAM
19+ ...
20+ model = YourModel()
21+ base_optimizer = Ranger21
22+ optimizer = SAM(model.parameters(), base_optimizer)
23+ ...
24+ for input, output in data:
25+ # first forward-backward pass
26+ loss = loss_function(output, model(input)) # use this loss for any training statistics
27+ loss.backward()
28+ optimizer.first_step(zero_grad=True)
29+
30+ # second forward-backward pass
31+ loss_function(output, model(input)).backward() # make sure to do a full forward pass
32+ optimizer.second_step(zero_grad=True)
33+
34+ Alternative Example with a single closure-based step function:
35+ from pytorch_optimizer import SAM
36+ ...
37+ model = YourModel()
38+ base_optimizer = Ranger21
39+ optimizer = SAM(model.parameters(), base_optimizer)
40+
41+ def closure():
42+ loss = loss_function(output, model(input))
43+ loss.backward()
44+ return loss
45+ ...
46+
47+ for input, output in data:
48+ loss = loss_function(output, model(input))
49+ loss.backward()
50+ optimizer.step(closure)
51+ optimizer.zero_grad()
52+ """
53+
1554 def __init__ (
1655 self ,
1756 params : PARAMS ,
@@ -20,6 +59,13 @@ def __init__(
2059 adaptive : bool = False ,
2160 ** kwargs ,
2261 ):
62+ """(Adaptive) Sharpness-Aware Minimization
63+ :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
64+ :param base_optimizer:
65+ :param rho: float. size of the neighborhood for computing the max loss
66+ :param adaptive: bool. element-wise Adaptive SAM
67+ :param kwargs: Dict. parameters for optimizer.
68+ """
2369 self .rho = rho
2470
2571 self .check_valid_parameters ()
0 commit comments