2
2
3
3
import sys
4
4
from copy import copy
5
- from typing import Any
5
+ from typing import TYPE_CHECKING , Any
6
6
7
7
import cloudpickle
8
8
from sortedcontainers import SortedDict , SortedSet
15
15
partial_function_from_dataframe ,
16
16
)
17
17
18
+ if TYPE_CHECKING :
19
+ from collections .abc import Sequence
20
+ from typing import Callable
21
+
18
22
try :
19
23
import pandas
20
24
@@ -82,12 +86,17 @@ class SequenceLearner(BaseLearner):
82
86
the added benefit of having results in the local kernel already.
83
87
"""
84
88
85
- def __init__ (self , function , sequence ):
89
+ def __init__ (
90
+ self ,
91
+ function : Callable [[Any ], Any ],
92
+ sequence : Sequence [Any ],
93
+ ):
86
94
self ._original_function = function
87
95
self .function = _IgnoreFirstArgument (function )
88
96
# prefer range(len(...)) over enumerate to avoid slowdowns
89
97
# when passing lazy sequences
90
- self ._to_do_indices = SortedSet (range (len (sequence )))
98
+ indices = range (len (sequence ))
99
+ self ._to_do_indices = SortedSet (indices )
91
100
self ._ntotal = len (sequence )
92
101
self .sequence = copy (sequence )
93
102
self .data = SortedDict ()
@@ -161,6 +170,8 @@ def to_dataframe( # type: ignore[override]
161
170
index_name : str = "i" ,
162
171
x_name : str = "x" ,
163
172
y_name : str = "y" ,
173
+ * ,
174
+ full_sequence : bool = False ,
164
175
) -> pandas .DataFrame :
165
176
"""Return the data as a `pandas.DataFrame`.
166
177
@@ -178,6 +189,9 @@ def to_dataframe( # type: ignore[override]
178
189
Name of the input value, by default "x"
179
190
y_name : str, optional
180
191
Name of the output value, by default "y"
192
+ full_sequence : bool, optional
193
+ If True, the returned dataframe will have the full sequence
194
+ where the y_name values are pd.NA if not evaluated yet.
181
195
182
196
Returns
183
197
-------
@@ -190,8 +204,16 @@ def to_dataframe( # type: ignore[override]
190
204
"""
191
205
if not with_pandas :
192
206
raise ImportError ("pandas is not installed." )
193
- indices , ys = zip (* self .data .items ()) if self .data else ([], [])
194
- sequence = [self .sequence [i ] for i in indices ]
207
+ import pandas as pd
208
+
209
+ if full_sequence :
210
+ indices = list (range (len (self .sequence )))
211
+ sequence = list (self .sequence )
212
+ ys = [self .data .get (i , pd .NA ) for i in indices ]
213
+ else :
214
+ indices , ys = zip (* self .data .items ()) if self .data else ([], []) # type: ignore[assignment]
215
+ sequence = [self .sequence [i ] for i in indices ]
216
+
195
217
df = pandas .DataFrame (indices , columns = [index_name ])
196
218
df [x_name ] = sequence
197
219
df [y_name ] = ys
@@ -209,6 +231,8 @@ def load_dataframe( # type: ignore[override]
209
231
index_name : str = "i" ,
210
232
x_name : str = "x" ,
211
233
y_name : str = "y" ,
234
+ * ,
235
+ full_sequence : bool = False ,
212
236
):
213
237
"""Load data from a `pandas.DataFrame`.
214
238
@@ -231,10 +255,25 @@ def load_dataframe( # type: ignore[override]
231
255
The ``x_name`` used in ``to_dataframe``, by default "x"
232
256
y_name : str, optional
233
257
The ``y_name`` used in ``to_dataframe``, by default "y"
258
+ full_sequence : bool, optional
259
+ The ``full_sequence`` used in ``to_dataframe``, by default False
234
260
"""
261
+ if not with_pandas :
262
+ raise ImportError ("pandas is not installed." )
263
+ import pandas as pd
264
+
235
265
indices = df [index_name ].values
236
266
xs = df [x_name ].values
237
- self .tell_many (zip (indices , xs ), df [y_name ].values )
267
+ ys = df [y_name ].values
268
+
269
+ if full_sequence :
270
+ evaluated_indices = [i for i , y in enumerate (ys ) if y is not pd .NA ]
271
+ xs = xs [evaluated_indices ]
272
+ ys = ys [evaluated_indices ]
273
+ indices = indices [evaluated_indices ]
274
+
275
+ self .tell_many (zip (indices , xs ), ys )
276
+
238
277
if with_default_function_args :
239
278
self .function = partial_function_from_dataframe (
240
279
self ._original_function , df , function_prefix
0 commit comments