1515from pytorch_optimizer .optimizer .utils import disable_running_stats , enable_running_stats
1616
1717
18+ def get_global_gradient_norm (param_groups : PARAMETERS , device : torch .device ) -> torch .Tensor :
19+ r"""Get global gradient norm."""
20+ return torch .norm (
21+ torch .stack (
22+ [
23+ ((torch .abs (p ) if group ['adaptive' ] else 1.0 ) * p .grad ).norm (p = 2 ).to (device )
24+ for group in param_groups
25+ for p in group ['params' ]
26+ if p .grad is not None
27+ ]
28+ ),
29+ p = 2 ,
30+ )
31+
32+
1833class SAM (BaseOptimizer ):
1934 r"""Sharpness-Aware Minimization for Efficiently Improving Generalization.
2035
@@ -80,8 +95,8 @@ def __init__(
8095 self .use_gc = use_gc
8196 self .perturb_eps = perturb_eps
8297
83- defaults : DEFAULTS = {'rho' : rho , 'adaptive' : adaptive }
84- defaults . update ( kwargs )
98+ defaults : DEFAULTS = {'rho' : rho , 'adaptive' : adaptive , ** kwargs }
99+
85100 super ().__init__ (params , defaults )
86101
87102 self .base_optimizer : Optimizer = base_optimizer (self .param_groups , ** kwargs )
@@ -90,13 +105,15 @@ def __init__(
90105 def __str__ (self ) -> str :
91106 return 'SAM'
92107
93- @torch .no_grad ()
94- def init_group (self ):
108+ def init_group (self , group : GROUP , ** kwargs ) -> None :
95109 pass
96110
97111 @torch .no_grad ()
98112 def first_step (self , zero_grad : bool = False ):
99- grad_norm = self .grad_norm ().add_ (self .perturb_eps )
113+ device = self .param_groups [0 ]['params' ][0 ].device
114+
115+ grad_norm = get_global_gradient_norm (self .param_groups , device ).add_ (self .perturb_eps )
116+
100117 for group in self .param_groups :
101118 scale = group ['rho' ] / grad_norm
102119
@@ -109,6 +126,7 @@ def first_step(self, zero_grad: bool = False):
109126 centralize_gradient (grad , gc_conv_only = False )
110127
111128 self .state [p ]['old_p' ] = p .clone ()
129+
112130 e_w = (torch .pow (p , 2 ) if group ['adaptive' ] else 1.0 ) * grad * scale .to (p )
113131
114132 p .add_ (e_w )
@@ -142,20 +160,6 @@ def step(self, closure: CLOSURE = None):
142160
143161 self .second_step ()
144162
145- def grad_norm (self ) -> torch .Tensor :
146- shared_device = self .param_groups [0 ]['params' ][0 ].device
147- return torch .norm (
148- torch .stack (
149- [
150- ((torch .abs (p ) if group ['adaptive' ] else 1.0 ) * p .grad ).norm (p = 2 ).to (shared_device )
151- for group in self .param_groups
152- for p in group ['params' ]
153- if p .grad is not None
154- ]
155- ),
156- p = 2 ,
157- )
158-
159163 def load_state_dict (self , state_dict : Dict ):
160164 super ().load_state_dict (state_dict )
161165 self .base_optimizer .param_groups = self .param_groups
@@ -218,24 +222,23 @@ def __init__(
218222 if hasattr (ReduceOp , 'AVG' ):
219223 self .grad_reduce = ReduceOp .AVG
220224 self .manual_average : bool = False
221- else : # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
225+ else :
222226 self .grad_reduce = ReduceOp .SUM
223227 self .manual_average : bool = True
224228
225229 self .base_optimizer = base_optimizer
226230 self .param_groups = self .base_optimizer .param_groups
227231
228- defaults : DEFAULTS = {'adaptive' : adaptive }
229- defaults . update ( kwargs )
232+ defaults : DEFAULTS = {'adaptive' : adaptive , ** kwargs }
233+
230234 super ().__init__ (params , defaults )
231235
232236 self .update_rho_t ()
233237
234238 def __str__ (self ) -> str :
235239 return 'GSAM'
236240
237- @torch .no_grad ()
238- def init_group (self ):
241+ def init_group (self , group : GROUP , ** kwargs ) -> None :
239242 pass
240243
241244 @torch .no_grad ()
@@ -414,8 +417,7 @@ def __init__(
414417
415418 alpha : float = gamma / (1.0 - gamma )
416419
417- defaults : DEFAULTS = {'rho' : rho , 'alpha' : alpha , 'adaptive' : adaptive , 'sam_eps' : eps }
418- defaults .update (kwargs )
420+ defaults : DEFAULTS = {'rho' : rho , 'alpha' : alpha , 'adaptive' : adaptive , 'sam_eps' : eps , ** kwargs }
419421
420422 super ().__init__ (params , defaults )
421423
@@ -425,13 +427,15 @@ def __init__(
425427 def __str__ (self ) -> str :
426428 return 'WSAM'
427429
428- @torch .no_grad ()
429- def init_group (self ):
430+ def init_group (self , group : GROUP , ** kwargs ) -> None :
430431 pass
431432
432433 @torch .no_grad ()
433434 def first_step (self , zero_grad : bool = False ):
434- grad_norm = self .grad_norm ()
435+ device = self .param_groups [0 ]['params' ][0 ].device
436+
437+ grad_norm = get_global_gradient_norm (self .param_groups , device )
438+
435439 for group in self .param_groups :
436440 scale = group ['rho' ] / (grad_norm + group ['sam_eps' ])
437441
@@ -516,21 +520,6 @@ def step(self, closure: CLOSURE = None):
516520
517521 return loss
518522
519- def grad_norm (self ) -> torch .Tensor :
520- shared_device = self .param_groups [0 ]['params' ][0 ].device
521-
522- return torch .norm (
523- torch .stack (
524- [
525- ((torch .abs (p ) if group ['adaptive' ] else 1.0 ) * p .grad ).norm (p = 2 ).to (shared_device )
526- for group in self .param_groups
527- for p in group ['params' ]
528- if p .grad is not None
529- ]
530- ),
531- p = 2 ,
532- )
533-
534523 def load_state_dict (self , state_dict : Dict ):
535524 super ().load_state_dict (state_dict )
536525 self .base_optimizer .param_groups = self .param_groups
@@ -591,8 +580,14 @@ def __init__(
591580 self .num_data = num_data
592581 self .damping = damping
593582
594- defaults : DEFAULTS = {'lr' : lr , 'betas' : betas , 'weight_decay' : weight_decay , 'rho' : rho , 'adaptive' : adaptive }
595- defaults .update (kwargs )
583+ defaults : DEFAULTS = {
584+ 'lr' : lr ,
585+ 'betas' : betas ,
586+ 'weight_decay' : weight_decay ,
587+ 'rho' : rho ,
588+ 'adaptive' : adaptive ,
589+ ** kwargs ,
590+ }
596591
597592 super ().__init__ (params , defaults )
598593
@@ -768,8 +763,7 @@ def __init__(
768763 def __str__ (self ) -> str :
769764 return 'LookSAM'
770765
771- @torch .no_grad ()
772- def init_group (self ):
766+ def init_group (self , group : GROUP , ** kwargs ) -> None :
773767 pass
774768
775769 def get_step (self ):
@@ -784,7 +778,10 @@ def first_step(self, zero_grad: bool = False) -> None:
784778 if self .get_step () % self .k != 0 :
785779 return
786780
787- grad_norm = self .grad_norm ().add_ (self .perturb_eps )
781+ device = self .param_groups [0 ]['params' ][0 ].device
782+
783+ grad_norm = get_global_gradient_norm (self .param_groups , device ).add_ (self .perturb_eps )
784+
788785 for group in self .param_groups :
789786 scale = group ['rho' ] / grad_norm
790787
@@ -800,6 +797,7 @@ def first_step(self, zero_grad: bool = False) -> None:
800797 self .state [f'old_grad_p_{ i } ' ]['old_grad_p' ] = grad .clone ()
801798
802799 e_w = (torch .pow (p , 2 ) if group ['adaptive' ] else 1.0 ) * grad * scale .to (p )
800+
803801 p .add_ (e_w )
804802
805803 if zero_grad :
@@ -849,20 +847,6 @@ def step(self, closure: CLOSURE = None):
849847
850848 self .second_step ()
851849
852- def grad_norm (self ) -> torch .Tensor :
853- shared_device = self .param_groups [0 ]['params' ][0 ].device
854- return torch .norm (
855- torch .stack (
856- [
857- ((torch .abs (p ) if group ['adaptive' ] else 1.0 ) * p .grad ).norm (p = 2 ).to (shared_device )
858- for group in self .param_groups
859- for p in group ['params' ]
860- if p .grad is not None
861- ]
862- ),
863- p = 2 ,
864- )
865-
866850 def load_state_dict (self , state_dict : Dict ):
867851 super ().load_state_dict (state_dict )
868852 self .base_optimizer .param_groups = self .param_groups
0 commit comments