11import math
2- import warnings
32from functools import partial
43from pathlib import Path
5- from typing import Callable , Dict , Tuple , Union
4+ from typing import Callable , Dict , List , Tuple , Union
5+ from warnings import filterwarnings
66
77import numpy as np
88import torch
1414from pytorch_optimizer .optimizer import OPTIMIZERS
1515from pytorch_optimizer .optimizer .alig import l2_projection
1616
17- warnings . filterwarnings ('ignore' , category = UserWarning )
17+ filterwarnings ('ignore' , category = UserWarning )
1818
1919OPTIMIZERS_IGNORE = ('lomo' , 'adalomo' , 'demo' , 'a2grad' , 'alig' ) # BUG: fix `alig`, invalid .__name__
2020OPTIMIZERS_MODEL_INPUT_NEEDED = ('lomo' , 'adalomo' , 'adammini' )
2121OPTIMIZERS_GRAPH_NEEDED = ('adahessian' , 'sophiah' )
2222OPTIMIZERS_CLOSURE_NEEDED = ('alig' , 'bsam' )
23- EVAL_PER_HYPYPERPARAM = 540
24- OPTIMIZATION_STEPS = 300
25- TESTING_OPTIMIZATION_STEPS = 650
26- DIFFICULT_RASTRIGIN = False
27- USE_AVERAGE_LOSS_PENALTY = True
28- AVERAGE_LOSS_PENALTY_FACTOR = 1.0
29- SEARCH_SEED = 42
30- LOSS_MIN_TRESH = 0
31-
32- default_search_space = {'lr' : hp .uniform ('lr' , 0 , 2 )}
33- special_search_spaces = {
23+ EVAL_PER_HYPERPARAM : int = 540
24+ OPTIMIZATION_STEPS : int = 300
25+ TESTING_OPTIMIZATION_STEPS : int = 650
26+ DIFFICULT_RASTRIGIN : bool = False
27+ USE_AVERAGE_LOSS_PENALTY : bool = True
28+ AVERAGE_LOSS_PENALTY_FACTOR : float = 1.0
29+ SEARCH_SEED : int = 42
30+ LOSS_MIN_THRESHOLD : float = 0. 0
31+
32+ DEFAULT_SEARCH_SPACES = {'lr' : hp .uniform ('lr' , 0 , 2 )}
33+ SPECIAL_SEARCH_SPACES = {
3434 'adafactor' : {'lr' : hp .uniform ('lr' , 0 , 10 )},
3535 'adams' : {'lr' : hp .uniform ('lr' , 0 , 10 )},
3636 'dadaptadagrad' : {'lr' : hp .uniform ('lr' , 0 , 10 )},
@@ -170,7 +170,7 @@ def execute_steps(
170170 optimizer_class : torch .optim .Optimizer ,
171171 optimizer_config : Dict ,
172172 num_iters : int = 500 ,
173- ) -> torch .Tensor :
173+ ) -> Tuple [ torch .Tensor , List [ float ]] :
174174 """
175175 Execute optimization steps for a given configuration.
176176
@@ -201,7 +201,6 @@ def closure() -> float:
201201
202202 return closure
203203
204- # Initialize the model and optimizer
205204 model = Model (func , initial_state )
206205 parameters = list (model .parameters ())
207206 optimizer_name : str = optimizer_class .__name__ .lower ()
@@ -218,30 +217,25 @@ def closure() -> float:
218217 elif optimizer_name == 'bsam' :
219218 optimizer_config ['num_data' ] = 1
220219
221- # Special initialization for memory-efficient optimizers
222220 if optimizer_name in OPTIMIZERS_MODEL_INPUT_NEEDED :
223221 optimizer = optimizer_class (model , ** optimizer_config )
224222 else :
225223 optimizer = optimizer_class (parameters , ** optimizer_config )
226224
227- # Track optimization path
228- losses = []
229225 steps = torch .zeros ((2 , num_iters + 1 ), dtype = torch .float32 )
230226 steps [:, 0 ] = model .x .detach ()
231227
228+ losses = []
232229 for i in range (1 , num_iters + 1 ):
233230 optimizer .zero_grad ()
231+
234232 loss = model ()
235233 losses .append (loss .item ())
236234
237- # Special handling for second-order optimizers
238- create_graph = optimizer_name in OPTIMIZERS_GRAPH_NEEDED
239- loss .backward (create_graph = create_graph )
235+ loss .backward (create_graph = optimizer_name in OPTIMIZERS_GRAPH_NEEDED )
240236
241- # Gradient clipping for stability
242237 nn .utils .clip_grad_norm_ (parameters , 1.0 )
243238
244- # Closure required for certain optimizers
245239 closure = create_closure (loss ) if optimizer_name in OPTIMIZERS_CLOSURE_NEEDED else None
246240 optimizer .step (closure )
247241
@@ -279,25 +273,19 @@ def objective(
279273 - A penalty for boundary violations.
280274 - An optional penalty for higher average loss during optimization (if enabled).
281275 """
282- # Execute optimization steps and get losses
283- steps , losses = execute_steps ( # Modified to unpack losses
284- criterion , initial_state , optimizer_class , params , num_iters
285- )
276+ steps , losses = execute_steps (criterion , initial_state , optimizer_class , params , num_iters )
286277
287- # Calculate boundary violations
288278 x_min_violation = torch .clamp (x_bounds [0 ] - steps [0 ], min = 0 ).max ()
289279 x_max_violation = torch .clamp (steps [0 ] - x_bounds [1 ], min = 0 ).max ()
290280 y_min_violation = torch .clamp (y_bounds [0 ] - steps [1 ], min = 0 ).max ()
291281 y_max_violation = torch .clamp (steps [1 ] - y_bounds [1 ], min = 0 ).max ()
292282 total_violation = x_min_violation + x_max_violation + y_min_violation + y_max_violation
293283
294- # Calculate average loss penalty
295- avg_loss = sum (losses ) / len (losses ) if losses else 0.0
296284 penalty = 75 * total_violation .item ()
297285 if USE_AVERAGE_LOSS_PENALTY :
286+ avg_loss : float = sum (losses ) / len (losses ) if losses else 0.0
298287 penalty += avg_loss * AVERAGE_LOSS_PENALTY_FACTOR
299288
300- # Calculate final distance to minimum
301289 final_position = steps [:, - 1 ]
302290 final_distance = ((final_position [0 ] - minimum [0 ]) ** 2 + (final_position [1 ] - minimum [1 ]) ** 2 ).item ()
303291
@@ -309,7 +297,7 @@ def plot_function(
309297 optimization_steps : torch .Tensor ,
310298 output_path : Path ,
311299 optimizer_name : str ,
312- params : dict ,
300+ params : Dict ,
313301 x_range : Tuple [float , float ],
314302 y_range : Tuple [float , float ],
315303 minimum : Tuple [float , float ],
@@ -335,34 +323,29 @@ def plot_function(
335323 fig = plt .figure (figsize = (8 , 8 ))
336324 ax = fig .add_subplot (1 , 1 , 1 )
337325
338- # Plot function contours and optimization path
339326 ax .contour (x_grid .numpy (), y_grid .numpy (), z .numpy (), 20 , cmap = 'jet' )
340327 ax .plot (optimization_steps [0 ], optimization_steps [1 ], color = 'r' , marker = 'x' , markersize = 3 )
341328
342- # Mark global minimum and final position
343329 plt .plot (* minimum , 'gD' , label = 'Global Minimum' )
344330 plt .plot (optimization_steps [0 , - 1 ], optimization_steps [1 , - 1 ], 'bD' , label = 'Final Position' )
345331
346- ax .set_title (
347- f'{ func .__name__ } func: { optimizer_name } with { TESTING_OPTIMIZATION_STEPS } iterations\n {
348- ", " .join (f"{ k } ={ round (v , 4 )} " for k , v in params .items ())
349- } '
350- )
332+ config : str = ', ' .join (f'{ k } ={ round (v , 4 )} ' for k , v in params .items ())
333+ ax .set_title (f'{ func .__name__ } func: { optimizer_name } with { TESTING_OPTIMIZATION_STEPS } iterations\n { config } ' )
351334 plt .legend ()
352335 plt .savefig (str (output_path ))
353336 plt .close ()
354337
355338
356339def execute_experiments (
357- optimizers : list ,
340+ optimizers : List ,
358341 func : Callable ,
359342 initial_state : Tuple [float , float ],
360343 output_dir : Path ,
361344 experiment_name : str ,
362345 x_range : Tuple [float , float ],
363346 y_range : Tuple [float , float ],
364347 minimum : Tuple [float , float ],
365- seed : int = 42 ,
348+ seed : int = SEARCH_SEED ,
366349) -> None :
367350 """
368351 Run optimization experiments for multiple optimizers.
@@ -382,15 +365,14 @@ def execute_experiments(
382365 optimizer_name = optimizer_class .__name__
383366 output_path = output_dir / f'{ experiment_name } _{ optimizer_name } .png'
384367 if output_path .exists ():
385- continue # Skip already generated plots
368+ continue
386369
387370 print ( # noqa: T201
388371 f'({ i } /{ len (optimizers )} ) Processing { optimizer_name } ... (Params to tune: { ", " .join (search_space .keys ())} )' # noqa: E501
389372 )
390373
391- # Select hyperparameter search space
392- num_hyperparams = len (search_space )
393- max_evals = EVAL_PER_HYPYPERPARAM * num_hyperparams # Scale evaluations based on hyperparameter count
374+ num_hyperparams : int = len (search_space )
375+ max_evals : int = EVAL_PER_HYPERPARAM * num_hyperparams
394376
395377 objective_fn = partial (
396378 objective ,
@@ -402,43 +384,38 @@ def execute_experiments(
402384 y_bounds = y_range ,
403385 num_iters = OPTIMIZATION_STEPS ,
404386 )
387+
405388 try :
406389 best_params = fmin (
407390 fn = objective_fn ,
408391 space = search_space ,
409392 algo = tpe .suggest ,
410393 max_evals = max_evals ,
411- loss_threshold = LOSS_MIN_TRESH ,
394+ loss_threshold = LOSS_MIN_THRESHOLD ,
412395 rstate = np .random .default_rng (seed ),
413396 )
414397 except AllTrialsFailed :
415398 print (f'⚠️ { optimizer_name } failed to optimize { func .__name__ } ' ) # noqa: T201
416399 continue
417400
418- # Run final optimization with best parameters
419- steps , _ = execute_steps ( # Modified to ignore losses
420- func , initial_state , optimizer_class , best_params , TESTING_OPTIMIZATION_STEPS
421- )
401+ steps , _ = execute_steps (func , initial_state , optimizer_class , best_params , TESTING_OPTIMIZATION_STEPS )
422402
423- # Generate and save visualization
424403 plot_function (func , steps , output_path , optimizer_name , best_params , x_range , y_range , minimum )
425404
426405
427406def main ():
428- """Main execution routine for optimization experiments."""
429407 np .random .seed (SEARCH_SEED )
430408 torch .manual_seed (SEARCH_SEED )
409+
431410 output_dir = Path ('.' ) / 'docs' / 'visualizations'
432411 output_dir .mkdir (parents = True , exist_ok = True )
433412
434- # Prepare the list of optimizers and their search spaces
435413 optimizers = [
436- (optimizer , special_search_spaces .get (optimizer_name , default_search_space ))
414+ (optimizer , SPECIAL_SEARCH_SPACES .get (optimizer_name , DEFAULT_SEARCH_SPACES ))
437415 for optimizer_name , optimizer in OPTIMIZERS .items ()
438416 if optimizer_name not in OPTIMIZERS_IGNORE
439417 ]
440418
441- # Run experiments for the Rastrigin function
442419 print ('Executing Rastrigin experiments...' ) # noqa: T201
443420 execute_experiments (
444421 optimizers ,
@@ -452,7 +429,6 @@ def main():
452429 seed = SEARCH_SEED ,
453430 )
454431
455- # Run experiments for the Rosenbrock function
456432 print ('Executing Rosenbrock experiments...' ) # noqa: T201
457433 execute_experiments (
458434 optimizers ,
0 commit comments