Skip to content

Commit 178a497

Browse files
committed
Write tests where we readd the data
1 parent 191f781 commit 178a497

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

adaptive/learner/average_learner1D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def to_numpy(self, mean: bool = True) -> np.ndarray:
150150

151151
def to_dataframe(
152152
self,
153-
mean: bool = True,
153+
mean: bool = False,
154154
with_default_function_args: bool = True,
155155
function_prefix: str = "function.",
156156
seed_name: str = "seed",

adaptive/learner/learnerND.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def to_dataframe(
401401
with_default_function_args: bool = True,
402402
function_prefix: str = "function.",
403403
point_names: tuple[str, ...] = ("x", "y", "z"),
404-
value_name: str = "y",
404+
value_name: str = "value",
405405
) -> pandas.DataFrame:
406406
if not with_pandas:
407407
raise ImportError("pandas is not installed.")
@@ -410,7 +410,7 @@ def to_dataframe(
410410
f"point_names ({point_names}) should have the"
411411
f" same length as learner.ndims ({self.ndim})"
412412
)
413-
data = sorted((*x, y) for x, y in self.data.items())
413+
data = list((*x, y) for x, y in self.data.items())
414414
df = pandas.DataFrame(data, columns=[*point_names, value_name])
415415
if with_default_function_args:
416416
assign_defaults(self.function, df, function_prefix)

adaptive/tests/test_learners.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,12 +710,34 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
710710
kw = {"point_names": list("xyz")[: len(learner_kwargs["bounds"])]}
711711
else:
712712
kw = {}
713+
713714
learner = learner_type(generate_random_parametrization(f), **learner_kwargs)
714715
simple_run(learner, 100)
715716
df = learner.to_dataframe(**kw)
716717
assert isinstance(df, pandas.DataFrame)
717-
assert len(df) == learner.npoints
718+
if learner_type is AverageLearner1D:
719+
assert len(df) == learner.nsamples
720+
else:
721+
assert len(df) == learner.npoints
722+
723+
# Add points from the DataFrame to a new empty learner
724+
learner2 = learner_type(generate_random_parametrization(f), **learner_kwargs)
725+
726+
if learner_type is Learner1D:
727+
learner2.tell_many(df["x"], df["y"])
728+
elif learner_type is Learner2D:
729+
learner2.tell_many(df[["x", "y"]].values, df["z"])
730+
elif learner_type is LearnerND:
731+
point_names = list(kw["point_names"])
732+
learner2.tell_many(df[point_names].values, df["value"])
733+
elif learner_type is AverageLearner:
734+
learner2.tell_many(df["seed"].values, df["y"])
735+
elif learner_type is AverageLearner1D:
736+
learner2.tell_many(df[["seed", "x"]].values, df["y"])
737+
else:
738+
raise NotImplementedError()
718739

740+
# Test this for a learner in a BalancingLearner
719741
learners = [
720742
learner_type(generate_random_parametrization(f), **learner_kwargs)
721743
for _ in range(2)
@@ -724,4 +746,8 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
724746
simple_run(learner, 100)
725747
df = learner.to_dataframe(**kw)
726748
assert isinstance(df, pandas.DataFrame)
727-
assert len(df) == learner.npoints
749+
750+
if learner_type is not AverageLearner1D:
751+
assert len(df) == learner.npoints
752+
753+
# TODO: Test this for a learner in a DataSaver

0 commit comments

Comments
 (0)