Skip to content

Commit 91ccd1e

Browse files
committed
feat: add streaming utilities, range support, and improve async handling in DataFrame
- Add `range` method to SessionContext and iterator support to DataFrame - Introduce `spawn_stream` utility and refactor async execution for better signal handling - Add tests for `KeyboardInterrupt` in `__arrow_c_stream__` and incremental DataFrame streaming - Improve memory usage tracking in tests with psutil - Update DataFrame docs with PyArrow streaming section and enhance `__arrow_c_stream__` documentation - Replace Tokio runtime creation with `spawn_stream` in PySessionContext - Bump datafusion packages to 49.0.1 and update dependencies - Remove unused imports and restore main Cargo.toml
1 parent 61f981b commit 91ccd1e

File tree

9 files changed

+337
-48
lines changed

9 files changed

+337
-48
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,31 @@ To materialize the results of your DataFrame operations:
145145
146146
# Display results
147147
df.show() # Print tabular format to console
148-
148+
149149
# Count rows
150150
count = df.count()
151151
152+
PyArrow Streaming
153+
-----------------
154+
155+
DataFusion DataFrames implement the ``__arrow_c_stream__`` protocol, enabling
156+
zero-copy streaming into libraries like `PyArrow <https://arrow.apache.org/>`_.
157+
Earlier versions eagerly converted the entire DataFrame when exporting to
158+
PyArrow, which could exhaust memory on large datasets. With streaming, batches
159+
are produced lazily so you can process arbitrarily large results without
160+
out-of-memory errors.
161+
162+
.. code-block:: python
163+
164+
import pyarrow as pa
165+
166+
# Create a PyArrow RecordBatchReader without materializing all batches
167+
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
168+
for batch in reader:
169+
... # process each batch as it is produced
170+
171+
See :doc:`../io/arrow` for additional details on the Arrow interface.
172+
152173
HTML Rendering
153174
--------------
154175

examples/datafusion-ffi-example/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ version = "0.2.0"
2121
edition = "2021"
2222

2323
[dependencies]
24-
datafusion = { version = "49.0.2" }
25-
datafusion-ffi = { version = "49.0.2" }
24+
datafusion = { version = "49.0.1" }
25+
datafusion-ffi = { version = "49.0.1" }
2626
pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] }
2727
arrow = { version = "55.0.0" }
2828
arrow-array = { version = "55.0.0" }

python/datafusion/context.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,37 @@ def from_polars(self, data: pl.DataFrame, name: str | None = None) -> DataFrame:
731731
"""
732732
return DataFrame(self.ctx.from_polars(data, name))
733733

734+
def range(
735+
self,
736+
start: int,
737+
stop: int | None = None,
738+
step: int = 1,
739+
partitions: int | None = None,
740+
) -> DataFrame:
741+
"""Create a DataFrame containing a sequence of numbers.
742+
743+
This is backed by DataFusion's ``range`` table function, which generates
744+
values lazily and therefore does not materialize the full range in
745+
memory. When ``stop`` is omitted, ``start`` is treated as the stop value
746+
and the sequence begins at zero.
747+
748+
Args:
749+
start: Starting value for the sequence or the exclusive stop if
750+
``stop`` is ``None``.
751+
stop: Exclusive upper bound of the sequence.
752+
step: Increment between successive values.
753+
partitions: Optional number of partitions for the generated data.
754+
755+
Returns:
756+
DataFrame yielding the requested range of values.
757+
"""
758+
if stop is None:
759+
start, stop = 0, start
760+
761+
parts = f", {int(partitions)}" if partitions is not None else ""
762+
sql = f"SELECT * FROM range({int(start)}, {int(stop)}, {int(step)}{parts})" # noqa: S608
763+
return self.sql(sql)
764+
734765
# https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
735766
# is the discussion on how we arrived at adding register_view
736767
def register_view(self, name: str, df: DataFrame) -> None:

python/datafusion/dataframe.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
TYPE_CHECKING,
2727
Any,
2828
Iterable,
29+
Iterator,
2930
Literal,
3031
Optional,
3132
Union,
@@ -1098,21 +1099,37 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram
10981099
return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls))
10991100

11001101
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
1101-
"""Export an Arrow PyCapsule Stream.
1102+
"""Export the DataFrame as an Arrow C Stream.
11021103
1103-
This will execute and collect the DataFrame. We will attempt to respect the
1104-
requested schema, but only trivial transformations will be applied such as only
1105-
returning the fields listed in the requested schema if their data types match
1106-
those in the DataFrame.
1104+
The DataFrame is executed using DataFusion's streaming APIs and exposed via
1105+
Arrow's C Stream interface. Record batches are produced incrementally, so the
1106+
full result set is never materialized in memory. When ``requested_schema`` is
1107+
provided, only straightforward projections such as column selection or
1108+
reordering are applied.
11071109
11081110
Args:
11091111
requested_schema: Attempt to provide the DataFrame using this schema.
11101112
11111113
Returns:
1112-
Arrow PyCapsule object.
1114+
Arrow PyCapsule object representing an ``ArrowArrayStream``.
11131115
"""
1116+
# ``DataFrame.__arrow_c_stream__`` in the Rust extension leverages
1117+
# ``execute_stream`` under the hood to stream batches one at a time.
11141118
return self.df.__arrow_c_stream__(requested_schema)
11151119

1120+
def __iter__(self) -> Iterator[pa.RecordBatch]:
1121+
"""Yield record batches from the DataFrame without materializing results.
1122+
1123+
This implementation streams record batches via the Arrow C Stream
1124+
interface, allowing callers such as :func:`pyarrow.Table.from_batches` to
1125+
consume results lazily. The DataFrame is executed using DataFusion's
1126+
streaming APIs so ``collect`` is never invoked.
1127+
"""
1128+
import pyarrow as pa
1129+
1130+
reader = pa.RecordBatchReader._import_from_c(self.__arrow_c_stream__())
1131+
yield from reader
1132+
11161133
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:
11171134
"""Apply a function to the current DataFrame which returns another DataFrame.
11181135

python/tests/test_dataframe.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,29 @@ def test_empty_to_arrow_table(df):
15821582
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
15831583

15841584

1585+
def test_arrow_c_stream_to_table(monkeypatch):
1586+
ctx = SessionContext()
1587+
1588+
# Create a DataFrame with two separate record batches
1589+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1590+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1591+
df = ctx.create_dataframe([[batch1], [batch2]])
1592+
1593+
# Fail if the DataFrame is pre-collected
1594+
def fail_collect(self): # pragma: no cover - failure path
1595+
msg = "collect should not be called"
1596+
raise AssertionError(msg)
1597+
1598+
monkeypatch.setattr(DataFrame, "collect", fail_collect)
1599+
1600+
table = pa.Table.from_batches(df)
1601+
expected = pa.Table.from_batches([batch1, batch2])
1602+
1603+
assert table.equals(expected)
1604+
assert table.schema == df.schema()
1605+
assert table.column("a").num_chunks == 2
1606+
1607+
15851608
def test_to_pylist(df):
15861609
# Convert datafusion dataframe to Python list
15871610
pylist = df.to_pylist()
@@ -2666,6 +2689,110 @@ def trigger_interrupt():
26662689
interrupt_thread.join(timeout=1.0)
26672690

26682691

2692+
def test_arrow_c_stream_interrupted():
2693+
"""__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
2694+
2695+
Similar to ``test_collect_interrupted`` this test issues a long running
2696+
query, but consumes the results via ``__arrow_c_stream__``. It then raises
2697+
``KeyboardInterrupt`` in the main thread and verifies that the stream
2698+
iteration stops promptly with the appropriate exception.
2699+
"""
2700+
2701+
ctx = SessionContext()
2702+
2703+
batches = []
2704+
for i in range(10):
2705+
batch = pa.RecordBatch.from_arrays(
2706+
[
2707+
pa.array(list(range(i * 1000, (i + 1) * 1000))),
2708+
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
2709+
],
2710+
names=["a", "b"],
2711+
)
2712+
batches.append(batch)
2713+
2714+
ctx.register_record_batches("t1", [batches])
2715+
ctx.register_record_batches("t2", [batches])
2716+
2717+
df = ctx.sql(
2718+
"""
2719+
WITH t1_expanded AS (
2720+
SELECT
2721+
a,
2722+
b,
2723+
CAST(a AS DOUBLE) / 1.5 AS c,
2724+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2725+
FROM t1
2726+
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2727+
),
2728+
t2_expanded AS (
2729+
SELECT
2730+
a,
2731+
b,
2732+
CAST(a AS DOUBLE) * 2.5 AS e,
2733+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2734+
FROM t2
2735+
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2736+
)
2737+
SELECT
2738+
t1.a, t1.b, t1.c, t1.d,
2739+
t2.a AS a2, t2.b AS b2, t2.e, t2.f
2740+
FROM t1_expanded t1
2741+
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2742+
WHERE t1.a > 100 AND t2.a > 100
2743+
"""
2744+
)
2745+
2746+
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
2747+
2748+
interrupted = False
2749+
interrupt_error = None
2750+
query_started = threading.Event()
2751+
max_wait_time = 5.0
2752+
2753+
def trigger_interrupt():
2754+
start_time = time.time()
2755+
while not query_started.is_set():
2756+
time.sleep(0.1)
2757+
if time.time() - start_time > max_wait_time:
2758+
msg = f"Query did not start within {max_wait_time} seconds"
2759+
raise RuntimeError(msg)
2760+
2761+
thread_id = threading.main_thread().ident
2762+
if thread_id is None:
2763+
msg = "Cannot get main thread ID"
2764+
raise RuntimeError(msg)
2765+
2766+
exception = ctypes.py_object(KeyboardInterrupt)
2767+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
2768+
ctypes.c_long(thread_id), exception
2769+
)
2770+
if res != 1:
2771+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
2772+
ctypes.c_long(thread_id), ctypes.py_object(0)
2773+
)
2774+
msg = "Failed to raise KeyboardInterrupt in main thread"
2775+
raise RuntimeError(msg)
2776+
2777+
interrupt_thread = threading.Thread(target=trigger_interrupt)
2778+
interrupt_thread.daemon = True
2779+
interrupt_thread.start()
2780+
2781+
try:
2782+
query_started.set()
2783+
# consume the reader which should block and be interrupted
2784+
reader.read_all()
2785+
except KeyboardInterrupt:
2786+
interrupted = True
2787+
except Exception as e: # pragma: no cover - unexpected errors
2788+
interrupt_error = e
2789+
2790+
if not interrupted:
2791+
pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}")
2792+
2793+
interrupt_thread.join(timeout=1.0)
2794+
2795+
26692796
def test_show_select_where_no_rows(capsys) -> None:
26702797
ctx = SessionContext()
26712798
df = ctx.sql("SELECT 1 WHERE 1=0")

python/tests/test_io.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818

1919
import pyarrow as pa
20+
import pytest
2021
from datafusion import column
2122
from datafusion.io import read_avro, read_csv, read_json, read_parquet
2223

@@ -92,3 +93,30 @@ def test_read_avro():
9293
path = Path.cwd() / "testing/data/avro/alltypes_plain.avro"
9394
avro_df = read_avro(path=path)
9495
assert avro_df is not None
96+
97+
98+
def test_arrow_c_stream_large_dataset(ctx):
99+
"""DataFrame.__arrow_c_stream__ yields batches incrementally.
100+
101+
This test constructs a DataFrame that would be far larger than available
102+
memory if materialized. The ``__arrow_c_stream__`` method should expose a
103+
stream of record batches without collecting the full dataset, so reading a
104+
handful of batches should not exhaust process memory.
105+
"""
106+
# Create a very large DataFrame using range; this would be terabytes if collected
107+
df = ctx.range(0, 1 << 40)
108+
109+
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
110+
111+
# Track RSS before consuming batches
112+
psutil = pytest.importorskip("psutil")
113+
process = psutil.Process()
114+
start_rss = process.memory_info().rss
115+
116+
for _ in range(5):
117+
batch = reader.read_next_batch()
118+
assert batch is not None
119+
assert len(batch) > 0
120+
current_rss = process.memory_info().rss
121+
# Ensure memory usage hasn't grown substantially (>50MB)
122+
assert current_rss - start_rss < 50 * 1024 * 1024

src/context.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use pyo3::prelude::*;
3434
use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider};
3535
use crate::dataframe::PyDataFrame;
3636
use crate::dataset::Dataset;
37-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
37+
use crate::errors::{py_datafusion_err, PyDataFusionResult};
3838
use crate::expr::sort_expr::PySortExpr;
3939
use crate::physical_plan::PyExecutionPlan;
4040
use crate::record_batch::PyRecordBatchStream;
@@ -45,7 +45,7 @@ use crate::udaf::PyAggregateUDF;
4545
use crate::udf::PyScalarUDF;
4646
use crate::udtf::PyTableFunction;
4747
use crate::udwf::PyWindowUDF;
48-
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
48+
use crate::utils::{get_global_ctx, spawn_stream, validate_pycapsule, wait_for_future};
4949
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5050
use datafusion::arrow::pyarrow::PyArrowType;
5151
use datafusion::arrow::record_batch::RecordBatch;
@@ -66,15 +66,13 @@ use datafusion::execution::disk_manager::DiskManagerMode;
6666
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
6767
use datafusion::execution::options::ReadOptions;
6868
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
69-
use datafusion::physical_plan::SendableRecordBatchStream;
7069
use datafusion::prelude::{
7170
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7271
};
7372
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
7473
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7574
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7675
use pyo3::IntoPyObjectExt;
77-
use tokio::task::JoinHandle;
7876

7977
/// Configuration options for a SessionContext
8078
#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
@@ -1132,12 +1130,8 @@ impl PySessionContext {
11321130
py: Python,
11331131
) -> PyDataFusionResult<PyRecordBatchStream> {
11341132
let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1135-
// create a Tokio runtime to run the async code
1136-
let rt = &get_tokio_runtime().0;
11371133
let plan = plan.plan.clone();
1138-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1139-
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1140-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
1134+
let stream = spawn_stream(py, async move { plan.execute(part, Arc::new(ctx)) })?;
11411135
Ok(PyRecordBatchStream::new(stream))
11421136
}
11431137
}

0 commit comments

Comments
 (0)