Skip to content

Commit 7b1e0e5

Browse files
committed
added compatibility with n-dimensional function domain
1 parent c084704 commit 7b1e0e5

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

adaptive/learner/skopt_learner.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from skopt import Optimizer
5+
from collections import OrderedDict
56

67
from adaptive.learner.base_learner import BaseLearner
78
from adaptive.notebook_integration import ensure_holoviews
@@ -25,19 +26,24 @@ class SKOptLearner(Optimizer, BaseLearner):
2526

2627
def __init__(self, function, **kwargs):
2728
self.function = function
28-
self.pending_points = set()
29-
self.data = {}
29+
self.pending_points = list()
30+
self.data = OrderedDict()
3031
super().__init__(**kwargs)
3132

3233
def tell(self, x, y, fit=True):
33-
self.pending_points.discard(x)
34-
self.data[x] = y
35-
super().tell([x], y, fit)
34+
if x in self.pending_points:
35+
self.pending_points.remove(x)
36+
if hasattr(x, '__iter__'):
37+
self.data[tuple(x)] = y
38+
super().tell(x, y, fit)
39+
else:
40+
self.data[x] = y
41+
super().tell([x], y, fit)
3642

3743
def tell_pending(self, x):
3844
# 'skopt.Optimizer' takes care of points we
3945
# have not got results for.
40-
self.pending_points.add(x)
46+
self.pending_points.append(x)
4147

4248
def remove_unfinished(self):
4349
pass

0 commit comments

Comments
 (0)