@@ -34,17 +34,17 @@ def __init__(
3434 average_conv_kernel : bool = False ,
3535 adamd_debias_term : bool = False ,
3636 eps : float = 1e-8 ,
37- seed : int = 2147483647 ,
37+ seed : int = 1337 ,
3838 ):
39- """
39+ """AdaHessian
4040 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
41- :param lr: float. learning rate.
41+ :param lr: float. learning rate
4242 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4343 :param weight_decay: float. weight decay (L2 penalty)
4444 :param hessian_power: float. exponent of the hessian trace
4545 :param update_each: int. compute the hessian trace approximation only after *this* number of steps
4646 :param num_samples: int. how many times to sample `z` for the approximation of the hessian trace
47- :param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper.
47+ :param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper
4848 :param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
4949 :param eps: float. term added to the denominator to improve numerical stability
5050 :param seed: int.
@@ -103,16 +103,17 @@ def zero_hessian(self):
103103 if not isinstance (p .hess , float ) and self .state [p ]['hessian_step' ] % self .update_each == 0 :
104104 p .hess .zero_ ()
105105
106- @torch .no_grad ()
107106 def set_hessian (self ):
108- """Computes the Hutchinson approximation of the hessian trace
109- and accumulates it for each trainable parameter
110- """
107+ """Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter"""
111108 params = []
112- for p in filter (lambda param : param .grad is not None , self .get_params ()):
109+ for p in self .get_params ():
110+ if p .grad is None :
111+ continue
112+
113113 # compute the trace only each `update_each` step
114114 if self .state [p ]['hessian_step' ] % self .update_each == 0 :
115115 params .append (p )
116+
116117 self .state [p ]['hessian_step' ] += 1
117118
118119 if len (params ) == 0 :
@@ -126,7 +127,7 @@ def set_hessian(self):
126127
127128 for i in range (self .num_samples ):
128129 # Rademacher distribution {-1.0, 1.0}
129- zs = [torch .randint (0 , 2 , p .size (), generator = self . generator , device = p . device ) * 2.0 - 1.0 for p in params ]
130+ zs = [2.0 * torch .randint (0 , 2 , p .size ()). float (). requires_grad_ ( True ) - 1.0 for p in params ]
130131
131132 # note that, possible memory leak due to retrain_graph=True
132133 h_zs = torch .autograd .grad (
@@ -141,7 +142,6 @@ def set_hessian(self):
141142 # approximate the expected values of z * (H@z)
142143 p .hess += h_z * z / self .num_samples
143144
144- @torch .no_grad ()
145145 def step (self , closure : CLOSURE = None ) -> LOSS :
146146 loss : LOSS = None
147147 if closure is not None :
@@ -156,7 +156,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
156156 continue
157157
158158 if self .average_conv_kernel and p .dim () == 4 :
159- p .hess = torch .abs (p .hess ).mean (dim = [ 2 , 3 ] , keepdim = True ).expand_as (p .hess ).clone ()
159+ p .hess = torch .abs (p .hess ).mean (dim = ( 2 , 3 ) , keepdim = True ).expand_as (p .hess ).clone ()
160160
161161 # Perform correct step-weight decay as in AdamW
162162 p .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
0 commit comments