Skip to content

Commit de82800

Browse files
committed
add TestSurrogateNamespace
1 parent 0b87f9a commit de82800

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

tests/full/custom_test_function/__init__.py

Whitespace-only changes.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Author: Simon Blanke
2+
# Email: simon.blanke@yahoo.com
3+
# License: MIT License
4+
5+
"""Tests for CustomTestFunction surrogate namespace (requires scikit-learn)."""
6+
7+
import numpy as np
8+
import pytest
9+
10+
from surfaces.custom_test_function import CustomTestFunction
11+
12+
13+
@pytest.fixture
14+
def sphere_func():
15+
"""Simple sphere function for testing."""
16+
17+
def sphere(params):
18+
return sum(v**2 for v in params.values())
19+
20+
return CustomTestFunction(
21+
objective_fn=sphere,
22+
search_space={"x": (-5, 5), "y": (-5, 5)},
23+
)
24+
25+
26+
@pytest.fixture
27+
def sphere_func_with_data(sphere_func):
28+
"""Sphere function with evaluation data."""
29+
np.random.seed(42)
30+
for _ in range(50):
31+
x, y = np.random.uniform(-5, 5, 2)
32+
sphere_func({"x": x, "y": y})
33+
return sphere_func
34+
35+
36+
class TestSurrogateNamespace:
37+
"""Test surrogate namespace methods."""
38+
39+
def test_fit_random_forest(self, sphere_func_with_data):
40+
"""Test Random Forest surrogate fitting."""
41+
sphere_func_with_data.surrogate.fit(method="random_forest")
42+
43+
assert sphere_func_with_data.surrogate.is_fitted
44+
assert sphere_func_with_data.surrogate.method == "random_forest"
45+
46+
def test_predict(self, sphere_func_with_data):
47+
"""Test surrogate prediction."""
48+
sphere_func_with_data.surrogate.fit(method="random_forest")
49+
50+
# Predict at origin (should be near 0)
51+
pred = sphere_func_with_data.surrogate.predict({"x": 0, "y": 0})
52+
assert pred < 5 # Should be small
53+
54+
def test_predict_array(self, sphere_func_with_data):
55+
"""Test surrogate prediction with array input."""
56+
sphere_func_with_data.surrogate.fit(method="random_forest")
57+
58+
X = np.array([[0, 0], [1, 1], [2, 2]])
59+
preds = sphere_func_with_data.surrogate.predict(X)
60+
61+
assert len(preds) == 3
62+
63+
def test_suggest_next(self, sphere_func_with_data):
64+
"""Test next point suggestion."""
65+
sphere_func_with_data.surrogate.fit(method="random_forest")
66+
67+
suggestions = sphere_func_with_data.surrogate.suggest_next(n_suggestions=3)
68+
69+
assert len(suggestions) == 3
70+
assert all("x" in s and "y" in s for s in suggestions)
71+
72+
def test_score(self, sphere_func_with_data):
73+
"""Test surrogate R^2 score."""
74+
sphere_func_with_data.surrogate.fit(method="random_forest")
75+
76+
score = sphere_func_with_data.surrogate.score()
77+
assert 0 < score <= 1 # R^2 should be positive for good fit
78+
79+
def test_not_fitted_error(self, sphere_func_with_data):
80+
"""Test error when predicting without fitting."""
81+
with pytest.raises(RuntimeError, match="No surrogate model fitted"):
82+
sphere_func_with_data.surrogate.predict({"x": 0, "y": 0})

0 commit comments

Comments
 (0)