Skip to content

Commit 1757433

Browse files
committed
add opt. package tests
1 parent 8d36eb0 commit 1757433

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

tests/test_optimization.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,78 @@ def test_scipy_integration():
148148
assert result.success
149149
assert abs(result.fun) < 0.01
150150
assert np.allclose(result.x, [0, 0], atol=0.1)
151+
152+
153+
############################################################
154+
# Optuna integration test
155+
156+
157+
def test_optuna_integration():
158+
"""Test that Surfaces functions work with Optuna."""
159+
import optuna
160+
161+
optuna.logging.set_verbosity(optuna.logging.WARNING)
162+
163+
func = SphereFunction(n_dim=2)
164+
165+
def objective(trial):
166+
x0 = trial.suggest_float("x0", -5, 5)
167+
x1 = trial.suggest_float("x1", -5, 5)
168+
# Surfaces accepts dict input
169+
return func({"x0": x0, "x1": x1})
170+
171+
study = optuna.create_study(direction="minimize")
172+
study.optimize(objective, n_trials=50, show_progress_bar=False)
173+
174+
# Should find minimum near [0, 0]
175+
assert study.best_value < 0.5
176+
assert abs(study.best_params["x0"]) < 1.0
177+
assert abs(study.best_params["x1"]) < 1.0
178+
179+
180+
def test_optuna_integration_rastrigin():
181+
"""Test Optuna with a more complex function."""
182+
import optuna
183+
184+
optuna.logging.set_verbosity(optuna.logging.WARNING)
185+
186+
func = RastriginFunction(n_dim=2)
187+
188+
def objective(trial):
189+
params = {f"x{i}": trial.suggest_float(f"x{i}", -5, 5) for i in range(2)}
190+
return func(params)
191+
192+
study = optuna.create_study(direction="minimize")
193+
study.optimize(objective, n_trials=100, show_progress_bar=False)
194+
195+
# Rastrigin has global minimum at origin with value 0
196+
# With 100 trials, we should get reasonably close
197+
assert study.best_value < 5.0
198+
199+
200+
############################################################
201+
# scikit-optimize integration test
202+
203+
204+
def test_skopt_integration():
205+
"""Test that Surfaces functions work with scikit-optimize."""
206+
pytest.importorskip("skopt")
207+
from skopt import gp_minimize
208+
209+
func = SphereFunction(n_dim=2)
210+
211+
# skopt passes list of values
212+
def objective(x):
213+
return func(x)
214+
215+
result = gp_minimize(
216+
objective,
217+
[(-5.0, 5.0), (-5.0, 5.0)],
218+
n_calls=30,
219+
random_state=42,
220+
)
221+
222+
# Should find minimum near [0, 0]
223+
assert result.fun < 0.5
224+
assert abs(result.x[0]) < 1.0
225+
assert abs(result.x[1]) < 1.0

0 commit comments

Comments
 (0)