Skip to content

Commit b88b803

Browse files
committed
Implement DataSaver.load_dataframe
1 parent 02362b9 commit b88b803

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

adaptive/learner/data_saver.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
with_pandas = False
1414

1515

16+
def _to_key(x):
17+
return tuple(x.values) if x.values.size > 1 else x.item()
18+
19+
1620
class DataSaver:
1721
"""Save extra data associated with the values that need to be learned.
1822
@@ -77,14 +81,30 @@ def to_dataframe(
7781
raise ImportError("pandas is not installed.")
7882
df = self.learner.to_dataframe(**kwargs)
7983

80-
def to_key(x):
81-
return tuple(x.values) if x.values.size > 1 else x.item()
82-
8384
df[extra_data_name] = [
84-
self.extra_data[to_key(x)] for _, x in df[df.attrs["inputs"]].iterrows()
85+
self.extra_data[_to_key(x)] for _, x in df[df.attrs["inputs"]].iterrows()
8586
]
8687
return df
8788

89+
def load_dataframe(
90+
self, df: pandas.DataFrame, extra_data_name: str = "extra_data", **kwargs
91+
):
92+
"""Load the data from a `pandas.DataFrame` into the learner.
93+
94+
Parameters
95+
----------
96+
df : pandas.DataFrame
97+
DataFrame with the data to load.
98+
extra_data_name : str, optional
99+
The ``extra_data_name`` used in `to_dataframe`, by default "extra_data".
100+
**kwargs : dict
101+
Keyword arguments passed to each ``child_learner.load_dataframe(**kwargs)``.
102+
"""
103+
self.learner.load_dataframe(df, **kwargs)
104+
for _, x in df[df.attrs["inputs"] + [extra_data_name]].iterrows():
105+
key = _to_key(x[:-1])
106+
self.extra_data[key] = x[-1]
107+
88108
def _get_data(self):
89109
return self.learner._get_data(), self.extra_data
90110

adaptive/tests/test_learners.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,3 +783,13 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
783783
assert len(df) == data_saver.nsamples
784784
else:
785785
assert len(df) == data_saver.npoints
786+
787+
# Test loading from a DataFrame into a new DataSaver
788+
learner2 = learner_type(learner.function, **learner_kwargs)
789+
data_saver2 = DataSaver(learner2, operator.itemgetter("result"))
790+
data_saver2.load_dataframe(df, **kw)
791+
assert data_saver2.extra_data.keys() == data_saver.extra_data.keys()
792+
assert all(
793+
data_saver2.extra_data[k] == data_saver.extra_data[k]
794+
for k in data_saver.extra_data.keys()
795+
)

0 commit comments

Comments
 (0)