Skip to content

Commit f78e90b

Browse files
committed
refactor: improve DataFrame streaming, memory management, and error handling
- Refactor record batch streaming to use `poll_next_batch` for clearer error handling - Improve `spawn_future`/`spawn_stream` functions for better Python exception integration and code reuse - Update `datafusion` and `datafusion-ffi` dependencies to 49.0.2 - Fix PyArrow `RecordBatchReader` import to use `_import_from_c_capsule` for safer memory handling - Refactor `ArrowArrayStream` handling to use `PyCapsule` with destructor for improved memory management - Refactor projection initialization in `PyDataFrame` for clarity - Move `range` functionality into `_testing.py` helper - Rename test column in `test_table_from_batches_stream` for accuracy - Add tests for `RecordBatchReader` and enhance DataFrame stream handling
1 parent 91ccd1e commit f78e90b

File tree

10 files changed

+157
-63
lines changed

10 files changed

+157
-63
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ out-of-memory errors.
164164
import pyarrow as pa
165165
166166
# Create a PyArrow RecordBatchReader without materializing all batches
167-
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
167+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
168168
for batch in reader:
169169
... # process each batch as it is produced
170170

examples/datafusion-ffi-example/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ version = "0.2.0"
2121
edition = "2021"
2222

2323
[dependencies]
24-
datafusion = { version = "49.0.1" }
25-
datafusion-ffi = { version = "49.0.1" }
24+
datafusion = { version = "49.0.2" }
25+
datafusion-ffi = { version = "49.0.2" }
2626
pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] }
2727
arrow = { version = "55.0.0" }
2828
arrow-array = { version = "55.0.0" }

python/datafusion/_testing.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Testing-only helpers for datafusion-python.
2+
3+
This module contains utilities used by the test-suite that should not be
4+
exposed as part of the public API. Keep the implementation minimal and
5+
documented so reviewers can easily see it's test-only.
6+
"""
7+
from __future__ import annotations
8+
9+
from typing import Any
10+
11+
from .context import SessionContext
12+
13+
14+
def range_table(
15+
ctx: SessionContext,
16+
start: int,
17+
stop: int | None = None,
18+
step: int = 1,
19+
partitions: int | None = None,
20+
) -> Any:
21+
"""Create a DataFrame containing a sequence of numbers using SQL RANGE.
22+
23+
This mirrors the previous ``SessionContext.range`` convenience method but
24+
lives in a testing-only module so it doesn't expand the public surface.
25+
26+
Args:
27+
ctx: SessionContext instance to run the SQL against.
28+
start: Starting value for the sequence or exclusive stop when ``stop``
29+
is ``None``.
30+
stop: Exclusive upper bound of the sequence.
31+
step: Increment between successive values.
32+
partitions: Optional number of partitions for the generated data.
33+
34+
Returns:
35+
DataFrame produced by the range table function.
36+
"""
37+
if stop is None:
38+
start, stop = 0, start
39+
40+
parts = f", {int(partitions)}" if partitions is not None else ""
41+
sql = f"SELECT * FROM range({int(start)}, {int(stop)}, {int(step)}{parts})"
42+
return ctx.sql(sql)

python/datafusion/context.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -731,36 +731,6 @@ def from_polars(self, data: pl.DataFrame, name: str | None = None) -> DataFrame:
731731
"""
732732
return DataFrame(self.ctx.from_polars(data, name))
733733

734-
def range(
735-
self,
736-
start: int,
737-
stop: int | None = None,
738-
step: int = 1,
739-
partitions: int | None = None,
740-
) -> DataFrame:
741-
"""Create a DataFrame containing a sequence of numbers.
742-
743-
This is backed by DataFusion's ``range`` table function, which generates
744-
values lazily and therefore does not materialize the full range in
745-
memory. When ``stop`` is omitted, ``start`` is treated as the stop value
746-
and the sequence begins at zero.
747-
748-
Args:
749-
start: Starting value for the sequence or the exclusive stop if
750-
``stop`` is ``None``.
751-
stop: Exclusive upper bound of the sequence.
752-
step: Increment between successive values.
753-
partitions: Optional number of partitions for the generated data.
754-
755-
Returns:
756-
DataFrame yielding the requested range of values.
757-
"""
758-
if stop is None:
759-
start, stop = 0, start
760-
761-
parts = f", {int(partitions)}" if partitions is not None else ""
762-
sql = f"SELECT * FROM range({int(start)}, {int(stop)}, {int(step)}{parts})" # noqa: S608
763-
return self.sql(sql)
764734

765735
# https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
766736
# is the discussion on how we arrived at adding register_view

python/datafusion/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def __iter__(self) -> Iterator[pa.RecordBatch]:
11271127
"""
11281128
import pyarrow as pa
11291129

1130-
reader = pa.RecordBatchReader._import_from_c(self.__arrow_c_stream__())
1130+
reader = pa.RecordBatchReader._import_from_c_capsule(self.__arrow_c_stream__())
11311131
yield from reader
11321132

11331133
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:

python/tests/test_dataframe.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,14 @@ def fail_collect(self): # pragma: no cover - failure path
16051605
assert table.column("a").num_chunks == 2
16061606

16071607

1608+
def test_arrow_c_stream_reader(df):
1609+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
1610+
assert isinstance(reader, pa.RecordBatchReader)
1611+
table = pa.Table.from_batches(reader)
1612+
expected = pa.Table.from_batches(df.collect())
1613+
assert table.equals(expected)
1614+
1615+
16081616
def test_to_pylist(df):
16091617
# Convert datafusion dataframe to Python list
16101618
pylist = df.to_pylist()
@@ -2743,7 +2751,7 @@ def test_arrow_c_stream_interrupted():
27432751
"""
27442752
)
27452753

2746-
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
2754+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
27472755

27482756
interrupted = False
27492757
interrupt_error = None

python/tests/test_io.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
import pyarrow as pa
2020
import pytest
21-
from datafusion import column
21+
from datafusion import DataFrame, column
22+
from datafusion._testing import range_table
2223
from datafusion.io import read_avro, read_csv, read_json, read_parquet
2324

2425

@@ -104,9 +105,9 @@ def test_arrow_c_stream_large_dataset(ctx):
104105
handful of batches should not exhaust process memory.
105106
"""
106107
# Create a very large DataFrame using range; this would be terabytes if collected
107-
df = ctx.range(0, 1 << 40)
108+
df = range_table(ctx, 0, 1 << 40)
108109

109-
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
110+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
110111

111112
# Track RSS before consuming batches
112113
psutil = pytest.importorskip("psutil")
@@ -120,3 +121,17 @@ def test_arrow_c_stream_large_dataset(ctx):
120121
current_rss = process.memory_info().rss
121122
# Ensure memory usage hasn't grown substantially (>50MB)
122123
assert current_rss - start_rss < 50 * 1024 * 1024
124+
125+
126+
def test_table_from_batches_stream(ctx, monkeypatch):
127+
df = range_table(ctx, 0, 10)
128+
129+
def fail_collect(self): # pragma: no cover - failure path
130+
msg = "collect should not be called"
131+
raise AssertionError(msg)
132+
133+
monkeypatch.setattr(DataFrame, "collect", fail_collect)
134+
135+
table = pa.Table.from_batches(df)
136+
assert table.shape == (10, 1)
137+
assert table.column_names == ["value"]

src/dataframe.rs

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use std::collections::HashMap;
19-
use std::ffi::CString;
19+
use std::ffi::{c_void, CStr, CString};
2020
use std::sync::Arc;
2121

2222
use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
@@ -39,6 +39,7 @@ use datafusion::prelude::*;
3939
use datafusion_ffi::table_provider::FFI_TableProvider;
4040
use futures::{StreamExt, TryStreamExt};
4141
use pyo3::exceptions::PyValueError;
42+
use pyo3::ffi;
4243
use pyo3::prelude::*;
4344
use pyo3::pybacked::PyBackedStr;
4445
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
@@ -47,7 +48,7 @@ use crate::catalog::PyTable;
4748
use crate::errors::{py_datafusion_err, PyDataFusionError};
4849
use crate::expr::sort_expr::to_sort_expressions;
4950
use crate::physical_plan::PyExecutionPlan;
50-
use crate::record_batch::PyRecordBatchStream;
51+
use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
5152
use crate::sql::logical::PyLogicalPlan;
5253
use crate::utils::{
5354
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_stream, spawn_streams,
@@ -58,6 +59,21 @@ use crate::{
5859
expr::{sort_expr::PySortExpr, PyExpr},
5960
};
6061

62+
#[allow(clippy::manual_c_str_literals)]
63+
static ARROW_STREAM_NAME: &CStr =
64+
unsafe { CStr::from_bytes_with_nul_unchecked(b"arrow_array_stream\0") };
65+
66+
unsafe extern "C" fn drop_stream(capsule: *mut ffi::PyObject) {
67+
if capsule.is_null() {
68+
return;
69+
}
70+
let stream_ptr =
71+
ffi::PyCapsule_GetPointer(capsule, ARROW_STREAM_NAME.as_ptr()) as *mut FFI_ArrowArrayStream;
72+
if !stream_ptr.is_null() {
73+
drop(Box::from_raw(stream_ptr));
74+
}
75+
}
76+
6177
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
6278
// - we have not decided on the table_provider approach yet
6379
// this is an interim implementation
@@ -374,11 +390,11 @@ impl Iterator for DataFrameStreamReader {
374390
// respecting Python signal handling (e.g. ``KeyboardInterrupt``).
375391
// This mirrors the behaviour of other synchronous wrappers and
376392
// prevents blocking indefinitely when a Python interrupt is raised.
377-
let fut = self.stream.next();
393+
let fut = poll_next_batch(&mut self.stream);
378394
let result = Python::with_gil(|py| wait_for_future(py, fut));
379395

380396
match result {
381-
Ok(Some(Ok(batch))) => {
397+
Ok(Ok(Some(batch))) => {
382398
let batch = if let Some(ref schema) = self.projection {
383399
match record_batch_into_schema(batch, schema.as_ref()) {
384400
Ok(b) => b,
@@ -389,8 +405,8 @@ impl Iterator for DataFrameStreamReader {
389405
};
390406
Some(Ok(batch))
391407
}
392-
Ok(Some(Err(e))) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
393-
Ok(None) => None,
408+
Ok(Ok(None)) => None,
409+
Ok(Err(e)) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
394410
Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
395411
}
396412
}
@@ -943,7 +959,7 @@ impl PyDataFrame {
943959
projection = Some(Arc::new(schema.clone()));
944960
}
945961

946-
let schema_ref = projection.clone().unwrap_or_else(|| Arc::new(schema));
962+
let schema_ref = Arc::new(schema.clone());
947963

948964
let reader = DataFrameStreamReader {
949965
stream,
@@ -952,9 +968,26 @@ impl PyDataFrame {
952968
};
953969
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
954970

955-
let ffi_stream = FFI_ArrowArrayStream::new(reader);
956-
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
957-
PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
971+
let stream = Box::new(FFI_ArrowArrayStream::new(reader));
972+
let stream_ptr = Box::into_raw(stream);
973+
assert!(
974+
!stream_ptr.is_null(),
975+
"ArrowArrayStream pointer should never be null"
976+
);
977+
let capsule = unsafe {
978+
ffi::PyCapsule_New(
979+
stream_ptr as *mut c_void,
980+
ARROW_STREAM_NAME.as_ptr(),
981+
Some(drop_stream),
982+
)
983+
};
984+
if capsule.is_null() {
985+
unsafe { drop(Box::from_raw(stream_ptr)) };
986+
Err(PyErr::fetch(py).into())
987+
} else {
988+
let any = unsafe { Bound::from_owned_ptr(py, capsule) };
989+
Ok(any.downcast_into::<PyCapsule>().unwrap())
990+
}
958991
}
959992

960993
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {

src/record_batch.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,20 @@ impl PyRecordBatchStream {
8484
}
8585
}
8686

87+
pub(crate) async fn poll_next_batch(
88+
stream: &mut SendableRecordBatchStream,
89+
) -> datafusion::error::Result<Option<RecordBatch>> {
90+
stream.next().await.transpose()
91+
}
92+
8793
async fn next_stream(
8894
stream: Arc<Mutex<SendableRecordBatchStream>>,
8995
sync: bool,
9096
) -> PyResult<PyRecordBatch> {
9197
let mut stream = stream.lock().await;
92-
match stream.next().await {
93-
Some(Ok(batch)) => Ok(batch.into()),
94-
Some(Err(e)) => Err(PyDataFusionError::from(e))?,
95-
None => {
98+
match poll_next_batch(&mut stream).await {
99+
Ok(Some(batch)) => Ok(batch.into()),
100+
Ok(None) => {
96101
// Depending on whether the iteration is sync or not, we raise either a
97102
// StopIteration or a StopAsyncIteration
98103
if sync {
@@ -101,5 +106,6 @@ async fn next_stream(
101106
Err(PyStopAsyncIteration::new_err("stream exhausted"))
102107
}
103108
}
109+
Err(e) => Err(PyDataFusionError::from(e))?,
104110
}
105111
}

src/utils.rs

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,42 @@ where
8585
})
8686
}
8787

88+
/// Spawn a [`Future`] on the Tokio runtime and wait for completion
89+
/// while respecting Python signal handling.
90+
pub(crate) fn spawn_future<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
91+
where
92+
F: Future<Output = datafusion::common::Result<T>> + Send + 'static,
93+
T: Send + 'static,
94+
{
95+
let rt = &get_tokio_runtime().0;
96+
let handle: JoinHandle<datafusion::common::Result<T>> = rt.spawn(fut);
97+
// Wait for the join handle while respecting Python signal handling.
98+
// We handle errors in two steps so `?` maps the error types correctly:
99+
// 1) convert any Python-related error from `wait_for_future` into `PyDataFusionError`
100+
// 2) convert any DataFusion error (inner result) into `PyDataFusionError`
101+
let inner_result = wait_for_future(py, async {
102+
// handle.await yields `Result<datafusion::common::Result<T>, JoinError>`
103+
// map JoinError into a DataFusion error so the async block returns
104+
// `datafusion::common::Result<T>` (i.e. Result<T, DataFusionError>)
105+
match handle.await {
106+
Ok(inner) => inner,
107+
Err(join_err) => Err(to_datafusion_err(join_err)),
108+
}
109+
})?; // converts PyErr -> PyDataFusionError
110+
111+
// `inner_result` is `datafusion::common::Result<T>`; use `?` to convert
112+
// the inner DataFusion error into `PyDataFusionError` via `From` and
113+
// return the inner `T` on success.
114+
Ok(inner_result?)
115+
}
116+
88117
/// Spawn a [`SendableRecordBatchStream`] on the Tokio runtime and wait for completion
89118
/// while respecting Python signal handling.
90119
pub(crate) fn spawn_stream<F>(py: Python, fut: F) -> PyDataFusionResult<SendableRecordBatchStream>
91120
where
92121
F: Future<Output = datafusion::common::Result<SendableRecordBatchStream>> + Send + 'static,
93122
{
94-
let rt = &get_tokio_runtime().0;
95-
let handle: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> = rt.spawn(fut);
96-
Ok(wait_for_future(py, async {
97-
handle.await.map_err(to_datafusion_err)
98-
})???)
123+
spawn_future(py, fut)
99124
}
100125

101126
/// Spawn a partitioned [`SendableRecordBatchStream`] on the Tokio runtime and wait for completion
@@ -107,12 +132,7 @@ pub(crate) fn spawn_streams<F>(
107132
where
108133
F: Future<Output = datafusion::common::Result<Vec<SendableRecordBatchStream>>> + Send + 'static,
109134
{
110-
let rt = &get_tokio_runtime().0;
111-
let handle: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
112-
rt.spawn(fut);
113-
Ok(wait_for_future(py, async {
114-
handle.await.map_err(to_datafusion_err)
115-
})???)
135+
spawn_future(py, fut)
116136
}
117137

118138
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {

0 commit comments

Comments
 (0)