@@ -30,11 +30,11 @@ def __init__(
3030 tolerance : float = 0.00 ,
3131 threshold : Optional [float ] = None ,
3232 ):
33- """
34- :param init_scale: Initial loss scale.
35- :param scale_factor: Factor by which to increase or decrease loss scale.
33+ """Dynamic Loss Scaler for fp16 training
34+ :param init_scale: Initial loss scale
35+ :param scale_factor: Factor by which to increase or decrease loss scale
3636 :param scale_window: If we do not experience overflow in scale_window iterations,
37- loss scale will increase by scale_factor.
37+ loss scale will increase by scale_factor
3838 :param tolerance: Pct of iterations that have overflowed after which we must decrease the loss scale
3939 :param threshold: If not None, loss scale will decrease below this threshold
4040 """
@@ -122,9 +122,9 @@ def build_fp32_params(cls, parameters, flatten: bool = True):
122122
123123 offset : int = 0
124124 for p in parameters :
125- numel = p .data .numel ()
126- fp32_params [offset : offset + numel ].copy_ (p .data .view (- 1 ))
127- offset += numel
125+ p_num_el = p .data .numel ()
126+ fp32_params [offset : offset + p_num_el ].copy_ (p .data .view (- 1 ))
127+ offset += p_num_el
128128
129129 fp32_params = torch .nn .Parameter (fp32_params )
130130 fp32_params .grad = fp32_params .data .new (total_param_size )
@@ -139,15 +139,15 @@ def build_fp32_params(cls, parameters, flatten: bool = True):
139139 return fp32_params
140140
141141 def state_dict (self ) -> Dict :
142- """Return the optimizer's state dict."""
142+ """Return the optimizer state dict."""
143143 state_dict = self .optimizer .state_dict ()
144144 if self .scaler is not None :
145145 state_dict ['loss_scaler' ] = self .scaler .loss_scale
146146 return state_dict
147147
148148 def load_state_dict (self , state_dict : Dict ):
149149 """Load an optimizer state dict.
150- In general we should prefer the configuration of the existing optimizer instance
150+ In general, we should prefer the configuration of the existing optimizer instance
151151 (e.g., learning rate) over that found in the state_dict. This allows us to
152152 resume training from a checkpoint using a new set of optimizer args.
153153 """
0 commit comments