Skip to content

Commit c1618d0

Browse files
committed
refactor: DynamicLossScaler
1 parent 7b5eb65 commit c1618d0

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

pytorch_optimizer/fp16.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ def __init__(
3030
tolerance: float = 0.00,
3131
threshold: Optional[float] = None,
3232
):
33-
"""
34-
:param init_scale: Initial loss scale.
35-
:param scale_factor: Factor by which to increase or decrease loss scale.
33+
"""Dynamic Loss Scaler for fp16 training
34+
:param init_scale: Initial loss scale
35+
:param scale_factor: Factor by which to increase or decrease loss scale
3636
:param scale_window: If we do not experience overflow in scale_window iterations,
37-
loss scale will increase by scale_factor.
37+
loss scale will increase by scale_factor
3838
:param tolerance: Pct of iterations that have overflowed after which we must decrease the loss scale
3939
:param threshold: If not None, loss scale will decrease below this threshold
4040
"""
@@ -122,9 +122,9 @@ def build_fp32_params(cls, parameters, flatten: bool = True):
122122

123123
offset: int = 0
124124
for p in parameters:
125-
numel = p.data.numel()
126-
fp32_params[offset : offset + numel].copy_(p.data.view(-1))
127-
offset += numel
125+
p_num_el = p.data.numel()
126+
fp32_params[offset : offset + p_num_el].copy_(p.data.view(-1))
127+
offset += p_num_el
128128

129129
fp32_params = torch.nn.Parameter(fp32_params)
130130
fp32_params.grad = fp32_params.data.new(total_param_size)
@@ -139,15 +139,15 @@ def build_fp32_params(cls, parameters, flatten: bool = True):
139139
return fp32_params
140140

141141
def state_dict(self) -> Dict:
142-
"""Return the optimizer's state dict."""
142+
"""Return the optimizer state dict."""
143143
state_dict = self.optimizer.state_dict()
144144
if self.scaler is not None:
145145
state_dict['loss_scaler'] = self.scaler.loss_scale
146146
return state_dict
147147

148148
def load_state_dict(self, state_dict: Dict):
149149
"""Load an optimizer state dict.
150-
In general we should prefer the configuration of the existing optimizer instance
150+
In general, we should prefer the configuration of the existing optimizer instance
151151
(e.g., learning rate) over that found in the state_dict. This allows us to
152152
resume training from a checkpoint using a new set of optimizer args.
153153
"""

0 commit comments

Comments
 (0)