@@ -16,17 +16,22 @@ class BaseOptimizer(ABC):
1616 def set_hessian (param_groups : PARAMETERS , state : STATE , hessian : List [torch .Tensor ]):
1717 r"""Set hessian to state from external source. Generally useful when using functorch as a base.
1818
19- Example usage:
20- ```
21- # Hutchinsons Estimator using HVP
22- noise = tree_map(lambda v: torch.randn_like(v), params)
23- loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
24- hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
25-
26- optimizer.set_hessian(hessian_diag_est)
27- # OR
28- optimizer.step(hessian=hessian_diag_est)
29- ````
19+ Example:
20+ -------
21+ Here's an example::
22+
23+ # Hutchinson's Estimator using HVP
24+ noise = tree_map(lambda v: torch.randn_like(v), params)
25+ loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
26+ hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
27+
28+ optimizer.set_hessian(hessian_diag_est)
29+ # OR
30+ optimizer.step(hessian=hessian_diag_est)
31+
32+ :param param_groups: PARAMETERS. parameter groups.
33+ :param state: STATE. optimizer state.
34+ :param hessian: List[torch.Tensor]. sequence of hessian to set.
3035 """
3136 i : int = 0
3237 for group in param_groups :
@@ -39,31 +44,48 @@ def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tens
3944 state [p ]['hessian' ] = hessian [i ]
4045 i += 1
4146
47+ @staticmethod
48+ def zero_hessian (param_groups : PARAMETERS , state : STATE , pre_zero : bool = True ):
49+ r"""Zero-out hessian.
50+
51+ :param param_groups: PARAMETERS. parameter groups.
52+ :param state: STATE. optimizer state.
53+ :param pre_zero: bool. zero-out hessian before computing the hessian.
54+ """
55+ for group in param_groups :
56+ for p in group ['params' ]:
57+ if p .requires_grad and p .grad is not None and not p .grad .is_sparse :
58+ if 'hessian' not in state [p ]:
59+ state [p ]['hessian' ] = torch .zeros_like (p )
60+ elif pre_zero :
61+ state [p ]['hessian' ].zero_ ()
62+
4263 @staticmethod
4364 @torch .no_grad ()
4465 def compute_hutchinson_hessian (
4566 param_groups : PARAMETERS ,
4667 state : STATE ,
4768 num_samples : int = 1 ,
48- pre_zero : bool = True ,
4969 alpha : float = 1.0 ,
5070 distribution : HUTCHINSON_G = 'gaussian' ,
5171 ):
52- r"""Hutchinson's approximate hessian, added to the state under key `hessian`."""
72+ r"""Hutchinson's approximate hessian, added to the state under key `hessian`.
73+
74+ :param param_groups: PARAMETERS. parameter groups.
75+ :param state: STATE. optimizer state.
76+ :param num_samples: int. number of times to sample `z` for the approximation of the hessian trace.
77+ :param alpha: float. alpha.
78+ :param distribution: HUTCHINSON_G. type of distribution.
79+ """
5380 if distribution not in ('gaussian' , 'rademacher' ):
5481 raise NotImplementedError (f'[-] Hessian with distribution { distribution } is not implemented.' )
5582
56- params = []
57- for group in param_groups :
58- for p in group ['params' ]:
59- if p .requires_grad and p .grad is not None and not p .grad .is_sparse :
60- if 'hessian' not in state [p ]:
61- state [p ]['hessian' ] = torch .zeros_like (p )
62- elif pre_zero :
63- state [p ]['hessian' ].zero_ ()
64-
65- params .append (p )
66-
83+ params : List [torch .Tensor ] = [
84+ p
85+ for group in param_groups
86+ for p in group ['params' ]
87+ if p .requires_grad and p .grad is not None and not p .grad .is_sparse
88+ ]
6789 if len (params ) == 0 :
6890 return
6991
@@ -77,7 +99,7 @@ def compute_hutchinson_hessian(
7799
78100 h_zs = torch .autograd .grad (grads , params , grad_outputs = zs , retain_graph = i < num_samples - 1 )
79101 for h_z , z , p in zip (h_zs , zs , params ):
80- state [p ]['hessian' ].add_ (h_z * z , alpha = ( 1 / num_samples ) * alpha )
102+ state [p ]['hessian' ].add_ (h_z * z , alpha = alpha / num_samples )
81103
82104 @staticmethod
83105 def apply_weight_decay (
@@ -89,7 +111,16 @@ def apply_weight_decay(
89111 fixed_decay : bool ,
90112 ratio : Optional [float ] = None ,
91113 ):
92- r"""Apply weight decay."""
114+ r"""Apply weight decay.
115+
116+ :param p: torch.Tensor. parameter.
117+ :param grad: torch.Tensor. gradient.
118+ :param lr: float. learning rate.
119+ :param weight_decay: float. weight decay (L2 penalty).
120+ :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
121+ :param fixed_decay: bool. fix weight decay.
122+ :param ratio: Optional[float]. scale weight decay.
123+ """
93124 if weight_decouple :
94125 p .mul_ (1.0 - weight_decay * (1.0 if fixed_decay else lr ) * (ratio if ratio is not None else 1.0 ))
95126 elif weight_decay > 0.0 and grad is not None :
@@ -99,7 +130,13 @@ def apply_weight_decay(
99130 def apply_ams_bound (
100131 ams_bound : bool , exp_avg_sq : torch .Tensor , max_exp_avg_sq : Optional [torch .Tensor ], eps : float
101132 ) -> torch .Tensor :
102- r"""Apply AMSBound variant."""
133+ r"""Apply AMSBound variant.
134+
135+ :param ams_bound: bool. whether to apply AMSBound.
136+ :param exp_avg_sq: torch.Tensor. exp_avg_sq.
137+ :param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
138+ :param eps: float. epsilon.
139+ """
103140 if ams_bound :
104141 torch .max (max_exp_avg_sq , exp_avg_sq , out = max_exp_avg_sq )
105142 de_nom = max_exp_avg_sq .add (eps )
@@ -110,7 +147,12 @@ def apply_ams_bound(
110147
111148 @staticmethod
112149 def apply_adam_debias (adam_debias : bool , step_size : float , bias_correction1 : float ) -> float :
113- r"""Apply AdamD variant."""
150+ r"""Apply AdamD variant.
151+
152+ :param adam_debias: bool. whether to apply AdamD.
153+ :param step_size: float. step size.
154+ :param bias_correction1: float. bias_correction.
155+ """
114156 return step_size if adam_debias else step_size / bias_correction1
115157
116158 @staticmethod
@@ -122,7 +164,15 @@ def get_rectify_step_size(
122164 n_sma_threshold : int ,
123165 degenerated_to_sgd : bool ,
124166 ) -> Tuple [float , float ]:
125- r"""Get step size for rectify optimizer."""
167+ r"""Get step size for rectify optimizer.
168+
169+ :param is_rectify: bool. whether to apply rectify-variant.
170+ :param step: int. number of steps.
171+ :param lr: float. learning rate.
172+ :param beta2: float. beta2.
173+ :param n_sma_threshold: float. SMA threshold.
174+ :param degenerated_to_sgd: bool. degenerated to SGD.
175+ """
126176 step_size : float = lr
127177 n_sma : float = 0.0
128178
@@ -148,7 +198,13 @@ def get_rectify_step_size(
148198 def get_adanorm_gradient (
149199 grad : torch .Tensor , adanorm : bool , exp_grad_norm : Optional [torch .Tensor ] = None , r : Optional [float ] = 0.95
150200 ) -> torch .Tensor :
151- r"""Get AdaNorm gradient."""
201+ r"""Get AdaNorm gradient.
202+
203+ :param grad. torch.Tensor. gradient.
204+ :param adanorm: bool. whether to apply AdaNorm.
205+ :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
206+ :param r: float. Optional[float]. momentum (ratio).
207+ """
152208 if not adanorm :
153209 return grad
154210
0 commit comments