Skip to content

Commit e322521

Browse files
committed
feat: enhance DataFrame streaming and improve robustness, tests, and docs
- Preserve partition order in DataFrame streaming and update related tests - Add tests for record batch ordering and DataFrame batch iteration - Improve `drop_stream` to correctly handle PyArrow ownership transfer and null pointers - Replace `assert` with `debug_assert` for safer ArrowArrayStream validation - Add documentation for `poll_next_batch` in PyRecordBatchStream - Refactor tests to use `fail_collect` fixture for DataFrame collect - Refactor `range_table` return type to `DataFrame` for clearer type hints - Minor cleanup in SessionContext (remove extra blank line)
1 parent f78e90b commit e322521

File tree

9 files changed

+135
-63
lines changed

9 files changed

+135
-63
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ out-of-memory errors.
168168
for batch in reader:
169169
... # process each batch as it is produced
170170
171+
DataFrames are also iterable, yielding :class:`pyarrow.RecordBatch` objects
172+
lazily so you can loop over results directly:
173+
174+
.. code-block:: python
175+
176+
for batch in df:
177+
... # process each batch as it is produced
178+
171179
See :doc:`../io/arrow` for additional details on the Arrow interface.
172180

173181
HTML Rendering

python/datafusion/_testing.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,24 @@
44
exposed as part of the public API. Keep the implementation minimal and
55
documented so reviewers can easily see it's test-only.
66
"""
7+
78
from __future__ import annotations
89

9-
from typing import Any
10+
from typing import TYPE_CHECKING
1011

1112
from .context import SessionContext
1213

14+
if TYPE_CHECKING:
15+
from datafusion import DataFrame
16+
1317

1418
def range_table(
1519
ctx: SessionContext,
1620
start: int,
1721
stop: int | None = None,
1822
step: int = 1,
1923
partitions: int | None = None,
20-
) -> Any:
24+
) -> DataFrame:
2125
"""Create a DataFrame containing a sequence of numbers using SQL RANGE.
2226
2327
This mirrors the previous ``SessionContext.range`` convenience method but
@@ -38,5 +42,5 @@ def range_table(
3842
start, stop = 0, start
3943

4044
parts = f", {int(partitions)}" if partitions is not None else ""
41-
sql = f"SELECT * FROM range({int(start)}, {int(stop)}, {int(step)}{parts})"
45+
sql = f"SELECT * FROM range({int(start)}, {int(stop)}, {int(step)}{parts})" # noqa: S608
4246
return ctx.sql(sql)

python/datafusion/context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,6 @@ def from_polars(self, data: pl.DataFrame, name: str | None = None) -> DataFrame:
731731
"""
732732
return DataFrame(self.ctx.from_polars(data, name))
733733

734-
735734
# https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
736735
# is the discussion on how we arrived at adding register_view
737736
def register_view(self, name: str, df: DataFrame) -> None:

python/datafusion/dataframe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ def __init__(
290290
class DataFrame:
291291
"""Two dimensional table representation of data.
292292
293+
DataFrame objects are iterable; iterating over a DataFrame yields
294+
:class:`pyarrow.RecordBatch` instances lazily.
295+
293296
See :ref:`user_guide_concepts` in the online documentation for more information.
294297
"""
295298

@@ -1114,7 +1117,8 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
11141117
Arrow PyCapsule object representing an ``ArrowArrayStream``.
11151118
"""
11161119
# ``DataFrame.__arrow_c_stream__`` in the Rust extension leverages
1117-
# ``execute_stream`` under the hood to stream batches one at a time.
1120+
# ``execute_stream_partitioned`` under the hood to stream batches while
1121+
# preserving the original partition order.
11181122
return self.df.__arrow_c_stream__(requested_schema)
11191123

11201124
def __iter__(self) -> Iterator[pa.RecordBatch]:
@@ -1123,7 +1127,8 @@ def __iter__(self) -> Iterator[pa.RecordBatch]:
11231127
This implementation streams record batches via the Arrow C Stream
11241128
interface, allowing callers such as :func:`pyarrow.Table.from_batches` to
11251129
consume results lazily. The DataFrame is executed using DataFusion's
1126-
streaming APIs so ``collect`` is never invoked.
1130+
partitioned streaming APIs so ``collect`` is never invoked and batch
1131+
order across partitions is preserved.
11271132
"""
11281133
import pyarrow as pa
11291134

python/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pyarrow as pa
1919
import pytest
20-
from datafusion import SessionContext
20+
from datafusion import DataFrame, SessionContext
2121
from pyarrow.csv import write_csv
2222

2323

@@ -49,3 +49,12 @@ def database(ctx, tmp_path):
4949
delimiter=",",
5050
schema_infer_max_records=10,
5151
)
52+
53+
54+
@pytest.fixture
55+
def fail_collect(monkeypatch):
56+
def _fail_collect(self, *args, **kwargs): # pragma: no cover - failure path
57+
msg = "collect should not be called"
58+
raise AssertionError(msg)
59+
60+
monkeypatch.setattr(DataFrame, "collect", _fail_collect)

python/tests/test_dataframe.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,27 +1582,51 @@ def test_empty_to_arrow_table(df):
15821582
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
15831583

15841584

1585-
def test_arrow_c_stream_to_table(monkeypatch):
1585+
def test_iter_batches_dataframe(fail_collect):
1586+
ctx = SessionContext()
1587+
1588+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1589+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1590+
df = ctx.create_dataframe([[batch1], [batch2]])
1591+
1592+
expected = [batch1, batch2]
1593+
for got, exp in zip(df, expected):
1594+
assert got.equals(exp)
1595+
1596+
1597+
def test_arrow_c_stream_to_table(fail_collect):
15861598
ctx = SessionContext()
15871599

15881600
# Create a DataFrame with two separate record batches
15891601
batch1 = pa.record_batch([pa.array([1])], names=["a"])
15901602
batch2 = pa.record_batch([pa.array([2])], names=["a"])
15911603
df = ctx.create_dataframe([[batch1], [batch2]])
15921604

1593-
# Fail if the DataFrame is pre-collected
1594-
def fail_collect(self): # pragma: no cover - failure path
1595-
msg = "collect should not be called"
1596-
raise AssertionError(msg)
1605+
table = pa.Table.from_batches(df)
1606+
batches = table.to_batches()
1607+
1608+
assert len(batches) == 2
1609+
assert batches[0].equals(batch1)
1610+
assert batches[1].equals(batch2)
1611+
assert table.schema == df.schema()
1612+
assert table.column("a").num_chunks == 2
1613+
1614+
1615+
def test_arrow_c_stream_order():
1616+
ctx = SessionContext()
15971617

1598-
monkeypatch.setattr(DataFrame, "collect", fail_collect)
1618+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1619+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1620+
1621+
df = ctx.create_dataframe([[batch1, batch2]])
15991622

16001623
table = pa.Table.from_batches(df)
16011624
expected = pa.Table.from_batches([batch1, batch2])
16021625

16031626
assert table.equals(expected)
1604-
assert table.schema == df.schema()
1605-
assert table.column("a").num_chunks == 2
1627+
col = table.column("a")
1628+
assert col.chunk(0)[0].as_py() == 1
1629+
assert col.chunk(1)[0].as_py() == 2
16061630

16071631

16081632
def test_arrow_c_stream_reader(df):

python/tests/test_io.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import pyarrow as pa
2020
import pytest
21-
from datafusion import DataFrame, column
21+
from datafusion import column
2222
from datafusion._testing import range_table
2323
from datafusion.io import read_avro, read_csv, read_json, read_parquet
2424

@@ -123,15 +123,9 @@ def test_arrow_c_stream_large_dataset(ctx):
123123
assert current_rss - start_rss < 50 * 1024 * 1024
124124

125125

126-
def test_table_from_batches_stream(ctx, monkeypatch):
126+
def test_table_from_batches_stream(ctx, fail_collect):
127127
df = range_table(ctx, 0, 10)
128128

129-
def fail_collect(self): # pragma: no cover - failure path
130-
msg = "collect should not be called"
131-
raise AssertionError(msg)
132-
133-
monkeypatch.setattr(DataFrame, "collect", fail_collect)
134-
135129
table = pa.Table.from_batches(df)
136130
assert table.shape == (10, 1)
137131
assert table.column_names == ["value"]

src/dataframe.rs

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,25 @@ unsafe extern "C" fn drop_stream(capsule: *mut ffi::PyObject) {
6767
if capsule.is_null() {
6868
return;
6969
}
70-
let stream_ptr =
71-
ffi::PyCapsule_GetPointer(capsule, ARROW_STREAM_NAME.as_ptr()) as *mut FFI_ArrowArrayStream;
72-
if !stream_ptr.is_null() {
73-
drop(Box::from_raw(stream_ptr));
70+
71+
// When PyArrow imports this capsule it steals the raw stream pointer and
72+
// sets the capsule's internal pointer to NULL. In that case
73+
// `PyCapsule_IsValid` returns 0 and this destructor must not drop the
74+
// stream as ownership has been transferred to PyArrow. If the capsule was
75+
// never imported, the pointer remains valid and we are responsible for
76+
// freeing the stream here.
77+
if ffi::PyCapsule_IsValid(capsule, ARROW_STREAM_NAME.as_ptr()) == 1 {
78+
let stream_ptr = ffi::PyCapsule_GetPointer(capsule, ARROW_STREAM_NAME.as_ptr())
79+
as *mut FFI_ArrowArrayStream;
80+
if !stream_ptr.is_null() {
81+
drop(Box::from_raw(stream_ptr));
82+
}
7483
}
84+
85+
// `PyCapsule_GetPointer` sets a Python error on failure. Clear it only
86+
// after the stream has been released (or determined to be owned
87+
// elsewhere).
88+
ffi::PyErr_Clear();
7589
}
7690

7791
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
@@ -369,50 +383,59 @@ impl PyDataFrame {
369383
Ok(html_str)
370384
}
371385
}
372-
/// Synchronous wrapper around a [`SendableRecordBatchStream`] used for
373-
/// the `__arrow_c_stream__` implementation.
386+
387+
/// Synchronous wrapper around partitioned [`SendableRecordBatchStream`]s used
388+
/// for the `__arrow_c_stream__` implementation.
374389
///
375-
/// It uses `runtime.block_on` to consume the underlying async stream,
376-
/// providing synchronous iteration. When a `projection` is set, each
377-
/// batch is converted via `record_batch_into_schema` to apply schema
378-
/// changes per batch.
379-
struct DataFrameStreamReader {
380-
stream: SendableRecordBatchStream,
390+
/// It drains each partition's stream sequentially, yielding record batches in
391+
/// their original partition order. When a `projection` is set, each batch is
392+
/// converted via `record_batch_into_schema` to apply schema changes per batch.
393+
struct PartitionedDataFrameStreamReader {
394+
streams: Vec<SendableRecordBatchStream>,
381395
schema: SchemaRef,
382396
projection: Option<SchemaRef>,
397+
current: usize,
383398
}
384399

385-
impl Iterator for DataFrameStreamReader {
400+
impl Iterator for PartitionedDataFrameStreamReader {
386401
type Item = Result<RecordBatch, ArrowError>;
387402

388403
fn next(&mut self) -> Option<Self::Item> {
389-
// Use wait_for_future to poll the underlying async stream while
390-
// respecting Python signal handling (e.g. ``KeyboardInterrupt``).
391-
// This mirrors the behaviour of other synchronous wrappers and
392-
// prevents blocking indefinitely when a Python interrupt is raised.
393-
let fut = poll_next_batch(&mut self.stream);
394-
let result = Python::with_gil(|py| wait_for_future(py, fut));
395-
396-
match result {
397-
Ok(Ok(Some(batch))) => {
398-
let batch = if let Some(ref schema) = self.projection {
399-
match record_batch_into_schema(batch, schema.as_ref()) {
400-
Ok(b) => b,
401-
Err(e) => return Some(Err(e)),
402-
}
403-
} else {
404-
batch
405-
};
406-
Some(Ok(batch))
404+
while self.current < self.streams.len() {
405+
let stream = &mut self.streams[self.current];
406+
let fut = poll_next_batch(stream);
407+
let result = Python::with_gil(|py| wait_for_future(py, fut));
408+
409+
match result {
410+
Ok(Ok(Some(batch))) => {
411+
let batch = if let Some(ref schema) = self.projection {
412+
match record_batch_into_schema(batch, schema.as_ref()) {
413+
Ok(b) => b,
414+
Err(e) => return Some(Err(e)),
415+
}
416+
} else {
417+
batch
418+
};
419+
return Some(Ok(batch));
420+
}
421+
Ok(Ok(None)) => {
422+
self.current += 1;
423+
continue;
424+
}
425+
Ok(Err(e)) => {
426+
return Some(Err(ArrowError::ExternalError(Box::new(e))));
427+
}
428+
Err(e) => {
429+
return Some(Err(ArrowError::ExternalError(Box::new(e))));
430+
}
407431
}
408-
Ok(Ok(None)) => None,
409-
Ok(Err(e)) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
410-
Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
411432
}
433+
434+
None
412435
}
413436
}
414437

415-
impl RecordBatchReader for DataFrameStreamReader {
438+
impl RecordBatchReader for PartitionedDataFrameStreamReader {
416439
fn schema(&self) -> SchemaRef {
417440
self.schema.clone()
418441
}
@@ -944,7 +967,7 @@ impl PyDataFrame {
944967
requested_schema: Option<Bound<'py, PyCapsule>>,
945968
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
946969
let df = self.df.as_ref().clone();
947-
let stream = spawn_stream(py, async move { df.execute_stream().await })?;
970+
let streams = spawn_streams(py, async move { df.execute_stream_partitioned().await })?;
948971

949972
let mut schema: Schema = self.df.schema().to_owned().into();
950973
let mut projection: Option<SchemaRef> = None;
@@ -961,19 +984,24 @@ impl PyDataFrame {
961984

962985
let schema_ref = Arc::new(schema.clone());
963986

964-
let reader = DataFrameStreamReader {
965-
stream,
987+
let reader = PartitionedDataFrameStreamReader {
988+
streams,
966989
schema: schema_ref,
967990
projection,
991+
current: 0,
968992
};
969993
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
970994

971995
let stream = Box::new(FFI_ArrowArrayStream::new(reader));
972996
let stream_ptr = Box::into_raw(stream);
973-
assert!(
997+
debug_assert!(
974998
!stream_ptr.is_null(),
975-
"ArrowArrayStream pointer should never be null"
999+
"ArrowArrayStream pointer should never be null",
9761000
);
1001+
// The returned capsule allows zero-copy hand-off to PyArrow. When
1002+
// PyArrow imports the capsule it assumes ownership of the stream and
1003+
// nulls out the capsule's internal pointer so `drop_stream` knows not to
1004+
// free it.
9771005
let capsule = unsafe {
9781006
ffi::PyCapsule_New(
9791007
stream_ptr as *mut c_void,

src/record_batch.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ impl PyRecordBatchStream {
8484
}
8585
}
8686

87+
/// Polls the next batch from a `SendableRecordBatchStream`, converting the `Option<Result<_>>` form.
8788
pub(crate) async fn poll_next_batch(
8889
stream: &mut SendableRecordBatchStream,
8990
) -> datafusion::error::Result<Option<RecordBatch>> {

0 commit comments

Comments
 (0)