Skip to content

Commit 75dcc45

Browse files
committed
Improve table readout of a dataframe in jupyter notebooks by making the table scrollable and displaying the first record batch up to 2MB
1 parent 3584bec commit 75dcc45

File tree

1 file changed

+70
-29
lines changed

1 file changed

+70
-29
lines changed

src/dataframe.rs

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ use datafusion::arrow::util::pretty;
3030
use datafusion::common::UnnestOptions;
3131
use datafusion::config::{CsvOptions, TableParquetOptions};
3232
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
33+
use datafusion::error::DataFusionError;
3334
use datafusion::execution::SendableRecordBatchStream;
3435
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
3536
use datafusion::prelude::*;
37+
use futures::{StreamExt, TryStreamExt};
3638
use pyo3::exceptions::PyValueError;
3739
use pyo3::prelude::*;
3840
use pyo3::pybacked::PyBackedStr;
@@ -50,6 +52,8 @@ use crate::{
5052
expr::{sort_expr::PySortExpr, PyExpr},
5153
};
5254

55+
const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
56+
5357
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
5458
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
5559
/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
@@ -100,46 +104,57 @@ impl PyDataFrame {
100104
}
101105

102106
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
103-
let mut html_str = "<table border='1'>\n".to_string();
107+
let (batch, mut has_more) =
108+
wait_for_future(py, get_first_record_batch(self.df.as_ref().clone()))?;
109+
let Some(batch) = batch else {
110+
return Ok("No data to display".to_string());
111+
};
104112

105-
let df = self.df.as_ref().clone().limit(0, Some(10))?;
106-
let batches = wait_for_future(py, df.collect())?;
113+
let mut html_str = "
114+
<div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
115+
<table style=\"border-collapse: collapse; min-width: 100%\">
116+
<thead>\n".to_string();
107117

108-
if batches.is_empty() {
109-
html_str.push_str("</table>\n");
110-
return Ok(html_str);
111-
}
112-
113-
let schema = batches[0].schema();
118+
let schema = batch.schema();
114119

115120
let mut header = Vec::new();
116121
for field in schema.fields() {
117-
header.push(format!("<th>{}</td>", field.name()));
122+
header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name()));
118123
}
119124
let header_str = header.join("");
120-
html_str.push_str(&format!("<tr>{}</tr>\n", header_str));
125+
html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str));
126+
127+
let formatters = batch
128+
.columns()
129+
.iter()
130+
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
131+
.map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))))
132+
.collect::<Result<Vec<_>, _>>()?;
133+
134+
let batch_size = batch.get_array_memory_size();
135+
let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
136+
true => {
137+
has_more = true;
138+
let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32;
139+
(batch.num_rows() as f32 * ratio).round() as usize
140+
}
141+
false => batch.num_rows(),
142+
};
121143

122-
for batch in batches {
123-
let formatters = batch
124-
.columns()
125-
.iter()
126-
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
127-
.map(|c| {
128-
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
129-
})
130-
.collect::<Result<Vec<_>, _>>()?;
131-
132-
for row in 0..batch.num_rows() {
133-
let mut cells = Vec::new();
134-
for formatter in &formatters {
135-
cells.push(format!("<td>{}</td>", formatter.value(row)));
136-
}
137-
let row_str = cells.join("");
138-
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
144+
for row in 0..num_rows_to_display {
145+
let mut cells = Vec::new();
146+
for formatter in &formatters {
147+
cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(row)));
139148
}
149+
let row_str = cells.join("");
150+
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
140151
}
141152

142-
html_str.push_str("</table>\n");
153+
html_str.push_str("</tbody></table></div>\n");
154+
155+
if has_more {
156+
html_str.push_str("Data truncated due to size.");
157+
}
143158

144159
Ok(html_str)
145160
}
@@ -732,3 +747,29 @@ fn record_batch_into_schema(
732747

733748
RecordBatch::try_new(schema, data_arrays)
734749
}
750+
751+
/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
752+
/// It additionally returns a bool, which indicates if there are more record batches available.
753+
/// We do this so we can determine if we should indicate to the user that the data has been
754+
/// truncated.
755+
async fn get_first_record_batch(
756+
df: DataFrame,
757+
) -> Result<(Option<RecordBatch>, bool), DataFusionError> {
758+
let mut stream = df.execute_stream().await?;
759+
loop {
760+
let rb = match stream.next().await {
761+
None => return Ok((None, false)),
762+
Some(Ok(r)) => r,
763+
Some(Err(e)) => return Err(e),
764+
};
765+
766+
if rb.num_rows() > 0 {
767+
let has_more = match stream.try_next().await {
768+
Ok(None) => false, // reached end
769+
Ok(Some(_)) => true,
770+
Err(_) => false, // Stream disconnected
771+
};
772+
return Ok((Some(rb), has_more));
773+
}
774+
}
775+
}

0 commit comments

Comments
 (0)