Skip to content

Commit d09e2fc

Browse files
committed
[skip ci] docs: SAM optimizer
1 parent b4ec632 commit d09e2fc

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

pytorch_optimizer/sam.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,45 @@
1212

1313

1414
class SAM(Optimizer):
15+
"""
16+
Reference : https://github.com/davda54/sam
17+
Example :
18+
from pytorch_optimizer import SAM
19+
...
20+
model = YourModel()
21+
base_optimizer = Ranger21
22+
optimizer = SAM(model.parameters(), base_optimizer)
23+
...
24+
for input, output in data:
25+
# first forward-backward pass
26+
loss = loss_function(output, model(input)) # use this loss for any training statistics
27+
loss.backward()
28+
optimizer.first_step(zero_grad=True)
29+
30+
# second forward-backward pass
31+
loss_function(output, model(input)).backward() # make sure to do a full forward pass
32+
optimizer.second_step(zero_grad=True)
33+
34+
Alternative Example with a single closure-based step function:
35+
from pytorch_optimizer import SAM
36+
...
37+
model = YourModel()
38+
base_optimizer = Ranger21
39+
optimizer = SAM(model.parameters(), base_optimizer)
40+
41+
def closure():
42+
loss = loss_function(output, model(input))
43+
loss.backward()
44+
return loss
45+
...
46+
47+
for input, output in data:
48+
loss = loss_function(output, model(input))
49+
loss.backward()
50+
optimizer.step(closure)
51+
optimizer.zero_grad()
52+
"""
53+
1554
def __init__(
1655
self,
1756
params: PARAMS,
@@ -20,6 +59,13 @@ def __init__(
2059
adaptive: bool = False,
2160
**kwargs,
2261
):
62+
"""(Adaptive) Sharpness-Aware Minimization
63+
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
64+
:param base_optimizer:
65+
:param rho: float. size of the neighborhood for computing the max loss
66+
:param adaptive: bool. element-wise Adaptive SAM
67+
:param kwargs: Dict. parameters for optimizer.
68+
"""
2369
self.rho = rho
2470

2571
self.check_valid_parameters()

0 commit comments

Comments
 (0)