11from typing import Dict , List , Optional , Union
22
33import torch
4+ from torch import nn
45from torch .optim import Optimizer
56
6- from pytorch_optimizer .types import CLOSURE
7+ from pytorch_optimizer .types import CLOSURE , PARAMETERS
78from pytorch_optimizer .utils import clip_grad_norm , has_overflow
89
910__AUTHOR__ = 'Facebook'
@@ -114,26 +115,29 @@ def get_parameters(cls, optimizer: Optimizer):
114115 return params
115116
116117 @classmethod
117- def build_fp32_params (cls , parameters , flatten : bool = True ) -> Union [torch .Tensor , List [torch .Tensor ]]:
118+ def build_fp32_params (
119+ cls , parameters : PARAMETERS , flatten : bool = True
120+ ) -> Union [torch .Tensor , List [torch .Tensor ]]:
118121 # create FP32 copy of parameters and grads
119122 if flatten :
120- total_param_size = sum (p . data .numel () for p in parameters )
123+ total_param_size : int = sum (p .numel () for p in parameters )
121124 fp32_params = torch .zeros (total_param_size , dtype = torch .float , device = parameters [0 ].device )
122125
123126 offset : int = 0
124127 for p in parameters :
125- p_num_el = p .data . numel ()
126- fp32_params [offset : offset + p_num_el ].copy_ (p .data . view (- 1 ))
128+ p_num_el = p .numel ()
129+ fp32_params [offset : offset + p_num_el ].copy_ (p .view (- 1 ))
127130 offset += p_num_el
128131
129- fp32_params = torch .nn .Parameter (fp32_params )
130- fp32_params .grad = fp32_params .data .new (total_param_size )
132+ fp32_params = nn .Parameter (fp32_params )
133+ fp32_params .grad = fp32_params .new (total_param_size )
134+
131135 return fp32_params
132136
133137 fp32_params = []
134138 for p in parameters :
135- p32 = torch . nn .Parameter (p . data .float ())
136- p32 .grad = torch .zeros_like (p32 . data )
139+ p32 = nn .Parameter (p .float ())
140+ p32 .grad = torch .zeros_like (p32 )
137141 fp32_params .append (p32 )
138142
139143 return fp32_params
@@ -181,25 +185,25 @@ def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
181185 continue
182186
183187 if p .grad is not None :
184- p32 .grad .data . copy_ (p .grad . data )
185- p32 .grad .data . mul_ (multiply_grads )
188+ p32 .grad .copy_ (p .grad )
189+ p32 .grad .mul_ (multiply_grads )
186190 else :
187- p32 .grad = torch .zeros_like (p . data , dtype = torch .float )
191+ p32 .grad = torch .zeros_like (p , dtype = torch .float )
188192
189193 self .needs_sync = False
190194
191- def multiply_grads (self , c ):
195+ def multiply_grads (self , c : float ):
192196 """Multiplies grads by a constant c."""
193197 if self .needs_sync :
194198 self .sync_fp16_grads_to_fp32 (c )
195199 else :
196200 for p32 in self .fp32_params :
197- p32 .grad .data . mul_ (c )
201+ p32 .grad .mul_ (c )
198202
199203 def update_main_grads (self ):
200204 self .sync_fp16_grads_to_fp32 ()
201205
202- def clip_main_grads (self , max_norm ):
206+ def clip_main_grads (self , max_norm : float ):
203207 """Clips gradient norm and updates dynamic loss scaler."""
204208 self .sync_fp16_grads_to_fp32 ()
205209
@@ -208,8 +212,10 @@ def clip_main_grads(self, max_norm):
208212 # detect overflow and adjust loss scale
209213 if self .scaler is not None :
210214 overflow : bool = has_overflow (grad_norm )
211- prev_scale = self .scaler .loss_scale
215+ prev_scale : float = self .scaler .loss_scale
216+
212217 self .scaler .update_scale (overflow )
218+
213219 if overflow :
214220 self .zero_grad ()
215221 if self .scaler .loss_scale <= self .min_loss_scale :
@@ -235,7 +241,7 @@ def step(self, closure: CLOSURE = None):
235241 for p , p32 in zip (self .fp16_params , self .fp32_params ):
236242 if not p .requires_grad :
237243 continue
238- p .data .copy_ (p32 . data )
244+ p .data .copy_ (p32 )
239245
240246 def zero_grad (self ):
241247 """Clears the gradients of all optimized parameters."""
0 commit comments