1010
1111class DynamicLossScaler :
1212 r"""Dynamically adjusts the loss scaling factor.
13+
1314 Dynamic loss scalers are important in mixed-precision training.
1415 They help us avoid underflows and overflows in low-precision gradients.
1516
@@ -50,8 +51,9 @@ def __init__(
5051
5152 def update_scale (self , overflow : bool ):
5253 r"""Update the loss scale.
53- If overflow exceeds our tolerance, we decrease the loss scale. If the number of
54- iterations since the last overflow exceeds the scale window, we increase the loss scale.
54+
55+ If overflow exceeds our tolerance, we decrease the loss scale.
56+ If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.
5557
5658 :param overflow: bool. adjust scales to prevent overflow.
5759 """
@@ -79,17 +81,31 @@ def update_scale(self, overflow: bool):
7981
8082 def decrease_loss_scale (self ):
8183 r"""Decrease the loss scale by self.scale_factor.
82- NOTE: the loss_scale will not go below self.threshold.
84+
85+ NOTE: the loss_scale will not go below `self.threshold`.
8386 """
8487 self .loss_scale /= self .scale_factor
8588 if self .threshold is not None :
8689 self .loss_scale = max (self .loss_scale , self .threshold )
8790
8891
8992class SafeFP16Optimizer (Optimizer ):
90- def __init__ (self , optimizer : OPTIMIZER , aggregate_g_norms : bool = False ):
93+ r"""Safe FP16 Optimizer.
94+
95+ :param optimizer: OPTIMIZER.
96+ :param aggregate_g_norms: bool. aggregate_g_norms.
97+ :param min_loss_scale: float. min_loss_scale.
98+ """
99+
100+ def __init__ (
101+ self ,
102+ optimizer : OPTIMIZER ,
103+ aggregate_g_norms : bool = False ,
104+ min_loss_scale : float = 2 ** - 5 ,
105+ ): # fmt: skip
91106 self .optimizer = optimizer
92107 self .aggregate_g_norms = aggregate_g_norms
108+ self .min_loss_scale = min_loss_scale
93109
94110 self .fp16_params = self .get_parameters (optimizer )
95111 self .fp32_params = self .build_fp32_params (self .fp16_params , flatten = False )
@@ -104,7 +120,6 @@ def __init__(self, optimizer: OPTIMIZER, aggregate_g_norms: bool = False):
104120 optimizer .param_groups [0 ]['params' ] = self .fp32_params
105121
106122 self .scaler : DynamicLossScaler = DynamicLossScaler (2.0 ** 15 ) # fmt: skip
107- self .min_loss_scale : float = 2 ** - 5 # fmt: skip
108123 self .needs_sync : bool = True
109124
110125 @classmethod
@@ -151,6 +166,7 @@ def state_dict(self) -> Dict:
151166
152167 def load_state_dict (self , state_dict : Dict ):
153168 r"""Load an optimizer state dict.
169+
154170 In general, we should prefer the configuration of the existing optimizer instance
155171 (e.g., learning rate) over that found in the state_dict. This allows us to
156172 resume training from a checkpoint using a new set of optimizer args.
@@ -162,9 +178,13 @@ def load_state_dict(self, state_dict: Dict):
162178 self .optimizer .load_state_dict (state_dict )
163179
164180 def backward (self , loss , update_main_grads : bool = False ):
165- r"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
166- Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
167- additionally dynamically scales the loss to avoid gradient underflow.
181+ r"""Compute the sum of gradients of the given tensor w.r.t. graph leaves.
182+
183+ Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
184+ additionally dynamically scales the loss to avoid gradient underflow.
185+
186+ :param loss: float. loss.
187+ :param update_main_grads: bool. update main gradient.
168188 """
169189 if self .scaler is not None :
170190 loss = loss * self .scaler .loss_scale
@@ -176,6 +196,7 @@ def backward(self, loss, update_main_grads: bool = False):
176196 self .update_main_grads ()
177197
178198 def sync_fp16_grads_to_fp32 (self , multiply_grads : float = 1.0 ):
199+ r"""Sync fp16 to fp32 gradients."""
179200 if self .needs_sync :
180201 if self .scaler is not None :
181202 # correct for dynamic loss scaler
@@ -195,7 +216,7 @@ def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
195216 self .needs_sync = False
196217
197218 def multiply_grads (self , c : float ):
198- r"""Multiplies grads by a constant c."""
219+ r"""Multiply grads by a constant c."""
199220 if self .needs_sync :
200221 self .sync_fp16_grads_to_fp32 (c )
201222 else :
@@ -206,7 +227,7 @@ def update_main_grads(self):
206227 self .sync_fp16_grads_to_fp32 ()
207228
208229 def clip_main_grads (self , max_norm : float ):
209- r"""Clips gradient norm and updates dynamic loss scaler."""
230+ r"""Clip gradient norm and updates dynamic loss scaler."""
210231 self .sync_fp16_grads_to_fp32 ()
211232
212233 grad_norm = clip_grad_norm (self .fp32_params , max_norm , sync = self .aggregate_g_norms )
@@ -221,8 +242,8 @@ def clip_main_grads(self, max_norm: float):
221242 if overflow :
222243 self .zero_grad ()
223244 if self .scaler .loss_scale <= self .min_loss_scale :
224- # Use FloatingPointError as an uncommon error that parent
225- # functions can safely catch to stop training.
245+ # Use FloatingPointError as an uncommon error
246+ # that parent functions can safely catch to stop training.
226247 self .scaler .loss_scale = prev_scale
227248
228249 raise FloatingPointError (
@@ -235,7 +256,7 @@ def clip_main_grads(self, max_norm: float):
235256 return grad_norm
236257
237258 def step (self , closure : CLOSURE = None ):
238- r"""Performs a single optimization step."""
259+ r"""Perform a single optimization step."""
239260 self .sync_fp16_grads_to_fp32 ()
240261 self .optimizer .step (closure )
241262
@@ -246,17 +267,19 @@ def step(self, closure: CLOSURE = None):
246267 p .data .copy_ (p32 )
247268
248269 def zero_grad (self ):
249- r"""Clears the gradients of all optimized parameters."""
270+ r"""Clear the gradients of all optimized parameters."""
250271 for p in self .fp16_params :
251272 p .grad = None
252273 for p32 in self .fp32_params :
253274 p32 .grad .zero_ ()
254275 self .needs_sync = False
255276
256277 def get_lr (self ) -> float :
278+ r"""Get learning rate."""
257279 return self .optimizer .get_lr ()
258280
259281 def set_lr (self , lr : float ):
282+ r"""Set learning rate."""
260283 self .optimizer .set_lr (lr )
261284
262285 @property
0 commit comments