Skip to content

Commit 2b94152

Browse files
authored
Allow storing the full sequence in SequenceLearner.to_dataframe (#425)
* Allow storing the full sequence in SequenceLearner.to_dataframe * Test dataframes * skip if minimal deps
1 parent 0dd5d98 commit 2b94152

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

adaptive/learner/sequence_learner.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import sys
44
from copy import copy
5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
66

77
import cloudpickle
88
from sortedcontainers import SortedDict, SortedSet
@@ -15,6 +15,10 @@
1515
partial_function_from_dataframe,
1616
)
1717

18+
if TYPE_CHECKING:
19+
from collections.abc import Sequence
20+
from typing import Callable
21+
1822
try:
1923
import pandas
2024

@@ -82,12 +86,17 @@ class SequenceLearner(BaseLearner):
8286
the added benefit of having results in the local kernel already.
8387
"""
8488

85-
def __init__(self, function, sequence):
89+
def __init__(
90+
self,
91+
function: Callable[[Any], Any],
92+
sequence: Sequence[Any],
93+
):
8694
self._original_function = function
8795
self.function = _IgnoreFirstArgument(function)
8896
# prefer range(len(...)) over enumerate to avoid slowdowns
8997
# when passing lazy sequences
90-
self._to_do_indices = SortedSet(range(len(sequence)))
98+
indices = range(len(sequence))
99+
self._to_do_indices = SortedSet(indices)
91100
self._ntotal = len(sequence)
92101
self.sequence = copy(sequence)
93102
self.data = SortedDict()
@@ -161,6 +170,8 @@ def to_dataframe( # type: ignore[override]
161170
index_name: str = "i",
162171
x_name: str = "x",
163172
y_name: str = "y",
173+
*,
174+
full_sequence: bool = False,
164175
) -> pandas.DataFrame:
165176
"""Return the data as a `pandas.DataFrame`.
166177
@@ -178,6 +189,9 @@ def to_dataframe( # type: ignore[override]
178189
Name of the input value, by default "x"
179190
y_name : str, optional
180191
Name of the output value, by default "y"
192+
full_sequence : bool, optional
193+
If True, the returned dataframe will have the full sequence
194+
where the y_name values are pd.NA if not evaluated yet.
181195
182196
Returns
183197
-------
@@ -190,8 +204,16 @@ def to_dataframe( # type: ignore[override]
190204
"""
191205
if not with_pandas:
192206
raise ImportError("pandas is not installed.")
193-
indices, ys = zip(*self.data.items()) if self.data else ([], [])
194-
sequence = [self.sequence[i] for i in indices]
207+
import pandas as pd
208+
209+
if full_sequence:
210+
indices = list(range(len(self.sequence)))
211+
sequence = list(self.sequence)
212+
ys = [self.data.get(i, pd.NA) for i in indices]
213+
else:
214+
indices, ys = zip(*self.data.items()) if self.data else ([], []) # type: ignore[assignment]
215+
sequence = [self.sequence[i] for i in indices]
216+
195217
df = pandas.DataFrame(indices, columns=[index_name])
196218
df[x_name] = sequence
197219
df[y_name] = ys
@@ -209,6 +231,8 @@ def load_dataframe( # type: ignore[override]
209231
index_name: str = "i",
210232
x_name: str = "x",
211233
y_name: str = "y",
234+
*,
235+
full_sequence: bool = False,
212236
):
213237
"""Load data from a `pandas.DataFrame`.
214238
@@ -231,10 +255,25 @@ def load_dataframe( # type: ignore[override]
231255
The ``x_name`` used in ``to_dataframe``, by default "x"
232256
y_name : str, optional
233257
The ``y_name`` used in ``to_dataframe``, by default "y"
258+
full_sequence : bool, optional
259+
The ``full_sequence`` used in ``to_dataframe``, by default False
234260
"""
261+
if not with_pandas:
262+
raise ImportError("pandas is not installed.")
263+
import pandas as pd
264+
235265
indices = df[index_name].values
236266
xs = df[x_name].values
237-
self.tell_many(zip(indices, xs), df[y_name].values)
267+
ys = df[y_name].values
268+
269+
if full_sequence:
270+
evaluated_indices = [i for i, y in enumerate(ys) if y is not pd.NA]
271+
xs = xs[evaluated_indices]
272+
ys = ys[evaluated_indices]
273+
indices = indices[evaluated_indices]
274+
275+
self.tell_many(zip(indices, xs), ys)
276+
238277
if with_default_function_args:
239278
self.function = partial_function_from_dataframe(
240279
self._original_function, df, function_prefix

adaptive/tests/test_sequence_learner.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
import asyncio
22

3+
import pytest
4+
35
from adaptive import Runner, SequenceLearner
4-
from adaptive.runner import SequentialExecutor
6+
from adaptive.learner.learner1D import with_pandas
7+
from adaptive.runner import SequentialExecutor, simple
8+
9+
offset = 0.0123
10+
11+
12+
def peak(x, offset=offset, wait=True):
13+
a = 0.01
14+
return {"x": x + a**2 / (a**2 + (x - offset) ** 2)}
515

616

717
class FailOnce:
@@ -22,3 +32,27 @@ def test_fail_with_sequence_of_unhashable():
2232
runner = Runner(learner, retries=1, executor=SequentialExecutor())
2333
asyncio.get_event_loop().run_until_complete(runner.task)
2434
assert runner.status() == "finished"
35+
36+
37+
@pytest.mark.skipif(not with_pandas, reason="pandas is not installed")
38+
def test_save_load_dataframe():
39+
learner = SequenceLearner(peak, sequence=range(10, 30, 1))
40+
simple(learner, npoints_goal=10)
41+
df = learner.to_dataframe()
42+
assert len(df) == 10
43+
assert df["x"].iloc[0] == 10
44+
df_full = learner.to_dataframe(full_sequence=True)
45+
assert len(df_full) == 20
46+
assert df_full["x"].iloc[0] == 10
47+
assert df_full["x"].iloc[-1] == 29
48+
49+
learner2 = learner.new()
50+
assert learner2.data == {}
51+
learner2.load_dataframe(df)
52+
assert len(learner2.data) == 10
53+
assert learner.to_dataframe().equals(df)
54+
55+
learner3 = learner.new()
56+
learner3.load_dataframe(df_full, full_sequence=True)
57+
assert len(learner3.data) == 10
58+
assert learner3.to_dataframe(full_sequence=True).equals(df_full)

0 commit comments

Comments
 (0)