Skip to content

Commit c285a09

Browse files
committed
_repr_ and _html_repr_ show '... and additional rows' message for truncated outputs
1 parent d635d56 commit c285a09

File tree

1 file changed

+77
-18
lines changed

1 file changed

+77
-18
lines changed

src/dataframe.rs

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,59 +90,108 @@ impl PyDataFrame {
9090
}
9191

9292
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
93-
let df = self.df.as_ref().clone().limit(0, Some(10))?;
93+
// Get 11 rows to check if there are more than 10
94+
let df = self.df.as_ref().clone().limit(0, Some(11))?;
9495
let batches = wait_for_future(py, df.collect())?;
95-
let batches_as_string = pretty::pretty_format_batches(&batches);
96+
let num_rows = batches.iter().map(|batch| batch.num_rows()).sum::<usize>();
97+
98+
// Flatten batches into a single batch for the first 10 rows
99+
let mut all_rows = Vec::new();
100+
let mut total_rows = 0;
101+
102+
for batch in &batches {
103+
let num_rows_to_take = if total_rows + batch.num_rows() > 10 {
104+
10 - total_rows
105+
} else {
106+
batch.num_rows()
107+
};
108+
109+
if num_rows_to_take > 0 {
110+
let sliced_batch = batch.slice(0, num_rows_to_take);
111+
all_rows.push(sliced_batch);
112+
total_rows += num_rows_to_take;
113+
}
114+
115+
if total_rows >= 10 {
116+
break;
117+
}
118+
}
119+
120+
let batches_as_string = pretty::pretty_format_batches(&all_rows);
121+
96122
match batches_as_string {
97-
Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
123+
Ok(batch) => {
124+
if num_rows > 10 {
125+
Ok(format!("DataFrame()\n{batch}\n... and additional rows"))
126+
} else {
127+
Ok(format!("DataFrame()\n{batch}"))
128+
}
129+
}
98130
Err(err) => Ok(format!("Error: {:?}", err.to_string())),
99131
}
100132
}
133+
134+
101135

102136
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
103137
let mut html_str = "<table border='1'>\n".to_string();
104-
105-
let df = self.df.as_ref().clone().limit(0, Some(10))?;
138+
139+
// Limit to the first 11 rows
140+
let df = self.df.as_ref().clone().limit(0, Some(11))?;
106141
let batches = wait_for_future(py, df.collect())?;
107-
142+
143+
// If there are no rows, close the table and return
108144
if batches.is_empty() {
109145
html_str.push_str("</table>\n");
110146
return Ok(html_str);
111147
}
112-
148+
149+
// Get schema for headers
113150
let schema = batches[0].schema();
114-
151+
115152
let mut header = Vec::new();
116153
for field in schema.fields() {
117-
header.push(format!("<th>{}</td>", field.name()));
154+
header.push(format!("<th>{}</th>", field.name()));
118155
}
119156
let header_str = header.join("");
120157
html_str.push_str(&format!("<tr>{}</tr>\n", header_str));
121-
122-
for batch in batches {
158+
159+
// Flatten rows and format them as HTML
160+
let mut total_rows = 0;
161+
for batch in &batches {
162+
total_rows += batch.num_rows();
123163
let formatters = batch
124164
.columns()
125165
.iter()
126166
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
127-
.map(|c| {
128-
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
129-
})
167+
.map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))))
130168
.collect::<Result<Vec<_>, _>>()?;
131-
132-
for row in 0..batch.num_rows() {
169+
170+
let num_rows_to_render = if total_rows > 10 { 10 } else { batch.num_rows() };
171+
172+
for row in 0..num_rows_to_render {
133173
let mut cells = Vec::new();
134174
for formatter in &formatters {
135175
cells.push(format!("<td>{}</td>", formatter.value(row)));
136176
}
137177
let row_str = cells.join("");
138178
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
139179
}
140-
}
141180

181+
if total_rows >= 10 {
182+
break;
183+
}
184+
}
185+
186+
if total_rows > 10 {
187+
html_str.push_str("<tr><td colspan=\"100%\">... and additional rows</td></tr>\n");
188+
}
189+
142190
html_str.push_str("</table>\n");
143-
191+
144192
Ok(html_str)
145193
}
194+
146195

147196
/// Calculate summary statistics for a DataFrame
148197
fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
@@ -436,6 +485,16 @@ impl PyDataFrame {
436485
Ok(Self::new(df))
437486
}
438487

488+
// Add column name handling that removes "?table?" prefix
489+
fn format_column_name(&self, name: &str) -> String {
490+
// Strip ?table? prefix if present
491+
if name.starts_with("?table?.") {
492+
name.trim_start_matches("?table?.").to_string()
493+
} else {
494+
name.to_string()
495+
}
496+
}
497+
439498
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
440499
fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
441500
let new_df = self

0 commit comments

Comments
 (0)