77from manify .optimizers .radan import RiemannianAdan
88from manify .manifolds import ProductManifold
99
10+
1011def get_product_manifold_and_target (device_str : str ):
1112 """
1213 Construct a product manifold R^2 x R^2 and a target point.
@@ -17,25 +18,32 @@ def get_product_manifold_and_target(device_str: str):
1718 target_point_tensor = torch .tensor ([1.0 , 1.0 , - 1.0 , - 1.0 ], dtype = torch .float32 )
1819 return product_manifold , target_point_tensor
1920
21+
2022def objective_function (point , target_point , manifold ):
2123 """
2224 Objective function: squared distance to the target point.
2325 """
2426 return manifold .dist (point , target_point ) ** 2
2527
26- def optimize_and_compare (manifold , target_point_tensor , optimizer_class , optimizer_params ,
27- initial_point_tensor , num_iterations = 200 , lr = 0.1 , tol = 1e-5 ):
28+
29+ def optimize_and_compare (
30+ manifold ,
31+ target_point_tensor ,
32+ optimizer_class ,
33+ optimizer_params ,
34+ initial_point_tensor ,
35+ num_iterations = 200 ,
36+ lr = 0.1 ,
37+ tol = 1e-5 ,
38+ ):
2839 """
2940 Optimize the initial point using the specified Riemannian optimizer.
3041 """
31- point_to_optimize = ManifoldParameter (
32- initial_point_tensor .clone ().requires_grad_ (True ),
33- manifold = manifold
34- )
42+ point_to_optimize = ManifoldParameter (initial_point_tensor .clone ().requires_grad_ (True ), manifold = manifold )
3543
3644 if optimizer_class .__name__ == "RiemannianAdan" :
3745 current_optimizer_params = optimizer_params .copy ()
38- current_optimizer_params .setdefault (' betas' , (0.92 , 0.98 , 0.99 ))
46+ current_optimizer_params .setdefault (" betas" , (0.92 , 0.98 , 0.99 ))
3947 optimizer = optimizer_class ([point_to_optimize ], lr = lr , ** current_optimizer_params )
4048 else :
4149 optimizer = optimizer_class ([point_to_optimize ], lr = lr , ** optimizer_params )
@@ -52,7 +60,8 @@ def optimize_and_compare(manifold, target_point_tensor, optimizer_class, optimiz
5260
5361 return losses [- 1 ], point_to_optimize .data .cpu ().numpy (), losses
5462
55- if __name__ == "__main__" :
63+
64+ def test_radan_vs_adam ():
5665 device_str = "cuda" if torch .cuda .is_available () else "cpu"
5766 product_manifold , target_point_tensor = get_product_manifold_and_target (device_str )
5867 target_point_tensor = target_point_tensor .to (device_str )
@@ -70,29 +79,32 @@ def optimize_and_compare(manifold, target_point_tensor, optimizer_class, optimiz
7079 initial_point_tensor .clone (),
7180 num_iterations = num_iterations ,
7281 lr = learning_rate ,
73- tol = tolerance
82+ tol = tolerance ,
7483 )
7584
7685 loss_radan , point_radan , _ = optimize_and_compare (
7786 product_manifold ,
7887 target_point_tensor ,
7988 RiemannianAdan ,
80- {' betas' : [0.7 , 0.999 , 0.999 ]},
89+ {" betas" : [0.7 , 0.999 , 0.999 ]},
8190 initial_point_tensor .clone (),
8291 num_iterations = num_iterations ,
8392 lr = learning_rate ,
84- tol = tolerance
93+ tol = tolerance ,
8594 )
8695
8796 print ("\n --- Comparison Results ---" )
8897 print (f"Target Point: { target_point_tensor .cpu ().numpy ()} " )
8998 print (f"Initial Point: { initial_point_tensor .cpu ().numpy ()} " )
9099 print (f"Adam Final Point: { point_adam } | Final Loss: { loss_adam :.6f} " )
91100 print (f"Adan Final Point: { point_radan } | Final Loss: { loss_radan :.6f} " )
92- final_loss_radam = objective_function (torch .from_numpy (point_adam ), target_point_tensor .cpu (), product_manifold ).item ()
93- final_loss_radan = objective_function (torch .from_numpy (point_radan ), target_point_tensor .cpu (), product_manifold ).item ()
101+ final_loss_radam = objective_function (
102+ torch .from_numpy (point_adam ), target_point_tensor .cpu (), product_manifold
103+ ).item ()
104+ final_loss_radan = objective_function (
105+ torch .from_numpy (point_radan ), target_point_tensor .cpu (), product_manifold
106+ ).item ()
94107
95108 assert final_loss_radam < 1e-3 , "Adam did not converge close enough to the target"
96109 assert final_loss_radan < 1e-3 , "Adan did not converge close enough to the target"
97110 print ("\n ✅ Optimization test passed: Both Adam and Adan reached the target with low loss." )
98-
0 commit comments