Skip to content

Commit de380d5

Browse files
committed
Add SequenceLearner.to_dataframe
1 parent 178a497 commit de380d5

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

adaptive/learner/sequence_learner.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
from sortedcontainers import SortedDict, SortedSet
55

66
from adaptive.learner.base_learner import BaseLearner
7+
from adaptive.utils import assign_defaults
8+
9+
try:
10+
import pandas
11+
12+
with_pandas = True
13+
14+
except ModuleNotFoundError:
15+
with_pandas = False
716

817

918
class _IgnoreFirstArgument:
@@ -120,6 +129,25 @@ def result(self):
120129
def npoints(self):
121130
return len(self.data)
122131

132+
def to_dataframe(
133+
self,
134+
with_default_function_args: bool = True,
135+
function_prefix: str = "function.",
136+
index_name: str = "i",
137+
x_name: str = "x",
138+
y_name: str = "y",
139+
) -> pandas.DataFrame:
140+
if not with_pandas:
141+
raise ImportError("pandas is not installed.")
142+
indices, ys = zip(*self.data.items()) if self.data else ([], [])
143+
sequence = [self.sequence[i] for i in indices]
144+
df = pandas.DataFrame(indices, columns=[index_name])
145+
df[x_name] = sequence
146+
df[y_name] = ys
147+
if with_default_function_args:
148+
assign_defaults(self.function, df, function_prefix)
149+
return df
150+
123151
def _get_data(self):
124152
return self.data
125153

adaptive/tests/test_learners.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,11 +703,11 @@ def test_learner_subdomain(learner_type, f, learner_kwargs):
703703
LearnerND,
704704
AverageLearner,
705705
AverageLearner1D,
706-
# SequenceLearner, # TODO: implement this
706+
SequenceLearner,
707707
)
708708
def test_to_dataframe(learner_type, f, learner_kwargs):
709709
if learner_type is LearnerND:
710-
kw = {"point_names": list("xyz")[: len(learner_kwargs["bounds"])]}
710+
kw = {"point_names": tuple("xyz")[: len(learner_kwargs["bounds"])]}
711711
else:
712712
kw = {}
713713

@@ -734,6 +734,8 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
734734
learner2.tell_many(df["seed"].values, df["y"])
735735
elif learner_type is AverageLearner1D:
736736
learner2.tell_many(df[["seed", "x"]].values, df["y"])
737+
elif learner_type is SequenceLearner:
738+
learner2.tell_many(df[["i", "x"]].values, df["y"])
737739
else:
738740
raise NotImplementedError()
739741

0 commit comments

Comments
 (0)