Skip to content

Commit 5fec6f8

Browse files
committed
refactor: AdaHessian
1 parent 59b2dd1 commit 5fec6f8

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

pytorch_optimizer/adahessian.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,17 @@ def __init__(
3434
average_conv_kernel: bool = False,
3535
adamd_debias_term: bool = False,
3636
eps: float = 1e-8,
37-
seed: int = 2147483647,
37+
seed: int = 1337,
3838
):
39-
"""
39+
"""AdaHessian
4040
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
41-
:param lr: float. learning rate.
41+
:param lr: float. learning rate
4242
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4343
:param weight_decay: float. weight decay (L2 penalty)
4444
:param hessian_power: float. exponent of the hessian trace
4545
:param update_each: int. compute the hessian trace approximation only after *this* number of steps
4646
:param num_samples: int. how many times to sample `z` for the approximation of the hessian trace
47-
:param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper.
47+
:param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper
4848
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
4949
:param eps: float. term added to the denominator to improve numerical stability
5050
:param seed: int.
@@ -103,16 +103,17 @@ def zero_hessian(self):
103103
if not isinstance(p.hess, float) and self.state[p]['hessian_step'] % self.update_each == 0:
104104
p.hess.zero_()
105105

106-
@torch.no_grad()
107106
def set_hessian(self):
108-
"""Computes the Hutchinson approximation of the hessian trace
109-
and accumulates it for each trainable parameter
110-
"""
107+
"""Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter"""
111108
params = []
112-
for p in filter(lambda param: param.grad is not None, self.get_params()):
109+
for p in self.get_params():
110+
if p.grad is None:
111+
continue
112+
113113
# compute the trace only each `update_each` step
114114
if self.state[p]['hessian_step'] % self.update_each == 0:
115115
params.append(p)
116+
116117
self.state[p]['hessian_step'] += 1
117118

118119
if len(params) == 0:
@@ -126,7 +127,7 @@ def set_hessian(self):
126127

127128
for i in range(self.num_samples):
128129
# Rademacher distribution {-1.0, 1.0}
129-
zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
130+
zs = [2.0 * torch.randint(0, 2, p.size()).float().requires_grad_(True) - 1.0 for p in params]
130131

131132
# note that, possible memory leak due to retrain_graph=True
132133
h_zs = torch.autograd.grad(
@@ -141,7 +142,6 @@ def set_hessian(self):
141142
# approximate the expected values of z * (H@z)
142143
p.hess += h_z * z / self.num_samples
143144

144-
@torch.no_grad()
145145
def step(self, closure: CLOSURE = None) -> LOSS:
146146
loss: LOSS = None
147147
if closure is not None:
@@ -156,7 +156,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
156156
continue
157157

158158
if self.average_conv_kernel and p.dim() == 4:
159-
p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
159+
p.hess = torch.abs(p.hess).mean(dim=(2, 3), keepdim=True).expand_as(p.hess).clone()
160160

161161
# Perform correct step-weight decay as in AdamW
162162
p.mul_(1.0 - group['lr'] * group['weight_decay'])

0 commit comments

Comments
 (0)