Skip to content

Commit 34dd608

Browse files
committed
comply with adaptive structure, fixed 1D test, added 4D test
1 parent 7b1e0e5 commit 34dd608

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

adaptive/learner/skopt_learner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,24 @@ class SKOptLearner(Optimizer, BaseLearner):
2626

2727
def __init__(self, function, **kwargs):
2828
self.function = function
29-
self.pending_points = list()
29+
self.pending_points = set()
3030
self.data = OrderedDict()
3131
super().__init__(**kwargs)
3232

3333
def tell(self, x, y, fit=True):
34-
if x in self.pending_points:
35-
self.pending_points.remove(x)
3634
if hasattr(x, '__iter__'):
35+
self.pending_points.discard(tuple(x))
3736
self.data[tuple(x)] = y
3837
super().tell(x, y, fit)
3938
else:
39+
self.pending_points.discard(x)
4040
self.data[x] = y
4141
super().tell([x], y, fit)
4242

4343
def tell_pending(self, x):
4444
# 'skopt.Optimizer' takes care of points we
4545
# have not got results for.
46-
self.pending_points.append(x)
46+
self.pending_points.add(tuple(x))
4747

4848
def remove_unfinished(self):
4949
pass

adaptive/tests/test_skopt_learner.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,23 @@ def g(x, noise_level=0.1):
2525
for _ in range(11):
2626
(x,), _ = learner.ask(1)
2727
learner.tell(x, learner.function(x))
28+
29+
30+
@pytest.mark.skipif(not with_scikit_optimize, reason="scikit-optimize is not installed")
31+
def test_skopt_learner_4D_runs():
32+
"""The SKOptLearner provides very few guarantees about its
33+
behaviour, so we only test the most basic usage
34+
In this case we test also for 2D domain
35+
"""
36+
37+
def g(x, noise_level=0.1):
38+
return np.sin(5 * (x[0] + x[1] + x[2] + x[3])) * (
39+
1 - np.tanh(x[0] ** 2 + x[1] ** 2 + x[2] ** 2 + x[3] ** 2)
40+
) + np.random.randn() * noise_level
41+
42+
learner = SKOptLearner(g, dimensions=[(-2.0, 2.0), (-2.0, 2.0),
43+
(-2.0, 2.0), (-2.0, 2.0)])
44+
45+
for _ in range(11):
46+
(x,), _ = learner.ask(1)
47+
learner.tell(x, learner.function(x))

0 commit comments

Comments
 (0)