@@ -215,7 +215,7 @@ def get_adanorm_gradient(
215215 return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad
216216
217217 @staticmethod
218- def validate_range (x : float , name : str , low : float , high : float , range_type : str = '[)' ):
218+ def validate_range (x : float , name : str , low : float , high : float , range_type : str = '[)' ) -> None :
219219 if range_type == '[)' and not low <= x < high :
220220 raise ValueError (f'[-] { name } must be in the range [{ low } , { high } )' )
221221 if range_type == '[]' and not low <= x <= high :
@@ -226,40 +226,42 @@ def validate_range(x: float, name: str, low: float, high: float, range_type: str
226226 raise ValueError (f'[-] { name } must be in the range ({ low } , { high } )' )
227227
228228 @staticmethod
229- def validate_non_negative (x : Optional [float ], name : str ):
229+ def validate_non_negative (x : Optional [float ], name : str ) -> None :
230230 if x is not None and x < 0.0 :
231231 raise ValueError (f'[-] { name } must be non-negative' )
232232
233233 @staticmethod
234- def validate_positive (x : Union [float , int ], name : str ):
234+ def validate_positive (x : Union [float , int ], name : str ) -> None :
235235 if x <= 0 :
236236 raise ValueError (f'[-] { name } must be positive' )
237237
238238 @staticmethod
239- def validate_boundary (constant : float , boundary : float , bound_type : str = 'upper' ):
239+ def validate_boundary (constant : float , boundary : float , bound_type : str = 'upper' ) -> None :
240240 if bound_type == 'upper' and constant > boundary :
241241 raise ValueError (f'[-] constant { constant } must be in a range of (-inf, { boundary } ]' )
242242 if bound_type == 'lower' and constant < boundary :
243243 raise ValueError (f'[-] constant { constant } must be in a range of [{ boundary } , inf)' )
244244
245245 @staticmethod
246- def validate_step (step : int , step_type : str ):
246+ def validate_step (step : int , step_type : str ) -> None :
247247 if step < 1 :
248248 raise NegativeStepError (step , step_type = step_type )
249249
250250 @staticmethod
251- def validate_options (x : str , name : str , options : List [str ]):
251+ def validate_options (x : str , name : str , options : List [str ]) -> None :
252252 if x not in options :
253253 opts : str = ' or ' .join ([f'\' { option } \' ' for option in options ]).strip ()
254254 raise ValueError (f'[-] { name } { x } must be one of ({ opts } )' )
255255
256256 @staticmethod
257- def validate_learning_rate (learning_rate : Optional [float ]):
257+ def validate_learning_rate (learning_rate : Optional [float ]) -> None :
258258 if learning_rate is not None and learning_rate < 0.0 :
259259 raise NegativeLRError (learning_rate )
260260
261- def validate_betas (self , betas : BETAS ):
262- self .validate_range (betas [0 ], 'beta1' , 0.0 , 1.0 , range_type = '[]' )
261+ def validate_betas (self , betas : BETAS ) -> None :
262+ if betas [0 ] is not None :
263+ self .validate_range (betas [0 ], 'beta1' , 0.0 , 1.0 , range_type = '[]' )
264+
263265 self .validate_range (betas [1 ], 'beta2' , 0.0 , 1.0 , range_type = '[]' )
264266
265267 if len (betas ) < 3 :
@@ -268,7 +270,7 @@ def validate_betas(self, betas: BETAS):
268270 if betas [2 ] is not None :
269271 self .validate_range (betas [2 ], 'beta3' , 0.0 , 1.0 , range_type = '[]' )
270272
271- def validate_nus (self , nus : Union [float , Tuple [float , float ]]):
273+ def validate_nus (self , nus : Union [float , Tuple [float , float ]]) -> None :
272274 if isinstance (nus , float ):
273275 self .validate_range (nus , 'nu' , 0.0 , 1.0 , range_type = '[]' )
274276 else :
0 commit comments