-
Notifications
You must be signed in to change notification settings - Fork 128
Add to_batches() and interpolate() methods to DataFrame #1241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,9 +40,11 @@ | |
from datafusion._internal import DataFrame as DataFrameInternal | ||
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal | ||
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal | ||
from datafusion.expr import Expr, SortExpr, sort_or_default | ||
from datafusion.expr import Expr, SortExpr, sort_or_default, Window | ||
from datafusion.plan import ExecutionPlan, LogicalPlan | ||
from datafusion.record_batch import RecordBatchStream | ||
from datafusion.functions import col, nvl, last_value | ||
from datafusion.common import NullTreatment | ||
|
||
if TYPE_CHECKING: | ||
import pathlib | ||
|
@@ -360,6 +362,9 @@ def describe(self) -> DataFrame: | |
""" | ||
return DataFrame(self.df.describe()) | ||
|
||
@deprecated( | ||
"schema() is deprecated. Use :py:meth:`~DataFrame.get_schema` instead" | ||
) | ||
def schema(self) -> pa.Schema: | ||
"""Return the :py:class:`pyarrow.Schema` of this DataFrame. | ||
|
||
|
@@ -370,6 +375,39 @@ def schema(self) -> pa.Schema: | |
Describing schema of the DataFrame | ||
""" | ||
return self.df.schema() | ||
|
||
def to_batches(self) -> list[pa.RecordBatch]: | ||
"""Convert DataFrame to list of RecordBatches.""" | ||
return self.collect() # delegate to existing method | ||
Comment on lines
+379
to
+381
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My opinion (see #1227) is to limit the surface area where we explicitly depend on pyarrow. Especially in this case it's just an alias? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 100% agree, this should return arro3 recordbatches There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not saying it should even return arro3 RecordBatches... because that's still an external dependency that datafusion is requiring on users. Datafusion could return a minimal batch object that just holds the RecordBatch pointer and then the user transfers it to their library of choice. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's possible but I would argue it's more convenient if it's in some way already usable instead of just a pointer. Arro3 is that small that it's negligible in size overall There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we should get @timsaucer 's thoughts in #1227: is having a required dependency on pyarrow a problem? What should we do about it? Do we want to depend on arro3-core instead? Or have functions that rely on pyarrow but error if pyarrow isn't installed? @ion-elgreco feel free to write down your thoughts there too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that Now if there is a way we can remove this dependency and it doesn't break existing workflows, that would be even better. I haven't made the time to sit down and play with it, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think the straightforward way to do that is to remove pyarrow as a required dependency and error if it's not installed. But a separate question is whether we should be adding new methods that explicitly depend on pyarrow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could do a try-except to throw an ImportError if it is not installed. But it might make more sense to drop I'll remove it and focus on fixing the |
||
|
||
def interpolate(self, method: str = "forward_fill", **kwargs) -> DataFrame: | ||
"""Interpolate missing values per column. | ||
|
||
Args: | ||
method: Interpolation method ('linear', 'forward_fill', 'backward_fill') | ||
Comment on lines
+383
to
+387
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At the outset this doesn't look quite right to me.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two quick questions:
Thank you for your time! |
||
|
||
Returns: | ||
DataFrame with interpolated values | ||
|
||
Raises: | ||
NotImplementedError: Linear interpolation not yet supported | ||
""" | ||
if method == "forward_fill": | ||
exprs = [] | ||
for field in self.schema(): | ||
window = Window(order_by=col(field.name)) | ||
expr = nvl(col(field.name),last_value(col(field.name)).over(window)).alias(field.name) | ||
exprs.append(expr) | ||
return self.select(*exprs) | ||
|
||
elif method == "backward_fill": | ||
raise NotImplementedError("backward_fill not yet implemented") | ||
|
||
elif method == "linear": | ||
raise NotImplementedError("Linear interpolation requires complex window function logic") | ||
|
||
else: | ||
raise ValueError(f"Unknown interpolation method: {method}") | ||
|
||
@deprecated( | ||
"select_columns() is deprecated. Use :py:meth:`~DataFrame.select` instead" | ||
|
@@ -592,6 +630,9 @@ def tail(self, n: int = 5) -> DataFrame: | |
""" | ||
return DataFrame(self.df.limit(n, max(0, self.count() - n))) | ||
|
||
@deprecated( | ||
"collect() returning RecordBatch list is deprecated. Use to_batches() for RecordBatch list or collect() will return DataFrame in future versions" | ||
) | ||
H0TB0X420 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def collect(self) -> list[pa.RecordBatch]: | ||
"""Execute this :py:class:`DataFrame` and collect results into memory. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,6 +185,47 @@ def get_header_style(self) -> str: | |
"padding: 10px; border: 1px solid #3367d6;" | ||
) | ||
|
||
def test_to_batches(df): | ||
"""Test to_batches method returns list of RecordBatches.""" | ||
batches = df.to_batches() | ||
assert isinstance(batches, list) | ||
assert len(batches) > 0 | ||
assert all(isinstance(batch, pa.RecordBatch) for batch in batches) | ||
|
||
|
||
collect_batches = df.collect() | ||
assert len(batches) == len(collect_batches) | ||
for i, batch in enumerate(batches): | ||
assert batch.equals(collect_batches[i]) | ||
|
||
|
||
def test_interpolate_forward_fill(ctx): | ||
"""Test interpolate method with forward_fill.""" | ||
|
||
batch = pa.RecordBatch.from_arrays( | ||
[pa.array([1, None, 3, None]), pa.array([4.0, None, 6.0, None])], | ||
names=["int_col", "float_col"], | ||
) | ||
df = ctx.create_dataframe([[batch]]) | ||
|
||
result = df.interpolate("forward_fill") | ||
|
||
assert isinstance(result, DataFrame) | ||
Comment on lines
+211
to
+213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would probably want to collect the results and verify they fill as expected. |
||
|
||
|
||
def test_interpolate_unsupported_method(ctx): | ||
"""Test interpolate with unsupported method raises error.""" | ||
batch = pa.RecordBatch.from_arrays( | ||
[pa.array([1, 2, 3])], names=["a"] | ||
) | ||
df = ctx.create_dataframe([[batch]]) | ||
|
||
with pytest.raises(NotImplementedError, match="requires complex window"): | ||
df.interpolate("linear") | ||
|
||
with pytest.raises(ValueError, match="Unknown interpolation method"): | ||
df.interpolate("unknown") | ||
|
||
|
||
def count_table_rows(html_content: str) -> int: | ||
"""Count the number of table rows in HTML content. | ||
|
Uh oh!
There was an error while loading. Please reload this page.