Skip to content

Commit 8ae1122

Browse files
committed
Change test_optimizer to pytest discoverable function naming
1 parent 5fdc6af commit 8ae1122

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

tests/test_optimizer.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from manify.optimizers.radan import RiemannianAdan
88
from manify.manifolds import ProductManifold
99

10+
1011
def get_product_manifold_and_target(device_str: str):
1112
"""
1213
Construct a product manifold R^2 x R^2 and a target point.
@@ -17,25 +18,32 @@ def get_product_manifold_and_target(device_str: str):
1718
target_point_tensor = torch.tensor([1.0, 1.0, -1.0, -1.0], dtype=torch.float32)
1819
return product_manifold, target_point_tensor
1920

21+
2022
def objective_function(point, target_point, manifold):
2123
"""
2224
Objective function: squared distance to the target point.
2325
"""
2426
return manifold.dist(point, target_point) ** 2
2527

26-
def optimize_and_compare(manifold, target_point_tensor, optimizer_class, optimizer_params,
27-
initial_point_tensor, num_iterations=200, lr=0.1, tol=1e-5):
28+
29+
def optimize_and_compare(
30+
manifold,
31+
target_point_tensor,
32+
optimizer_class,
33+
optimizer_params,
34+
initial_point_tensor,
35+
num_iterations=200,
36+
lr=0.1,
37+
tol=1e-5,
38+
):
2839
"""
2940
Optimize the initial point using the specified Riemannian optimizer.
3041
"""
31-
point_to_optimize = ManifoldParameter(
32-
initial_point_tensor.clone().requires_grad_(True),
33-
manifold=manifold
34-
)
42+
point_to_optimize = ManifoldParameter(initial_point_tensor.clone().requires_grad_(True), manifold=manifold)
3543

3644
if optimizer_class.__name__ == "RiemannianAdan":
3745
current_optimizer_params = optimizer_params.copy()
38-
current_optimizer_params.setdefault('betas', (0.92, 0.98, 0.99))
46+
current_optimizer_params.setdefault("betas", (0.92, 0.98, 0.99))
3947
optimizer = optimizer_class([point_to_optimize], lr=lr, **current_optimizer_params)
4048
else:
4149
optimizer = optimizer_class([point_to_optimize], lr=lr, **optimizer_params)
@@ -52,7 +60,8 @@ def optimize_and_compare(manifold, target_point_tensor, optimizer_class, optimiz
5260

5361
return losses[-1], point_to_optimize.data.cpu().numpy(), losses
5462

55-
if __name__ == "__main__":
63+
64+
def test_radan_vs_adam():
5665
device_str = "cuda" if torch.cuda.is_available() else "cpu"
5766
product_manifold, target_point_tensor = get_product_manifold_and_target(device_str)
5867
target_point_tensor = target_point_tensor.to(device_str)
@@ -70,29 +79,32 @@ def optimize_and_compare(manifold, target_point_tensor, optimizer_class, optimiz
7079
initial_point_tensor.clone(),
7180
num_iterations=num_iterations,
7281
lr=learning_rate,
73-
tol=tolerance
82+
tol=tolerance,
7483
)
7584

7685
loss_radan, point_radan, _ = optimize_and_compare(
7786
product_manifold,
7887
target_point_tensor,
7988
RiemannianAdan,
80-
{'betas': [0.7, 0.999, 0.999]},
89+
{"betas": [0.7, 0.999, 0.999]},
8190
initial_point_tensor.clone(),
8291
num_iterations=num_iterations,
8392
lr=learning_rate,
84-
tol=tolerance
93+
tol=tolerance,
8594
)
8695

8796
print("\n--- Comparison Results ---")
8897
print(f"Target Point: {target_point_tensor.cpu().numpy()}")
8998
print(f"Initial Point: {initial_point_tensor.cpu().numpy()}")
9099
print(f"Adam Final Point: {point_adam} | Final Loss: {loss_adam:.6f}")
91100
print(f"Adan Final Point: {point_radan} | Final Loss: {loss_radan:.6f}")
92-
final_loss_radam = objective_function(torch.from_numpy(point_adam), target_point_tensor.cpu(), product_manifold).item()
93-
final_loss_radan = objective_function(torch.from_numpy(point_radan), target_point_tensor.cpu(), product_manifold).item()
101+
final_loss_radam = objective_function(
102+
torch.from_numpy(point_adam), target_point_tensor.cpu(), product_manifold
103+
).item()
104+
final_loss_radan = objective_function(
105+
torch.from_numpy(point_radan), target_point_tensor.cpu(), product_manifold
106+
).item()
94107

95108
assert final_loss_radam < 1e-3, "Adam did not converge close enough to the target"
96109
assert final_loss_radan < 1e-3, "Adan did not converge close enough to the target"
97110
print("\n✅ Optimization test passed: Both Adam and Adan reached the target with low loss.")
98-

0 commit comments

Comments
 (0)