Skip to content

Commit baa2fd2

Browse files
authored
[Docs] Convert to google-style docstring (#449)
* docs: google-style docstring * build(docs): google-style docstring * build(deps): mkdocstrings-python * docs: google-style docstring * docs: google-style docstring * docs: google-style docstring * docs: google-style docstring * docs: v3.8.2 changelog * docs: docstring * docs: docstring * update: test cases * docs: docstring
1 parent 0283cf5 commit baa2fd2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+2366
-2154
lines changed

docs/changelogs/v3.8.2.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@
77
### Update
88

99
* Refactor the type hints. (#448)
10+
11+
### Docs
12+
13+
* Convert the docstring style from reST to google-style docstring. (#449)

mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ plugins:
5757
allow_inspection: true
5858
show_bases: true
5959
show_source: true
60-
docstring_style: sphinx
60+
docstring_style: google
6161
- git-revision-date-localized:
6262
enabled: !ENV [ DEPLOY, false ]
6363
enable_creation_date: true

pytorch_optimizer/base/exception.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
class NoSparseGradientError(Exception):
2-
"""Raised when the gradient is sparse gradient.
2+
r"""Raised when the gradient is sparse gradient.
33
4-
:param optimizer_name: str. optimizer name.
5-
:param note: str. special conditions to note (default '').
4+
Args:
5+
optimizer_name (str): The name of the optimizer where the error occurred.
6+
note (str): Additional special conditions or notes (default is an empty string).
67
"""
78

89
def __init__(self, optimizer_name: str, note: str = ''):
@@ -46,10 +47,11 @@ def __init__(self, num_steps: int, step_type: str = ''):
4647

4748

4849
class NoComplexParameterError(Exception):
49-
"""Raised when the dtype of the parameter is complex.
50+
r"""Raised when the dtype of the parameter is complex.
5051
51-
:param optimizer_name: str. optimizer name.
52-
:param note: str. special conditions to note (default '').
52+
Args:
53+
optimizer_name (str): The name of the optimizer where the error occurred.
54+
note (str): Additional special conditions or notes (default is an empty string).
5355
"""
5456

5557
def __init__(self, optimizer_name: str, note: str = ''):

pytorch_optimizer/base/optimizer.py

Lines changed: 85 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020

2121

2222
class BaseOptimizer(ABC, Optimizer):
23-
r"""Base optimizer class. Provides common functionalities for the optimizers."""
23+
"""Base optimizer class. Provides common functionalities for the optimizers."""
2424

2525
def __init__(self, params: Parameters, defaults: Defaults) -> None:
2626
super().__init__(params, defaults)
2727

2828
@staticmethod
2929
def load_optimizer(optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> Optimizer:
30-
r"""Build torch.optim.Optimizer class."""
30+
"""Build torch.optim.Optimizer class."""
3131
if isinstance(optimizer, Optimizer):
3232
return optimizer
3333

@@ -40,22 +40,22 @@ def load_optimizer(optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> Optimize
4040
@staticmethod
4141
@torch.no_grad()
4242
def set_hessian(param_groups: Parameters, state: State, hessian: List[torch.Tensor]) -> None:
43-
r"""Set hessian to state from external source. Generally useful when using functorch as a base.
43+
"""Set hessian to state from external source. Generally useful when using functorch as a base.
44+
45+
Args:
46+
param_groups: PARAMETERS. Parameter groups from optimizer.
47+
state: STATE. Optimizer state dictionary.
48+
hessian: List[torch.Tensor]. Sequence of Hessian tensors to set.
4449
4550
Example:
46-
-------
47-
# Hutchinson's Estimator using HVP
48-
noise = tree_map(lambda v: torch.randn_like(v), params)
49-
loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
50-
hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
51+
# Hutchinson's Estimator using Hessian-vector product (HVP)
52+
>>> noise = tree_map(lambda v: torch.randn_like(v), params)
53+
>>> loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
54+
>>> hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
5155
52-
optimizer.set_hessian(hessian_diag_est)
56+
>>> optimizer.set_hessian(hessian_diag_est)
5357
# OR
54-
optimizer.step(hessian=hessian_diag_est)
55-
56-
:param param_groups: PARAMETERS. parameter groups.
57-
:param state: STATE. optimizer state.
58-
:param hessian: List[torch.Tensor]. sequence of hessian to set.
58+
>>> optimizer.step(hessian=hessian_diag_est)
5959
"""
6060
i: int = 0
6161
for group in param_groups:
@@ -70,11 +70,12 @@ def set_hessian(param_groups: Parameters, state: State, hessian: List[torch.Tens
7070

7171
@staticmethod
7272
def zero_hessian(param_groups: Parameters, state: State, pre_zero: bool = True) -> None:
73-
r"""Zero-out hessian.
73+
"""Zero-out Hessian.
7474
75-
:param param_groups: PARAMETERS. parameter groups.
76-
:param state: STATE. optimizer state.
77-
:param pre_zero: bool. zero-out hessian before computing the hessian.
75+
Args:
76+
param_groups (Parameters): Parameter groups from the optimizer.
77+
state (State): Optimizer state dictionary.
78+
pre_zero (bool): If True, zero-out the Hessian before computing/updating it.
7879
"""
7980
for group in param_groups:
8081
for p in group['params']:
@@ -93,13 +94,14 @@ def compute_hutchinson_hessian(
9394
alpha: float = 1.0,
9495
distribution: HUTCHINSON_G = 'gaussian',
9596
) -> None:
96-
r"""Hutchinson's approximate hessian, added to the state under key `hessian`.
97-
98-
:param param_groups: PARAMETERS. parameter groups.
99-
:param state: STATE. optimizer state.
100-
:param num_samples: int. number of times to sample `z` for the approximation of the hessian trace.
101-
:param alpha: float. alpha.
102-
:param distribution: HUTCHINSON_G. type of distribution.
97+
r"""Hutchinson's approximate Hessian, added to the state under key `hessian`.
98+
99+
Args:
100+
param_groups (Parameters): Parameter groups from the optimizer.
101+
state (State): Optimizer state dictionary.
102+
num_samples (int): Number of times to sample noise vector `z` for the trace approximation.
103+
alpha (float): Scaling factor for the Hessian estimate.
104+
distribution (HUTCHINSON_G): Type of noise distribution used (e.g., Rademacher).
103105
"""
104106
if distribution not in ('gaussian', 'rademacher'):
105107
raise NotImplementedError(f'hessian with distribution {distribution} is not implemented.')
@@ -135,15 +137,16 @@ def apply_weight_decay(
135137
fixed_decay: bool,
136138
ratio: Optional[float] = None,
137139
) -> None:
138-
r"""Apply weight decay.
139-
140-
:param p: torch.Tensor. parameter.
141-
:param grad: torch.Tensor. gradient.
142-
:param lr: float. learning rate.
143-
:param weight_decay: float. weight decay (L2 penalty).
144-
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
145-
:param fixed_decay: bool. fix weight decay.
146-
:param ratio: Optional[float]. scale weight decay.
140+
"""Apply weight decay.
141+
142+
Args:
143+
p (torch.Tensor): Parameter tensor to apply weight decay to.
144+
grad (torch.Tensor): Gradient tensor of parameter p.
145+
lr (float): Learning rate to scale the update.
146+
weight_decay (float): Weight decay coefficient (L2 penalty).
147+
weight_decouple (bool): If True, applies decoupled weight decay as in AdamW.
148+
fixed_decay (bool): If True, fixes weight decay to not depend on learning rate.
149+
ratio (Optional[float]): Optional scaling factor for weight decay.
147150
"""
148151
if weight_decouple:
149152
p.mul_(1.0 - weight_decay * (1.0 if fixed_decay else lr) * (ratio if ratio is not None else 1.0))
@@ -158,13 +161,14 @@ def apply_ams_bound(
158161
eps: float,
159162
exp_avg_sq_eps: float = 1e-15,
160163
) -> torch.Tensor:
161-
r"""Apply AMSBound variant.
162-
163-
:param ams_bound: bool. whether to apply AMSBound.
164-
:param exp_avg_sq: torch.Tensor. exp_avg_sq.
165-
:param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
166-
:param eps: float. epsilon.
167-
:param exp_avg_sq_eps: float. eps value for numerical stability for exp_avg_sq.
164+
"""Apply AMSBound variant.
165+
166+
Args:
167+
ams_bound (bool): Whether to apply the AMSBound variant.
168+
exp_avg_sq (torch.Tensor): Exponential moving average of squared gradients.
169+
max_exp_avg_sq (Optional[torch.Tensor]): Maximum of all exp_avg_sq elements, for AMSBound.
170+
eps (float): Small epsilon value for numerical stability.
171+
exp_avg_sq_eps (float): Epsilon used specifically for numerical stability in exp_avg_sq computations.
168172
"""
169173
if ams_bound:
170174
if torch.is_complex(max_exp_avg_sq):
@@ -179,10 +183,11 @@ def apply_ams_bound(
179183

180184
@staticmethod
181185
def debias(beta: float, step: int) -> float:
182-
r"""Adam-style debias correction. Returns `1.0 - beta ** step`.
186+
"""Adam-style debias correction.
183187
184-
:param beta: float. beta.
185-
:param step: int. number of step.
188+
Args:
189+
beta (float): Exponential decay rate for moment estimates.
190+
step (int): Current optimization step number.
186191
"""
187192
return 1.0 - math.pow(beta, step) # fmt: skip
188193

@@ -192,19 +197,21 @@ def debias_beta(beta: float, step: int) -> float:
192197
193198
Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`
194199
195-
:param beta: float. beta.
196-
:param step: int. number of step.
200+
Args:
201+
beta (float): The original beta decay rate.
202+
step (int): Current optimization step number.
197203
"""
198204
beta_n: float = math.pow(beta, step)
199205
return (beta_n - beta) / (beta_n - 1.0) # fmt: skip
200206

201207
@staticmethod
202208
def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
203-
r"""Apply AdamD variant.
209+
"""Apply AdamD variant.
204210
205-
:param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
206-
:param step_size: float. step size.
207-
:param bias_correction1: float. bias_correction.
211+
Args:
212+
adam_debias (bool): If True, only corrects the denominator to avoid inflating step sizes early in training.
213+
step_size (float): The step size for the update.
214+
bias_correction1 (float): The bias correction factor for the first moment.
208215
"""
209216
return step_size if adam_debias else step_size / bias_correction1
210217

@@ -217,14 +224,15 @@ def get_rectify_step_size(
217224
n_sma_threshold: int,
218225
degenerated_to_sgd: bool,
219226
) -> Tuple[float, float]:
220-
r"""Get step size for rectify optimizer.
221-
222-
:param is_rectify: bool. whether to apply rectify-variant.
223-
:param step: int. number of steps.
224-
:param lr: float. learning rate.
225-
:param beta2: float. beta2.
226-
:param n_sma_threshold: float. SMA threshold.
227-
:param degenerated_to_sgd: bool. degenerated to SGD.
227+
"""Get step size for rectify optimizer.
228+
229+
Args:
230+
is_rectify (bool): Whether to apply the rectify variant.
231+
step (int): Current step number.
232+
lr (float): Base learning rate.
233+
beta2 (float): Beta2 parameter from optimizer (momentum term).
234+
n_sma_threshold (float): Simple Moving Average (SMA) threshold for rectification.
235+
degenerated_to_sgd (bool): Whether to degenerate to SGD if below threshold.
228236
"""
229237
step_size: float = lr
230238
n_sma: float = 0.0
@@ -253,10 +261,11 @@ def get_adanorm_gradient(
253261
) -> torch.Tensor:
254262
r"""Get AdaNorm gradient.
255263
256-
:param grad: torch.Tensor. gradient.
257-
:param adanorm: bool. whether to use the AdaNorm variant.
258-
:param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
259-
:param r: Optional[float]. EMA factor. between 0.9 ~ 0.99 is preferred.
264+
Args:
265+
grad (torch.Tensor): Gradient.
266+
adanorm (bool): Whether to use the AdaNorm variant.
267+
exp_grad_norm (Optional[torch.Tensor]): Exponential moving average of gradient norm.
268+
r (Optional[float]): EMA factor; between 0.9 and 0.99 is preferred.
260269
"""
261270
if not adanorm or exp_grad_norm is None:
262271
return grad
@@ -272,7 +281,7 @@ def get_adanorm_gradient(
272281

273282
@staticmethod
274283
def get_rms(x: torch.Tensor) -> torch.Tensor:
275-
r"""Get RMS."""
284+
"""Get RMS."""
276285
return x.norm(2) / math.sqrt(x.numel())
277286

278287
@staticmethod
@@ -281,29 +290,31 @@ def approximate_sq_grad(
281290
exp_avg_sq_col: torch.Tensor,
282291
output: torch.Tensor,
283292
) -> None:
284-
r"""Get approximation of EMA of squared gradient."""
293+
"""Get approximation of EMA of squared gradient."""
285294
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
286295
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
287296
torch.mul(r_factor, c_factor, out=output)
288297

289298
@staticmethod
290299
def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None:
291-
r"""Apply the Cautious Optimizer feature.
300+
"""Apply the Cautious Optimizer feature.
292301
293-
:param update: torch.Tensor. update. it'll be masked in in-place manner.
294-
:param grad: torch.Tensor. gradient.
302+
Args:
303+
update (torch.Tensor): update. it'll be masked in in-place manner.
304+
grad (torch.Tensor): gradient.
295305
"""
296306
mask = (update * grad > 0).to(grad.dtype)
297307
mask.mul_(mask.numel() / (mask.sum() + 1))
298308
update.mul_(mask)
299309

300310
@staticmethod
301311
def get_stable_adamw_rms(grad: torch.Tensor, exp_avg_sq: torch.Tensor, eps: float = 1e-16) -> float:
302-
r"""Get StableAdamW RMS.
312+
"""Get StableAdamW RMS.
303313
304-
:param grad: torch.Tensor. gradient.
305-
:param exp_avg_sq: torch.Tensor. exp_avg_sq.
306-
:param eps: float. epsilon.
314+
Args:
315+
grad (torch.Tensor): gradient.
316+
exp_avg_sq (torch.Tensor): Exponential moving average of squared gradient.
317+
eps (float): Small value to prevent division by zero.
307318
"""
308319
return grad.pow(2).div_(exp_avg_sq.clip(min=eps)).mean().sqrt_().clip_(min=1.0).item()
309320

@@ -382,12 +393,12 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
382393

383394
@abstractmethod
384395
def init_group(self, group: ParamGroup, **kwargs) -> None: # pragma: no cover
385-
r"""Initialize the group of the optimizer and return is_complex."""
396+
"""Initialize the group of the optimizer and return is_complex."""
386397
return
387398

388399
@staticmethod
389400
def view_as_real(param, *state_and_grads) -> tuple:
390-
r"""View imaginary tensors as real tensors."""
401+
"""View imaginary tensors as real tensors."""
391402
if torch.is_complex(param):
392403
param = torch.view_as_real(param)
393404
state_and_grads = tuple(
@@ -399,7 +410,7 @@ def view_as_real(param, *state_and_grads) -> tuple:
399410

400411
@staticmethod
401412
def maximize_gradient(grad: torch.Tensor, maximize: bool = False) -> None:
402-
r"""Maximize the objective with respect to the params, instead of minimizing."""
413+
"""Maximize the objective with respect to the params, instead of minimizing."""
403414
if maximize:
404415
grad.neg_()
405416

pytorch_optimizer/base/scheduler.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77

88

99
class BaseLinearWarmupScheduler(ABC):
10-
r"""BaseLinearWarmupScheduler class.
11-
12-
The LR Scheduler class based on this class has linear warmup strategy.
13-
14-
:param optimizer: Optimizer. It will set learning rate to all trainable parameters in optimizer.
15-
:param t_max: int. total steps to train.
16-
:param max_lr: float. maximum lr.
17-
:param min_lr: float. minimum lr.
18-
:param init_lr: float. initial lr.
19-
:param warmup_steps: int. steps to warm-up.
10+
"""BaseLinearWarmupScheduler class.
11+
12+
A learning rate scheduler class that implements a linear warmup strategy.
13+
14+
Args:
15+
optimizer (Optimizer): The optimizer whose learning rate will be scheduled.
16+
It will set the learning rate to all trainable parameters in the optimizer.
17+
t_max (int): Total number of training steps (epochs or iterations).
18+
max_lr (float): The maximum learning rate after warmup.
19+
min_lr (float): The minimum learning rate to decay to (or start from if warmup).
20+
init_lr (float): Initial learning rate at the start of warmup.
21+
warmup_steps (int): Number of steps to warm up linearly from init_lr to max_lr.
2022
"""
2123

2224
def __init__(

pytorch_optimizer/base/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
Defaults = Dict[str, Any]
1212
ParamGroup = Dict[str, Any]
13-
State = Dict[str, Any]
13+
State = Dict
1414
Parameters = Optional[Union[Iterable[torch.Tensor], Iterable[ParamGroup]]]
1515

1616
Closure = Optional[Callable[[], float]]

0 commit comments

Comments
 (0)