Skip to content

Commit 057c28d

Browse files
committed
update: SAM optimizer
1 parent 13d0ddf commit 057c28d

File tree

1 file changed

+47
-63
lines changed
  • pytorch_optimizer/optimizer

1 file changed

+47
-63
lines changed

pytorch_optimizer/optimizer/sam.py

Lines changed: 47 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,21 @@
1515
from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats
1616

1717

18+
def get_global_gradient_norm(param_groups: PARAMETERS, device: torch.device) -> torch.Tensor:
19+
r"""Get global gradient norm."""
20+
return torch.norm(
21+
torch.stack(
22+
[
23+
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(device)
24+
for group in param_groups
25+
for p in group['params']
26+
if p.grad is not None
27+
]
28+
),
29+
p=2,
30+
)
31+
32+
1833
class SAM(BaseOptimizer):
1934
r"""Sharpness-Aware Minimization for Efficiently Improving Generalization.
2035
@@ -80,8 +95,8 @@ def __init__(
8095
self.use_gc = use_gc
8196
self.perturb_eps = perturb_eps
8297

83-
defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
84-
defaults.update(kwargs)
98+
defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive, **kwargs}
99+
85100
super().__init__(params, defaults)
86101

87102
self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
@@ -90,13 +105,15 @@ def __init__(
90105
def __str__(self) -> str:
91106
return 'SAM'
92107

93-
@torch.no_grad()
94-
def init_group(self):
108+
def init_group(self, group: GROUP, **kwargs) -> None:
95109
pass
96110

97111
@torch.no_grad()
98112
def first_step(self, zero_grad: bool = False):
99-
grad_norm = self.grad_norm().add_(self.perturb_eps)
113+
device = self.param_groups[0]['params'][0].device
114+
115+
grad_norm = get_global_gradient_norm(self.param_groups, device).add_(self.perturb_eps)
116+
100117
for group in self.param_groups:
101118
scale = group['rho'] / grad_norm
102119

@@ -109,6 +126,7 @@ def first_step(self, zero_grad: bool = False):
109126
centralize_gradient(grad, gc_conv_only=False)
110127

111128
self.state[p]['old_p'] = p.clone()
129+
112130
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)
113131

114132
p.add_(e_w)
@@ -142,20 +160,6 @@ def step(self, closure: CLOSURE = None):
142160

143161
self.second_step()
144162

145-
def grad_norm(self) -> torch.Tensor:
146-
shared_device = self.param_groups[0]['params'][0].device
147-
return torch.norm(
148-
torch.stack(
149-
[
150-
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device)
151-
for group in self.param_groups
152-
for p in group['params']
153-
if p.grad is not None
154-
]
155-
),
156-
p=2,
157-
)
158-
159163
def load_state_dict(self, state_dict: Dict):
160164
super().load_state_dict(state_dict)
161165
self.base_optimizer.param_groups = self.param_groups
@@ -218,24 +222,23 @@ def __init__(
218222
if hasattr(ReduceOp, 'AVG'):
219223
self.grad_reduce = ReduceOp.AVG
220224
self.manual_average: bool = False
221-
else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
225+
else:
222226
self.grad_reduce = ReduceOp.SUM
223227
self.manual_average: bool = True
224228

225229
self.base_optimizer = base_optimizer
226230
self.param_groups = self.base_optimizer.param_groups
227231

228-
defaults: DEFAULTS = {'adaptive': adaptive}
229-
defaults.update(kwargs)
232+
defaults: DEFAULTS = {'adaptive': adaptive, **kwargs}
233+
230234
super().__init__(params, defaults)
231235

232236
self.update_rho_t()
233237

234238
def __str__(self) -> str:
235239
return 'GSAM'
236240

237-
@torch.no_grad()
238-
def init_group(self):
241+
def init_group(self, group: GROUP, **kwargs) -> None:
239242
pass
240243

241244
@torch.no_grad()
@@ -414,8 +417,7 @@ def __init__(
414417

415418
alpha: float = gamma / (1.0 - gamma)
416419

417-
defaults: DEFAULTS = {'rho': rho, 'alpha': alpha, 'adaptive': adaptive, 'sam_eps': eps}
418-
defaults.update(kwargs)
420+
defaults: DEFAULTS = {'rho': rho, 'alpha': alpha, 'adaptive': adaptive, 'sam_eps': eps, **kwargs}
419421

420422
super().__init__(params, defaults)
421423

@@ -425,13 +427,15 @@ def __init__(
425427
def __str__(self) -> str:
426428
return 'WSAM'
427429

428-
@torch.no_grad()
429-
def init_group(self):
430+
def init_group(self, group: GROUP, **kwargs) -> None:
430431
pass
431432

432433
@torch.no_grad()
433434
def first_step(self, zero_grad: bool = False):
434-
grad_norm = self.grad_norm()
435+
device = self.param_groups[0]['params'][0].device
436+
437+
grad_norm = get_global_gradient_norm(self.param_groups, device)
438+
435439
for group in self.param_groups:
436440
scale = group['rho'] / (grad_norm + group['sam_eps'])
437441

@@ -516,21 +520,6 @@ def step(self, closure: CLOSURE = None):
516520

517521
return loss
518522

519-
def grad_norm(self) -> torch.Tensor:
520-
shared_device = self.param_groups[0]['params'][0].device
521-
522-
return torch.norm(
523-
torch.stack(
524-
[
525-
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device)
526-
for group in self.param_groups
527-
for p in group['params']
528-
if p.grad is not None
529-
]
530-
),
531-
p=2,
532-
)
533-
534523
def load_state_dict(self, state_dict: Dict):
535524
super().load_state_dict(state_dict)
536525
self.base_optimizer.param_groups = self.param_groups
@@ -591,8 +580,14 @@ def __init__(
591580
self.num_data = num_data
592581
self.damping = damping
593582

594-
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'rho': rho, 'adaptive': adaptive}
595-
defaults.update(kwargs)
583+
defaults: DEFAULTS = {
584+
'lr': lr,
585+
'betas': betas,
586+
'weight_decay': weight_decay,
587+
'rho': rho,
588+
'adaptive': adaptive,
589+
**kwargs,
590+
}
596591

597592
super().__init__(params, defaults)
598593

@@ -768,8 +763,7 @@ def __init__(
768763
def __str__(self) -> str:
769764
return 'LookSAM'
770765

771-
@torch.no_grad()
772-
def init_group(self):
766+
def init_group(self, group: GROUP, **kwargs) -> None:
773767
pass
774768

775769
def get_step(self):
@@ -784,7 +778,10 @@ def first_step(self, zero_grad: bool = False) -> None:
784778
if self.get_step() % self.k != 0:
785779
return
786780

787-
grad_norm = self.grad_norm().add_(self.perturb_eps)
781+
device = self.param_groups[0]['params'][0].device
782+
783+
grad_norm = get_global_gradient_norm(self.param_groups, device).add_(self.perturb_eps)
784+
788785
for group in self.param_groups:
789786
scale = group['rho'] / grad_norm
790787

@@ -800,6 +797,7 @@ def first_step(self, zero_grad: bool = False) -> None:
800797
self.state[f'old_grad_p_{i}']['old_grad_p'] = grad.clone()
801798

802799
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)
800+
803801
p.add_(e_w)
804802

805803
if zero_grad:
@@ -849,20 +847,6 @@ def step(self, closure: CLOSURE = None):
849847

850848
self.second_step()
851849

852-
def grad_norm(self) -> torch.Tensor:
853-
shared_device = self.param_groups[0]['params'][0].device
854-
return torch.norm(
855-
torch.stack(
856-
[
857-
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device)
858-
for group in self.param_groups
859-
for p in group['params']
860-
if p.grad is not None
861-
]
862-
),
863-
p=2,
864-
)
865-
866850
def load_state_dict(self, state_dict: Dict):
867851
super().load_state_dict(state_dict)
868852
self.base_optimizer.param_groups = self.param_groups

0 commit comments

Comments
 (0)