Skip to content

Commit 51fdc12

Browse files
committed
Fix bug by moving take(10) on underlying batches rather than the previous stream of batches. Also refactor _repr_html_.
1 parent b218692 commit 51fdc12

File tree

1 file changed

+39
-12
lines changed

1 file changed

+39
-12
lines changed

src/dataframe.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3333
use datafusion::execution::SendableRecordBatchStream;
3434
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
3535
use datafusion::prelude::*;
36-
use futures::StreamExt;
36+
use futures::{future, StreamExt};
3737
use pyo3::exceptions::PyValueError;
3838
use pyo3::prelude::*;
3939
use pyo3::pybacked::PyBackedStr;
@@ -92,15 +92,7 @@ impl PyDataFrame {
9292

9393
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
9494
let df = self.df.as_ref().clone();
95-
96-
let stream = wait_for_future(py, df.execute_stream()).map_err(py_datafusion_err)?;
97-
98-
let batches: Vec<RecordBatch> = wait_for_future(
99-
py,
100-
stream.take(10).collect::<Vec<_>>())
101-
.into_iter()
102-
.collect::<Result<Vec<_>,_>>()?;
103-
95+
let batches: Vec<RecordBatch> = get_batches(py, df, 10)?;
10496
let batches_as_string = pretty::pretty_format_batches(&batches);
10597
match batches_as_string {
10698
Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
@@ -111,8 +103,8 @@ impl PyDataFrame {
111103
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
112104
let mut html_str = "<table border='1'>\n".to_string();
113105

114-
let df = self.df.as_ref().clone().limit(0, Some(10))?;
115-
let batches = wait_for_future(py, df.collect())?;
106+
let df = self.df.as_ref().clone();
107+
let batches: Vec<RecordBatch> = get_batches(py, df, 10)?;
116108

117109
if batches.is_empty() {
118110
html_str.push_str("</table>\n");
@@ -742,3 +734,38 @@ fn record_batch_into_schema(
742734

743735
RecordBatch::try_new(schema, data_arrays)
744736
}
737+
738+
fn get_batches(
739+
py: Python,
740+
df: DataFrame,
741+
max_rows: usize,
742+
) -> Result<Vec<RecordBatch>, PyDataFusionError> {
743+
let partitioned_stream = wait_for_future(py, df.execute_stream_partitioned()).map_err(py_datafusion_err)?;
744+
let stream = futures::stream::iter(partitioned_stream).flatten();
745+
wait_for_future(
746+
py,
747+
stream
748+
.scan(0, |state, x| {
749+
let total = *state;
750+
if total >= max_rows {
751+
future::ready(None)
752+
} else {
753+
match x {
754+
Ok(batch) => {
755+
if total + batch.num_rows() <= max_rows {
756+
*state = total + batch.num_rows();
757+
future::ready(Some(Ok(batch)))
758+
} else {
759+
*state = max_rows;
760+
future::ready(Some(Ok(batch.slice(0, max_rows - total))))
761+
}
762+
}
763+
Err(err) => future::ready(Some(Err(PyDataFusionError::from(err)))),
764+
}
765+
}
766+
})
767+
.collect::<Vec<_>>(),
768+
)
769+
.into_iter()
770+
.collect::<Result<Vec<_>, _>>()
771+
}

0 commit comments

Comments
 (0)