From 8b4d5794b49f02146920f9e48808761a3cef0273 Mon Sep 17 00:00:00 2001 From: H0TB0X420 Date: Tue, 16 Sep 2025 15:16:47 -0400 Subject: [PATCH] Add to_batches() and interpolate() methods to DataFrame - Add to_batches() as alias for collect() returning RecordBatch list - Add interpolate() method with forward_fill support - Add deprecation warning to collect() method - Add comprehensive tests for both methods Addresses items from RFC #875 --- python/datafusion/dataframe.py | 43 +++++++++++++++++++++++++++++++++- python/tests/test_dataframe.py | 41 ++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 181c29db4..12e8f7c35 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -40,9 +40,11 @@ from datafusion._internal import DataFrame as DataFrameInternal from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal -from datafusion.expr import Expr, SortExpr, sort_or_default +from datafusion.expr import Expr, SortExpr, sort_or_default, Window from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.record_batch import RecordBatchStream +from datafusion.functions import col, nvl, last_value +from datafusion.common import NullTreatment if TYPE_CHECKING: import pathlib @@ -360,6 +362,9 @@ def describe(self) -> DataFrame: """ return DataFrame(self.df.describe()) + @deprecated( + "schema() is deprecated. Use :py:meth:`~DataFrame.get_schema` instead" + ) def schema(self) -> pa.Schema: """Return the :py:class:`pyarrow.Schema` of this DataFrame. @@ -370,6 +375,39 @@ def schema(self) -> pa.Schema: Describing schema of the DataFrame """ return self.df.schema() + + def to_batches(self) -> list[pa.RecordBatch]: + """Convert DataFrame to list of RecordBatches.""" + return self.collect() # delegate to existing method + + def interpolate(self, method: str = "forward_fill", **kwargs) -> DataFrame: + """Interpolate missing values per column. + + Args: + method: Interpolation method ('linear', 'forward_fill', 'backward_fill') + + Returns: + DataFrame with interpolated values + + Raises: + NotImplementedError: Linear interpolation not yet supported + """ + if method == "forward_fill": + exprs = [] + for field in self.schema(): + window = Window(order_by=col(field.name)) + expr = nvl(col(field.name),last_value(col(field.name)).over(window)).alias(field.name) + exprs.append(expr) + return self.select(*exprs) + + elif method == "backward_fill": + raise NotImplementedError("backward_fill not yet implemented") + + elif method == "linear": + raise NotImplementedError("Linear interpolation requires complex window function logic") + + else: + raise ValueError(f"Unknown interpolation method: {method}") @deprecated( "select_columns() is deprecated. Use :py:meth:`~DataFrame.select` instead" @@ -592,6 +630,9 @@ def tail(self, n: int = 5) -> DataFrame: """ return DataFrame(self.df.limit(n, max(0, self.count() - n))) + @deprecated( + "collect() returning RecordBatch list is deprecated. Use to_batches() for RecordBatch list or collect() will return DataFrame in future versions" + ) def collect(self) -> list[pa.RecordBatch]: """Execute this :py:class:`DataFrame` and collect results into memory. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 343d32a92..e921b7055 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -185,6 +185,47 @@ def get_header_style(self) -> str: "padding: 10px; border: 1px solid #3367d6;" ) +def test_to_batches(df): + """Test to_batches method returns list of RecordBatches.""" + batches = df.to_batches() + assert isinstance(batches, list) + assert len(batches) > 0 + assert all(isinstance(batch, pa.RecordBatch) for batch in batches) + + + collect_batches = df.collect() + assert len(batches) == len(collect_batches) + for i, batch in enumerate(batches): + assert batch.equals(collect_batches[i]) + + +def test_interpolate_forward_fill(ctx): + """Test interpolate method with forward_fill.""" + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, None, 3, None]), pa.array([4.0, None, 6.0, None])], + names=["int_col", "float_col"], + ) + df = ctx.create_dataframe([[batch]]) + + result = df.interpolate("forward_fill") + + assert isinstance(result, DataFrame) + + +def test_interpolate_unsupported_method(ctx): + """Test interpolate with unsupported method raises error.""" + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3])], names=["a"] + ) + df = ctx.create_dataframe([[batch]]) + + with pytest.raises(NotImplementedError, match="requires complex window"): + df.interpolate("linear") + + with pytest.raises(ValueError, match="Unknown interpolation method"): + df.interpolate("unknown") + def count_table_rows(html_content: str) -> int: """Count the number of table rows in HTML content.