Skip to content

Commit aa13de3

Browse files
committed
update: name
1 parent 3178f10 commit aa13de3

File tree

1 file changed

+6
-1
lines changed
  • pytorch_optimizer/optimizer

1 file changed

+6
-1
lines changed

pytorch_optimizer/optimizer/sam.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.optim.optimizer import Optimizer
55

66
from pytorch_optimizer.base.base_optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.exception import ClosureError
78
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS
89

910

@@ -77,6 +78,10 @@ def __init__(
7778
def validate_parameters(self):
7879
self.validate_rho(self.rho)
7980

81+
@property
82+
def __name__(self) -> str:
83+
return 'SAM'
84+
8085
@torch.no_grad()
8186
def reset(self):
8287
pass
@@ -119,7 +124,7 @@ def second_step(self, zero_grad: bool = False):
119124
@torch.no_grad()
120125
def step(self, closure: CLOSURE = None):
121126
if closure is None:
122-
raise RuntimeError('[-] Sharpness Aware Minimization (SAM) requires closure')
127+
raise ClosureError(self.__name__)
123128

124129
self.first_step(zero_grad=True)
125130

0 commit comments

Comments
 (0)