Skip to content

Commit c56b36a

Browse files
authored
Merge pull request #180 from kozistr/refactor/modules
[Refactor] Hessian methods
2 parents 5deede6 + ae2ec8a commit c56b36a

File tree

8 files changed

+125
-40
lines changed

8 files changed

+125
-40
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Getting Started
2525

2626
For more, see the `documentation <https://pytorch-optimizers.readthedocs.io/en/latest/>`__.
2727

28-
Most optimizers are under MIT or Apache 2.0 license, but a few optimizers like `Fromage` have BY-NC-SA 4.0 license, which is non-commercial.
28+
Most optimizers are under MIT or Apache 2.0 license, but a few optimizers like `Fromage`, `Nero` have BY-NC-SA 4.0 license, which is non-commercial.
2929
So, please double-check the license before using it at your work.
3030

3131
Installation

docs/base_api.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
BaseOptimizer
2+
=============
3+
4+
.. _BaseOptimizer:
5+
6+
BaseOptimizer
7+
-------------
8+
9+
.. autoclass:: pytorch_optimizer.base.optimizer.BaseOptimizer
10+
:members:

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Contents
1818
.. toctree::
1919
:maxdepth: 2
2020

21+
base_api
2122
optimizer_api
2223
scheduler_api
2324
util_api

pytorch_optimizer/base/optimizer.py

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,22 @@ class BaseOptimizer(ABC):
1616
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]):
1717
r"""Set hessian to state from external source. Generally useful when using functorch as a base.
1818
19-
Example usage:
20-
```
21-
# Hutchinsons Estimator using HVP
22-
noise = tree_map(lambda v: torch.randn_like(v), params)
23-
loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
24-
hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
25-
26-
optimizer.set_hessian(hessian_diag_est)
27-
# OR
28-
optimizer.step(hessian=hessian_diag_est)
29-
````
19+
Example:
20+
-------
21+
Here's an example::
22+
23+
# Hutchinson's Estimator using HVP
24+
noise = tree_map(lambda v: torch.randn_like(v), params)
25+
loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
26+
hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
27+
28+
optimizer.set_hessian(hessian_diag_est)
29+
# OR
30+
optimizer.step(hessian=hessian_diag_est)
31+
32+
:param param_groups: PARAMETERS. parameter groups.
33+
:param state: STATE. optimizer state.
34+
:param hessian: List[torch.Tensor]. sequence of hessian to set.
3035
"""
3136
i: int = 0
3237
for group in param_groups:
@@ -39,31 +44,48 @@ def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tens
3944
state[p]['hessian'] = hessian[i]
4045
i += 1
4146

47+
@staticmethod
48+
def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True):
49+
r"""Zero-out hessian.
50+
51+
:param param_groups: PARAMETERS. parameter groups.
52+
:param state: STATE. optimizer state.
53+
:param pre_zero: bool. zero-out hessian before computing the hessian.
54+
"""
55+
for group in param_groups:
56+
for p in group['params']:
57+
if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
58+
if 'hessian' not in state[p]:
59+
state[p]['hessian'] = torch.zeros_like(p)
60+
elif pre_zero:
61+
state[p]['hessian'].zero_()
62+
4263
@staticmethod
4364
@torch.no_grad()
4465
def compute_hutchinson_hessian(
4566
param_groups: PARAMETERS,
4667
state: STATE,
4768
num_samples: int = 1,
48-
pre_zero: bool = True,
4969
alpha: float = 1.0,
5070
distribution: HUTCHINSON_G = 'gaussian',
5171
):
52-
r"""Hutchinson's approximate hessian, added to the state under key `hessian`."""
72+
r"""Hutchinson's approximate hessian, added to the state under key `hessian`.
73+
74+
:param param_groups: PARAMETERS. parameter groups.
75+
:param state: STATE. optimizer state.
76+
:param num_samples: int. number of times to sample `z` for the approximation of the hessian trace.
77+
:param alpha: float. alpha.
78+
:param distribution: HUTCHINSON_G. type of distribution.
79+
"""
5380
if distribution not in ('gaussian', 'rademacher'):
5481
raise NotImplementedError(f'[-] Hessian with distribution {distribution} is not implemented.')
5582

56-
params = []
57-
for group in param_groups:
58-
for p in group['params']:
59-
if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
60-
if 'hessian' not in state[p]:
61-
state[p]['hessian'] = torch.zeros_like(p)
62-
elif pre_zero:
63-
state[p]['hessian'].zero_()
64-
65-
params.append(p)
66-
83+
params: List[torch.Tensor] = [
84+
p
85+
for group in param_groups
86+
for p in group['params']
87+
if p.requires_grad and p.grad is not None and not p.grad.is_sparse
88+
]
6789
if len(params) == 0:
6890
return
6991

@@ -77,7 +99,7 @@ def compute_hutchinson_hessian(
7799

78100
h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1)
79101
for h_z, z, p in zip(h_zs, zs, params):
80-
state[p]['hessian'].add_(h_z * z, alpha=(1 / num_samples) * alpha)
102+
state[p]['hessian'].add_(h_z * z, alpha=alpha / num_samples)
81103

82104
@staticmethod
83105
def apply_weight_decay(
@@ -89,7 +111,16 @@ def apply_weight_decay(
89111
fixed_decay: bool,
90112
ratio: Optional[float] = None,
91113
):
92-
r"""Apply weight decay."""
114+
r"""Apply weight decay.
115+
116+
:param p: torch.Tensor. parameter.
117+
:param grad: torch.Tensor. gradient.
118+
:param lr: float. learning rate.
119+
:param weight_decay: float. weight decay (L2 penalty).
120+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
121+
:param fixed_decay: bool. fix weight decay.
122+
:param ratio: Optional[float]. scale weight decay.
123+
"""
93124
if weight_decouple:
94125
p.mul_(1.0 - weight_decay * (1.0 if fixed_decay else lr) * (ratio if ratio is not None else 1.0))
95126
elif weight_decay > 0.0 and grad is not None:
@@ -99,7 +130,13 @@ def apply_weight_decay(
99130
def apply_ams_bound(
100131
ams_bound: bool, exp_avg_sq: torch.Tensor, max_exp_avg_sq: Optional[torch.Tensor], eps: float
101132
) -> torch.Tensor:
102-
r"""Apply AMSBound variant."""
133+
r"""Apply AMSBound variant.
134+
135+
:param ams_bound: bool. whether to apply AMSBound.
136+
:param exp_avg_sq: torch.Tensor. exp_avg_sq.
137+
:param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
138+
:param eps: float. epsilon.
139+
"""
103140
if ams_bound:
104141
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
105142
de_nom = max_exp_avg_sq.add(eps)
@@ -110,7 +147,12 @@ def apply_ams_bound(
110147

111148
@staticmethod
112149
def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
113-
r"""Apply AdamD variant."""
150+
r"""Apply AdamD variant.
151+
152+
:param adam_debias: bool. whether to apply AdamD.
153+
:param step_size: float. step size.
154+
:param bias_correction1: float. bias_correction.
155+
"""
114156
return step_size if adam_debias else step_size / bias_correction1
115157

116158
@staticmethod
@@ -122,7 +164,15 @@ def get_rectify_step_size(
122164
n_sma_threshold: int,
123165
degenerated_to_sgd: bool,
124166
) -> Tuple[float, float]:
125-
r"""Get step size for rectify optimizer."""
167+
r"""Get step size for rectify optimizer.
168+
169+
:param is_rectify: bool. whether to apply rectify-variant.
170+
:param step: int. number of steps.
171+
:param lr: float. learning rate.
172+
:param beta2: float. beta2.
173+
:param n_sma_threshold: float. SMA threshold.
174+
:param degenerated_to_sgd: bool. degenerated to SGD.
175+
"""
126176
step_size: float = lr
127177
n_sma: float = 0.0
128178

@@ -148,7 +198,13 @@ def get_rectify_step_size(
148198
def get_adanorm_gradient(
149199
grad: torch.Tensor, adanorm: bool, exp_grad_norm: Optional[torch.Tensor] = None, r: Optional[float] = 0.95
150200
) -> torch.Tensor:
151-
r"""Get AdaNorm gradient."""
201+
r"""Get AdaNorm gradient.
202+
203+
:param grad. torch.Tensor. gradient.
204+
:param adanorm: bool. whether to apply AdaNorm.
205+
:param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
206+
:param r: float. Optional[float]. momentum (ratio).
207+
"""
152208
if not adanorm:
153209
return grad
154210

pytorch_optimizer/loss/__init__.py

Whitespace-only changes.

pytorch_optimizer/optimizer/adahessian.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] =
8888
if hessian is not None:
8989
self.set_hessian(self.param_groups, self.state, hessian)
9090
elif step % self.update_period == 0:
91+
self.zero_hessian(self.param_groups, self.state)
9192
self.compute_hutchinson_hessian(
92-
self.param_groups, self.state, self.num_samples, distribution=self.distribution
93+
param_groups=self.param_groups,
94+
state=self.state,
95+
num_samples=self.num_samples,
96+
distribution=self.distribution,
9397
)
9498

9599
for group in self.param_groups:

pytorch_optimizer/optimizer/sgd.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __str__(self) -> str:
5050
@torch.no_grad()
5151
def reset(self):
5252
for group in self.param_groups:
53+
group['step'] = 0
5354
for p in group['params']:
5455
state = self.state[p]
5556

@@ -86,8 +87,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8687
if len(state) == 0:
8788
state['momentum_buffer'] = p.clone()
8889

89-
if group['weight_decay'] > 0.0:
90-
grad.add_(p, alpha=group['weight_decay'])
90+
self.apply_weight_decay(
91+
p,
92+
grad,
93+
lr=group['lr'],
94+
weight_decay=group['weight_decay'],
95+
weight_decouple=False,
96+
fixed_decay=False,
97+
)
9198

9299
buf = state['momentum_buffer']
93100
buf.mul_((1.0 / beta) - 1.0).add_(grad, alpha=-large_lr).add_(p).mul_(beta)
@@ -177,11 +184,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
177184
else:
178185
grad = buf
179186

180-
if group['weight_decay'] > 0.0:
181-
if group['weight_decouple']:
182-
p.mul_(1.0 - group['lr'] * group['weight_decay'])
183-
else:
184-
grad.add_(p, alpha=group['weight_decay'])
187+
self.apply_weight_decay(
188+
p,
189+
grad,
190+
lr=group['lr'],
191+
weight_decay=group['weight_decay'],
192+
weight_decouple=group['weight_decouple'],
193+
fixed_decay=False,
194+
)
185195

186196
p.add_(grad, alpha=-group['lr'])
187197

pytorch_optimizer/optimizer/sophia.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] =
8888
if hessian is not None:
8989
self.set_hessian(self.param_groups, self.state, hessian)
9090
elif step % self.update_period == 0:
91+
self.zero_hessian(self.param_groups, self.state)
9192
self.compute_hutchinson_hessian(
92-
self.param_groups, self.state, self.num_samples, distribution=self.distribution
93+
param_groups=self.param_groups,
94+
state=self.state,
95+
num_samples=self.num_samples,
96+
distribution=self.distribution,
9397
)
9498

9599
for group in self.param_groups:

0 commit comments

Comments
 (0)