|
4 | 4 |
|
5 | 5 | import torch |
6 | 6 |
|
7 | | -from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError, NoSparseGradientError |
8 | | -from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G |
| 7 | +from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError |
| 8 | +from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G, PARAMETERS, STATE |
9 | 9 |
|
10 | 10 |
|
11 | 11 | class BaseOptimizer(ABC): |
12 | 12 | r"""Base optimizer class.""" |
13 | 13 |
|
| 14 | + @staticmethod |
14 | 15 | @torch.no_grad() |
15 | | - def set_hessian(self, hessian): |
16 | | - """ |
17 | | - Helper function to set hessian state from external source |
18 | | - Generally useful when using functorch as a base |
| 16 | + def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]): |
| 17 | + r"""Set hessian to state from external source. Generally useful when using functorch as a base. |
19 | 18 |
|
20 | 19 | Example usage: |
21 | 20 | ``` |
22 | 21 | # Hutchinsons Estimator using HVP |
23 | 22 | noise = tree_map(lambda v: torch.randn_like(v), params) |
24 | 23 | loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,)) |
25 | | - hessian_diag_est = tree_map(lambda a, b: a*b, hvp_est, noise) |
| 24 | + hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise) |
26 | 25 |
|
27 | 26 | optimizer.set_hessian(hessian_diag_est) |
28 | 27 | # OR |
29 | 28 | optimizer.step(hessian=hessian_diag_est) |
30 | 29 | ```` |
31 | | -
|
32 | 30 | """ |
33 | | - i = 0 |
34 | | - for group in self.param_groups: |
| 31 | + i: int = 0 |
| 32 | + for group in param_groups: |
35 | 33 | for p in group['params']: |
36 | | - assert p.shape == hessian[i].shape |
37 | | - self.state[p]['hessian'] = hessian[i] |
| 34 | + if p.size() != hessian[i].size(): |
| 35 | + raise ValueError( |
| 36 | + f'[-] the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}' |
| 37 | + ) |
| 38 | + |
| 39 | + state[p]['hessian'] = hessian[i] |
38 | 40 | i += 1 |
39 | 41 |
|
| 42 | + @staticmethod |
40 | 43 | @torch.no_grad() |
41 | | - def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0, distribution: HUTCHINSON_G = 'gaussian'): |
42 | | - """ |
43 | | - Hutchinsons approximate hessian, added to the state under key 'hessian' |
44 | | - """ |
45 | | - if distribution not in ['gaussian', 'rademacher']: |
46 | | - raise NotImplementedError(f"Hessian with distribution {distribution} is not implemented") |
| 44 | + def compute_hutchinson_hessian( |
| 45 | + param_groups: PARAMETERS, |
| 46 | + state: STATE, |
| 47 | + num_samples: int = 1, |
| 48 | + pre_zero: bool = True, |
| 49 | + alpha: float = 1.0, |
| 50 | + distribution: HUTCHINSON_G = 'gaussian', |
| 51 | + ): |
| 52 | + r"""Hutchinson's approximate hessian, added to the state under key `hessian`.""" |
| 53 | + if distribution not in ('gaussian', 'rademacher'): |
| 54 | + raise NotImplementedError(f'[-] Hessian with distribution {distribution} is not implemented.') |
47 | 55 |
|
48 | 56 | params = [] |
49 | | - for group in self.param_groups: |
| 57 | + for group in param_groups: |
50 | 58 | for p in group['params']: |
51 | | - if p.requires_grad and p.grad is not None: |
52 | | - if p.grad.is_sparse: |
53 | | - raise NoSparseGradientError(str(self)) |
54 | | - # Initialize Hessian state |
55 | | - if 'hessian' in self.state[p]: |
56 | | - if pre_zero: |
57 | | - self.state[p]['hessian'].zero_() |
58 | | - else: |
59 | | - self.state[p]['hessian'] = torch.zeros_like(p.data) |
| 59 | + if p.requires_grad and p.grad is not None and not p.grad.is_sparse: |
| 60 | + if 'hessian' not in state[p]: |
| 61 | + state[p]['hessian'] = torch.zeros_like(p) |
| 62 | + elif pre_zero: |
| 63 | + state[p]['hessian'].zero_() |
| 64 | + |
60 | 65 | params.append(p) |
61 | 66 |
|
62 | 67 | if len(params) == 0: |
63 | 68 | return |
64 | 69 |
|
65 | 70 | grads = [p.grad for p in params] |
66 | 71 |
|
67 | | - for i in range(nsamples): |
68 | | - if distribution == 'gaussian': |
69 | | - # Gaussian N(0,Id) |
70 | | - zs = [torch.randn(p.size(), device=p.device) for p in params] |
71 | | - elif distribution == 'rademacher': |
72 | | - # Rademacher distribution {-1.0, 1.0} |
73 | | - zs = [torch.randint(0, 2, p.size(), dtype=p.dtype, device=p.device) * 2.0 - 1.0 for p in params] |
| 72 | + for i in range(num_samples): |
| 73 | + if distribution == 'rademacher': |
| 74 | + zs = [torch.randint_like(p, 0, 1) * 2.0 - 1.0 for p in params] |
| 75 | + else: |
| 76 | + zs = [torch.randn_like(p) for p in params] |
74 | 77 |
|
75 | | - h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < nsamples - 1) |
| 78 | + h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1) |
76 | 79 | for h_z, z, p in zip(h_zs, zs, params): |
77 | | - # approximate the expected values of z*(H@z) |
78 | | - self.state[p]['hessian'].add_(h_z * z, alpha=(1/nsamples) * alpha) |
| 80 | + state[p]['hessian'].add_(h_z * z, alpha=(1 / num_samples) * alpha) |
79 | 81 |
|
80 | 82 | @staticmethod |
81 | 83 | def apply_weight_decay( |
|
0 commit comments