Skip to content

Commit 37d2e03

Browse files
committed
docs: docstring
1 parent b4be50b commit 37d2e03

File tree

10 files changed

+61
-23
lines changed

10 files changed

+61
-23
lines changed

pytorch_optimizer/base/optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88

99
class BaseOptimizer(ABC):
10+
r"""Base optimizer class."""
11+
1012
@staticmethod
1113
def validate_learning_rate(learning_rate: float):
1214
if learning_rate < 0.0:

pytorch_optimizer/base/scheduler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77

88
class BaseLinearWarmupScheduler(ABC):
9-
r"""BaseLinearWarmupScheduler class. The LR Scheduler class based on this class has linear warmup strategy.
9+
r"""BaseLinearWarmupScheduler class.
10+
11+
The LR Scheduler class based on this class has linear warmup strategy.
1012
1113
:param optimizer: Optimizer. OPTIMIZER. It will set learning rate to all trainable parameters in optimizer.
1214
:param t_max: int. total steps to train.

pytorch_optimizer/lr_scheduler/chebyshev.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def chebyshev_steps(small_m: float, big_m: float, num_epochs: int) -> np.ndarray:
5-
r"""chebyshev_steps.
5+
r"""Chebyshev steps.
66
77
:param small_m: float. stands for 'm' notation.
88
:param big_m: float. stands for 'M' notation.
@@ -16,13 +16,15 @@ def chebyshev_steps(small_m: float, big_m: float, num_epochs: int) -> np.ndarray
1616

1717

1818
def chebyshev_perm(num_epochs: int) -> np.ndarray:
19+
r"""Chebyshev permutation."""
1920
perm = np.array([0])
2021
while len(perm) < num_epochs:
2122
perm = np.vstack([perm, 2 * len(perm) - 1 - perm]).T.flatten()
2223
return perm
2324

2425

2526
def get_chebyshev_schedule(num_epochs: int) -> np.ndarray:
27+
r"""Get Chebyshev schedules."""
2628
steps: np.ndarray = chebyshev_steps(0.1, 1, num_epochs - 2)
2729
perm: np.ndarray = chebyshev_perm(num_epochs - 2)
2830
return steps[perm]

pytorch_optimizer/lr_scheduler/linear_warmup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66

77

88
class LinearScheduler(BaseLinearWarmupScheduler):
9+
r"""Linear LR Scheduler w/ linear warmup."""
10+
911
def _step(self) -> float:
1012
return self.max_lr + (self.min_lr - self.max_lr) * (self.step_t - self.warmup_steps) / (
1113
self.total_steps - self.warmup_steps
1214
)
1315

1416

1517
class CosineScheduler(BaseLinearWarmupScheduler):
18+
r"""Cosine LR Scheduler w/ linear warmup."""
19+
1620
def _step(self) -> float:
1721
phase: float = (self.step_t - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi
1822
return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0

pytorch_optimizer/lr_scheduler/proportion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33

44
class ProportionScheduler:
5-
r"""ProportionScheduler (Rho Scheduler of GSAM)
5+
r"""ProportionScheduler (Rho Scheduler of GSAM).
6+
67
This scheduler outputs a value that evolves proportional to lr_scheduler.
78
89
:param lr_scheduler: learning rate scheduler.

pytorch_optimizer/optimizer/fp16.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class DynamicLossScaler:
1212
r"""Dynamically adjusts the loss scaling factor.
13+
1314
Dynamic loss scalers are important in mixed-precision training.
1415
They help us avoid underflows and overflows in low-precision gradients.
1516
@@ -50,8 +51,9 @@ def __init__(
5051

5152
def update_scale(self, overflow: bool):
5253
r"""Update the loss scale.
53-
If overflow exceeds our tolerance, we decrease the loss scale. If the number of
54-
iterations since the last overflow exceeds the scale window, we increase the loss scale.
54+
55+
If overflow exceeds our tolerance, we decrease the loss scale.
56+
If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.
5557
5658
:param overflow: bool. adjust scales to prevent overflow.
5759
"""
@@ -79,17 +81,31 @@ def update_scale(self, overflow: bool):
7981

8082
def decrease_loss_scale(self):
8183
r"""Decrease the loss scale by self.scale_factor.
82-
NOTE: the loss_scale will not go below self.threshold.
84+
85+
NOTE: the loss_scale will not go below `self.threshold`.
8386
"""
8487
self.loss_scale /= self.scale_factor
8588
if self.threshold is not None:
8689
self.loss_scale = max(self.loss_scale, self.threshold)
8790

8891

8992
class SafeFP16Optimizer(Optimizer):
90-
def __init__(self, optimizer: OPTIMIZER, aggregate_g_norms: bool = False):
93+
r"""Safe FP16 Optimizer.
94+
95+
:param optimizer: OPTIMIZER.
96+
:param aggregate_g_norms: bool. aggregate_g_norms.
97+
:param min_loss_scale: float. min_loss_scale.
98+
"""
99+
100+
def __init__(
101+
self,
102+
optimizer: OPTIMIZER,
103+
aggregate_g_norms: bool = False,
104+
min_loss_scale: float = 2 ** -5,
105+
): # fmt: skip
91106
self.optimizer = optimizer
92107
self.aggregate_g_norms = aggregate_g_norms
108+
self.min_loss_scale = min_loss_scale
93109

94110
self.fp16_params = self.get_parameters(optimizer)
95111
self.fp32_params = self.build_fp32_params(self.fp16_params, flatten=False)
@@ -104,7 +120,6 @@ def __init__(self, optimizer: OPTIMIZER, aggregate_g_norms: bool = False):
104120
optimizer.param_groups[0]['params'] = self.fp32_params
105121

106122
self.scaler: DynamicLossScaler = DynamicLossScaler(2.0 ** 15) # fmt: skip
107-
self.min_loss_scale: float = 2 ** -5 # fmt: skip
108123
self.needs_sync: bool = True
109124

110125
@classmethod
@@ -151,6 +166,7 @@ def state_dict(self) -> Dict:
151166

152167
def load_state_dict(self, state_dict: Dict):
153168
r"""Load an optimizer state dict.
169+
154170
In general, we should prefer the configuration of the existing optimizer instance
155171
(e.g., learning rate) over that found in the state_dict. This allows us to
156172
resume training from a checkpoint using a new set of optimizer args.
@@ -162,9 +178,13 @@ def load_state_dict(self, state_dict: Dict):
162178
self.optimizer.load_state_dict(state_dict)
163179

164180
def backward(self, loss, update_main_grads: bool = False):
165-
r"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
166-
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
167-
additionally dynamically scales the loss to avoid gradient underflow.
181+
r"""Compute the sum of gradients of the given tensor w.r.t. graph leaves.
182+
183+
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
184+
additionally dynamically scales the loss to avoid gradient underflow.
185+
186+
:param loss: float. loss.
187+
:param update_main_grads: bool. update main gradient.
168188
"""
169189
if self.scaler is not None:
170190
loss = loss * self.scaler.loss_scale
@@ -176,6 +196,7 @@ def backward(self, loss, update_main_grads: bool = False):
176196
self.update_main_grads()
177197

178198
def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
199+
r"""Sync fp16 to fp32 gradients."""
179200
if self.needs_sync:
180201
if self.scaler is not None:
181202
# correct for dynamic loss scaler
@@ -195,7 +216,7 @@ def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
195216
self.needs_sync = False
196217

197218
def multiply_grads(self, c: float):
198-
r"""Multiplies grads by a constant c."""
219+
r"""Multiply grads by a constant c."""
199220
if self.needs_sync:
200221
self.sync_fp16_grads_to_fp32(c)
201222
else:
@@ -206,7 +227,7 @@ def update_main_grads(self):
206227
self.sync_fp16_grads_to_fp32()
207228

208229
def clip_main_grads(self, max_norm: float):
209-
r"""Clips gradient norm and updates dynamic loss scaler."""
230+
r"""Clip gradient norm and updates dynamic loss scaler."""
210231
self.sync_fp16_grads_to_fp32()
211232

212233
grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_g_norms)
@@ -221,8 +242,8 @@ def clip_main_grads(self, max_norm: float):
221242
if overflow:
222243
self.zero_grad()
223244
if self.scaler.loss_scale <= self.min_loss_scale:
224-
# Use FloatingPointError as an uncommon error that parent
225-
# functions can safely catch to stop training.
245+
# Use FloatingPointError as an uncommon error
246+
# that parent functions can safely catch to stop training.
226247
self.scaler.loss_scale = prev_scale
227248

228249
raise FloatingPointError(
@@ -235,7 +256,7 @@ def clip_main_grads(self, max_norm: float):
235256
return grad_norm
236257

237258
def step(self, closure: CLOSURE = None):
238-
r"""Performs a single optimization step."""
259+
r"""Perform a single optimization step."""
239260
self.sync_fp16_grads_to_fp32()
240261
self.optimizer.step(closure)
241262

@@ -246,17 +267,19 @@ def step(self, closure: CLOSURE = None):
246267
p.data.copy_(p32)
247268

248269
def zero_grad(self):
249-
r"""Clears the gradients of all optimized parameters."""
270+
r"""Clear the gradients of all optimized parameters."""
250271
for p in self.fp16_params:
251272
p.grad = None
252273
for p32 in self.fp32_params:
253274
p32.grad.zero_()
254275
self.needs_sync = False
255276

256277
def get_lr(self) -> float:
278+
r"""Get learning rate."""
257279
return self.optimizer.get_lr()
258280

259281
def set_lr(self, lr: float):
282+
r"""Set learning rate."""
260283
self.optimizer.set_lr(lr)
261284

262285
@property

pytorch_optimizer/optimizer/gsam.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,11 @@ def maybe_no_sync(self):
178178

179179
@torch.no_grad()
180180
def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs):
181-
r"""set closure
182-
create self.forward_backward_func, which is a function such that self.forward_backward_func() automatically
183-
performs forward and backward passes. This function does not take any arguments, and the inputs and
184-
targets data should be pre-set in the definition of partial-function.
181+
r"""Set closure.
182+
183+
Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
184+
automatically performs forward and backward passes. This function does not take any arguments,
185+
and the inputs and targets data should be pre-set in the definition of partial-function.
185186
186187
:param loss_fn: nn.Module. loss function.
187188
:param inputs: torch.Tensor. inputs.

pytorch_optimizer/optimizer/lamb.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010

1111
class Lamb(Optimizer, BaseOptimizer):
12-
r"""Large Batch Optimization for Deep Learning. This Lamb implementation is based on the paper v3,
13-
which does not use de-biasing.
12+
r"""Large Batch Optimization for Deep Learning.
13+
14+
This Lamb implementation is based on the paper v3, which does not use de-biasing.
1415
1516
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1617
:param lr: float. learning rate.

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
class Ranger21(Optimizer, BaseOptimizer):
1717
r"""Integrating the latest deep learning components into a single optimizer.
18+
1819
Here's the components
1920
* uses the AdamW optimizer as its core (or, optionally, MadGrad)
2021
* Adaptive gradient clipping

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
class Shampoo(Optimizer, BaseOptimizer):
1111
r"""Preconditioned Stochastic Tensor Optimization.
12+
1213
Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.
1314
1415
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

0 commit comments

Comments
 (0)