22from torch .optim .optimizer import Optimizer , required
33
44
5- class SGD_MC (Optimizer ):
6- r"""Implements stochastic gradient descent (optionally with momentum).
5+ class SgdMaxChange (Optimizer ):
6+ r"""Implements stochastic gradient descent (optionally with momentum and max
7+ change).
78 Nesterov momentum is based on the formula from
89 `On the importance of initialization and momentum in deep learning`__.
910 Args:
@@ -14,6 +15,10 @@ class SGD_MC(Optimizer):
1415 weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
1516 dampening (float, optional): dampening for momentum (default: 0)
1617 nesterov (bool, optional): enables Nesterov momentum (default: False)
18+ max_change_per_layer (float, optional): change in parameters allowed of
19+ any given layer, on any given batch, measured in l2 norm
20+ max_change (float, optional): change in parameters allowed of the whole
21+ model, after applying the per-layer constraint
1722 Example:
1823 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1924 >>> optimizer.zero_grad()
@@ -49,17 +54,21 @@ def __init__(self, params, lr=required, momentum=0, dampening=0,
4954 raise ValueError ("Invalid momentum value: {}" .format (momentum ))
5055 if weight_decay < 0.0 :
5156 raise ValueError ("Invalid weight_decay value: {}" .format (weight_decay ))
57+ if max_change_per_layer < 0.01 :
58+ raise ValueError ("Invalid max_change_per_layer value: {}" .format (max_change_per_layer ))
59+ if max_change < 0.01 :
60+ raise ValueError ("Invalid max_change value: {}" .format (max_change ))
5261
5362 defaults = dict (lr = lr , momentum = momentum , dampening = dampening ,
5463 weight_decay = weight_decay , nesterov = nesterov ,
5564 max_change_per_layer = max_change_per_layer ,
5665 max_change = max_change )
5766 if nesterov and (momentum <= 0 or dampening != 0 ):
5867 raise ValueError ("Nesterov momentum requires a momentum and zero dampening" )
59- super (SGD_MC , self ).__init__ (params , defaults )
68+ super (SgdMaxChange , self ).__init__ (params , defaults )
6069
6170 def __setstate__ (self , state ):
62- super (SGD_MC , self ).__setstate__ (state )
71+ super (SgdMaxChange , self ).__setstate__ (state )
6372 for group in self .param_groups :
6473 group .setdefault ('nesterov' , False )
6574
@@ -107,7 +116,7 @@ def step(self, closure=None):
107116 d_p = buf
108117 norm = d_p .norm (2 ).item ()
109118 if norm * group ['lr' ] > max_change_per_layer :
110- d_p .mul_ (max_change_per_layer / norm )
119+ d_p .mul_ (max_change_per_layer / ( norm * group [ 'lr' ]) )
111120 delta .append (d_p )
112121 total_norm += d_p .norm (2 ).item () ** 2.
113122
@@ -118,7 +127,7 @@ def step(self, closure=None):
118127 if p .grad is None :
119128 continue
120129 if total_norm * group ['lr' ] > max_change :
121- p .add_ (delta [i ], alpha = - group ['lr' ] * max_change / total_norm )
130+ p .add_ (delta [i ], alpha = - group ['lr' ] * max_change / ( total_norm * group [ 'lr' ]) )
122131 else :
123132 p .add_ (delta [i ], alpha = - group ['lr' ])
124133
0 commit comments