Skip to content

Commit 68055fe

Browse files
committed
Add BalancingLearner.load_dataframe
1 parent 82a3cb5 commit 68055fe

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,14 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
389389
learners.append(learner)
390390
return cls(learners, cdims=arguments)
391391

392-
def to_dataframe(self, **kwargs):
392+
def to_dataframe(self, index_name: str = "learner_index", **kwargs):
393393
"""Return the data as a concatenated `pandas.DataFrame` from child learners.
394394
395395
Parameters
396396
----------
397+
index_name : str, optional
398+
The name of the index column indicating the learner index,
399+
by default "learner_index".
397400
**kwargs : dict
398401
Keyword arguments passed to each ``child_learner.to_dataframe(**kwargs)``.
399402
@@ -408,10 +411,33 @@ def to_dataframe(self, **kwargs):
408411
"""
409412
if not with_pandas:
410413
raise ImportError("pandas is not installed.")
411-
dfs = [learner.to_dataframe(**kwargs) for learner in self.learners]
414+
dfs = []
415+
for i, learner in enumerate(self.learners):
416+
df = learner.to_dataframe(**kwargs)
417+
cols = list(df.columns)
418+
df[index_name] = i
419+
df = df[[index_name] + cols]
420+
dfs.append(df)
412421
df = pandas.concat(dfs, axis=0, ignore_index=True)
413422
return df
414423

424+
def load_dataframe(
425+
self, df: pandas.DataFrame, index_name: str = "learner_index", **kwargs
426+
):
427+
"""Load the data from a `pandas.DataFrame` into the child learners.
428+
429+
Parameters
430+
----------
431+
df : pandas.DataFrame
432+
DataFrame with the data to load.
433+
index_name : str, optional
434+
The ``index_name`` used in `to_dataframe`, by default "learner_index".
435+
**kwargs : dict
436+
Keyword arguments passed to each ``child_learner.load_dataframe(**kwargs)``.
437+
"""
438+
for i, gr in df.groupby(index_name):
439+
self.learners[i].load_dataframe(gr, **kwargs)
440+
415441
def save(self, fname, compress=True):
416442
"""Save the data of the child learners into pickle files
417443
in a directory.

adaptive/tests/test_learners.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -736,12 +736,19 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
736736
learner_type(generate_random_parametrization(f), **learner_kwargs)
737737
for _ in range(2)
738738
]
739-
learner = BalancingLearner(learners)
740-
simple_run(learner, 100)
741-
df = learner.to_dataframe(**kw)
742-
assert isinstance(df, pandas.DataFrame)
739+
bal_learner = BalancingLearner(learners)
740+
simple_run(bal_learner, 100)
741+
df_bal = bal_learner.to_dataframe(**kw)
742+
assert isinstance(df_bal, pandas.DataFrame)
743743

744744
if learner_type is not AverageLearner1D:
745-
assert len(df) == learner.npoints
745+
assert len(df_bal) == bal_learner.npoints
746746

747-
# TODO: Test this for a learner in a DataSaver
747+
# Test loading from a DataFrame into the BalancingLearner
748+
learners2 = [
749+
learner_type(generate_random_parametrization(f), **learner_kwargs)
750+
for _ in range(2)
751+
]
752+
bal_learner2 = BalancingLearner(learners2)
753+
bal_learner2.load_dataframe(df_bal, **kw)
754+
assert bal_learner2.npoints == bal_learner.npoints

0 commit comments

Comments
 (0)