Skip to content

Commit 02362b9

Browse files
committed
Implement DataSaver.to_dataframe
1 parent 222b84f commit 02362b9

File tree

8 files changed

+82
-0
lines changed

8 files changed

+82
-0
lines changed

adaptive/learner/average_learner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def to_dataframe(
116116
if not with_pandas:
117117
raise ImportError("pandas is not installed.")
118118
df = pandas.DataFrame(sorted(self.data.items()), columns=[seed_name, y_name])
119+
df.attrs["inputs"] = [seed_name]
120+
df.attrs["output"] = y_name
119121
if with_default_function_args:
120122
assign_defaults(self.function, df, function_prefix)
121123
return df

adaptive/learner/average_learner1D.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def to_dataframe(
196196
]
197197
columns = [seed_name, x_name, y_name]
198198
df = pandas.DataFrame(data, columns=columns)
199+
df.attrs["inputs"] = [seed_name, x_name]
200+
df.attrs["output"] = y_name
199201
if with_default_function_args:
200202
assign_defaults(self.function, df, function_prefix)
201203
return df

adaptive/learner/data_saver.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
from adaptive.learner.base_learner import BaseLearner
55
from adaptive.utils import copy_docstring_from
66

7+
try:
8+
import pandas
9+
10+
with_pandas = True
11+
12+
except ModuleNotFoundError:
13+
with_pandas = False
14+
715

816
class DataSaver:
917
"""Save extra data associated with the values that need to be learned.
@@ -44,6 +52,39 @@ def tell(self, x, result):
4452
def tell_pending(self, x):
4553
self.learner.tell_pending(x)
4654

55+
def to_dataframe(
56+
self, extra_data_name: str = "extra_data", **kwargs
57+
) -> pandas.DataFrame:
58+
"""Return the data as a concatenated `pandas.DataFrame` from child learners.
59+
60+
Parameters
61+
----------
62+
extra_data_name : str, optional
63+
The name of the column containing the extra data, by default "extra_data".
64+
**kwargs : dict
65+
Keyword arguments passed to the ``child_learner.to_dataframe(**kwargs)``.
66+
67+
Returns
68+
-------
69+
pandas.DataFrame
70+
71+
Raises
72+
------
73+
ImportError
74+
If `pandas` is not installed.
75+
"""
76+
if not with_pandas:
77+
raise ImportError("pandas is not installed.")
78+
df = self.learner.to_dataframe(**kwargs)
79+
80+
def to_key(x):
81+
return tuple(x.values) if x.values.size > 1 else x.item()
82+
83+
df[extra_data_name] = [
84+
self.extra_data[to_key(x)] for _, x in df[df.attrs["inputs"]].iterrows()
85+
]
86+
return df
87+
4788
def _get_data(self):
4889
return self.learner._get_data(), self.extra_data
4990

adaptive/learner/learner1D.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def to_dataframe(
369369
xs, ys = zip(*sorted(self.data.items())) if self.data else ([], [])
370370
df = pandas.DataFrame(xs, columns=[x_name])
371371
df[y_name] = ys
372+
df.attrs["inputs"] = [x_name]
373+
df.attrs["output"] = y_name
372374
if with_default_function_args:
373375
assign_defaults(self.function, df, function_prefix)
374376
return df

adaptive/learner/learner2D.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def to_dataframe(
439439
raise ImportError("pandas is not installed.")
440440
data = sorted((x, y, z) for (x, y), z in self.data.items())
441441
df = pandas.DataFrame(data, columns=[x_name, y_name, z_name])
442+
df.attrs["inputs"] = [x_name, y_name]
442443
if with_default_function_args:
443444
assign_defaults(self.function, df, function_prefix)
444445
return df

adaptive/learner/learnerND.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ def to_dataframe(
444444
)
445445
data = list((*x, y) for x, y in self.data.items())
446446
df = pandas.DataFrame(data, columns=[*point_names, value_name])
447+
df.attrs["inputs"] = list(point_names)
447448
if with_default_function_args:
448449
assign_defaults(self.function, df, function_prefix)
449450
return df

adaptive/learner/sequence_learner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def to_dataframe(
172172
df = pandas.DataFrame(indices, columns=[index_name])
173173
df[x_name] = sequence
174174
df[y_name] = ys
175+
df.attrs["inputs"] = [index_name]
176+
df.attrs["output"] = y_name
175177
if with_default_function_args:
176178
assign_defaults(self.function, df, function_prefix)
177179
return df

adaptive/tests/test_learners.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import shutil
1010
import tempfile
11+
import time
1112

1213
import flaky
1314
import numpy as np
@@ -697,13 +698,25 @@ def test_learner_subdomain(learner_type, f, learner_kwargs):
697698
raise NotImplementedError()
698699

699700

701+
def add_time(f):
702+
@ft.wraps(f)
703+
def wrapper(*args, **kwargs):
704+
t0 = time.time()
705+
result = f(*args, **kwargs)
706+
return {"result": result, "time": time.time() - t0}
707+
708+
return wrapper
709+
710+
700711
@run_with(
701712
Learner1D,
702713
Learner2D,
703714
LearnerND,
704715
AverageLearner,
705716
AverageLearner1D,
706717
SequenceLearner,
718+
IntegratorLearner,
719+
with_all_loss_functions=False,
707720
)
708721
def test_to_dataframe(learner_type, f, learner_kwargs):
709722
if learner_type is LearnerND:
@@ -752,3 +765,21 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
752765
bal_learner2 = BalancingLearner(learners2)
753766
bal_learner2.load_dataframe(df_bal, **kw)
754767
assert bal_learner2.npoints == bal_learner.npoints
768+
769+
if learner_type is SequenceLearner:
770+
# We do not test the DataSaver with the SequenceLearner
771+
# because the DataSaver is not compatible with the SequenceLearner.
772+
return
773+
774+
# Test with DataSaver
775+
learner = learner_type(
776+
add_time(generate_random_parametrization(f)), **learner_kwargs
777+
)
778+
data_saver = DataSaver(learner, operator.itemgetter("result"))
779+
df = data_saver.to_dataframe(**kw) # test if empty dataframe works
780+
simple_run(data_saver, 100)
781+
df = data_saver.to_dataframe(**kw)
782+
if learner_type is AverageLearner1D:
783+
assert len(df) == data_saver.nsamples
784+
else:
785+
assert len(df) == data_saver.npoints

0 commit comments

Comments
 (0)