Skip to content

Commit e46daf0

Browse files
committed
Test for attrs
1 parent b88b803 commit e46daf0

File tree

6 files changed

+20
-2
lines changed

6 files changed

+20
-2
lines changed

adaptive/learner/data_saver.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ def to_dataframe(
8787
return df
8888

8989
def load_dataframe(
90-
self, df: pandas.DataFrame, extra_data_name: str = "extra_data", **kwargs
90+
self,
91+
df: pandas.DataFrame,
92+
extra_data_name: str = "extra_data",
93+
input_names: tuple[str] = (),
94+
**kwargs
9195
):
9296
"""Load the data from a `pandas.DataFrame` into the learner.
9397
@@ -97,11 +101,18 @@ def load_dataframe(
97101
DataFrame with the data to load.
98102
extra_data_name : str, optional
99103
The ``extra_data_name`` used in `to_dataframe`, by default "extra_data".
104+
input_names : tuple[str], optional
105+
The input names of the child learner. By default the input names are
106+
taken from ``df.attrs["inputs"]``, however, metadata is not preserved
107+
when saving/loading a DataFrame to/from a file. In that case, the input
108+
names can be passed explicitly. For example, for a 2D learner, this would
109+
be ``input_names=('x', 'y')``.
100110
**kwargs : dict
101111
Keyword arguments passed to each ``child_learner.load_dataframe(**kwargs)``.
102112
"""
103113
self.learner.load_dataframe(df, **kwargs)
104-
for _, x in df[df.attrs["inputs"] + [extra_data_name]].iterrows():
114+
keys = df.attrs.get("inputs", list(input_names))
115+
for _, x in df[keys + [extra_data_name]].iterrows():
105116
key = _to_key(x[:-1])
106117
self.extra_data[key] = x[-1]
107118

adaptive/learner/integrator_learner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,8 @@ def to_dataframe(
594594
if not with_pandas:
595595
raise ImportError("pandas is not installed.")
596596
df = pandas.DataFrame(sorted(self.data.items()), columns=[x_name, y_name])
597+
df.attrs["inputs"] = [x_name]
598+
df.attrs["output"] = y_name
597599
if with_default_function_args:
598600
assign_defaults(self.function, df, function_prefix)
599601
return df

adaptive/learner/learner2D.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def to_dataframe(
440440
data = sorted((x, y, z) for (x, y), z in self.data.items())
441441
df = pandas.DataFrame(data, columns=[x_name, y_name, z_name])
442442
df.attrs["inputs"] = [x_name, y_name]
443+
df.attrs["output"] = z_name
443444
if with_default_function_args:
444445
assign_defaults(self.function, df, function_prefix)
445446
return df

adaptive/learner/learnerND.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ def to_dataframe(
445445
data = list((*x, y) for x, y in self.data.items())
446446
df = pandas.DataFrame(data, columns=[*point_names, value_name])
447447
df.attrs["inputs"] = list(point_names)
448+
df.attrs["output"] = value_name
448449
if with_default_function_args:
449450
assign_defaults(self.function, df, function_prefix)
450451
return df

adaptive/tests/test_learners.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,8 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
729729
# Test empty dataframe
730730
df = learner.to_dataframe(**kw)
731731
assert len(df) == 0
732+
assert "inputs" in df.attrs
733+
assert "output" in df.attrs
732734

733735
# Run the learner
734736
simple_run(learner, 100)

adaptive/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def assign_defaults(
119119
defaults = _default_parameters(function, function_prefix, start_index)
120120
for k, v in defaults.items():
121121
df[k] = len(df) * [v]
122+
df[k] = df[k].astype("category")
122123

123124

124125
def partial_function_from_dataframe(function, df, function_prefix: str = "function."):

0 commit comments

Comments
 (0)