@@ -116,21 +116,37 @@ impl PyDataFrame {
116116 }
117117
118118 fn __repr__ ( & self , py : Python ) -> PyDataFusionResult < String > {
119+ let ( batches, has_more) = wait_for_future (
120+ py,
121+ collect_record_batches_to_display ( self . df . as_ref ( ) . clone ( ) , 10 , 10 ) ,
122+ ) ?;
123+ if batches. is_empty ( ) {
124+ // This should not be reached, but do it for safety since we index into the vector below
125+ return Ok ( "No data to display" . to_string ( ) ) ;
126+ }
127+
119128 let df = self . df . as_ref ( ) . clone ( ) . limit ( 0 , Some ( 10 ) ) ?;
120129 let batches = wait_for_future ( py, df. collect ( ) ) ?;
121- let batches_as_string = pretty:: pretty_format_batches ( & batches) ;
122- match batches_as_string {
123- Ok ( batch) => Ok ( format ! ( "DataFrame()\n {batch}" ) ) ,
124- Err ( err) => Ok ( format ! ( "Error: {:?}" , err. to_string( ) ) ) ,
125- }
130+ let batches_as_displ =
131+ pretty:: pretty_format_batches ( & batches) . map_err ( py_datafusion_err) ?;
132+
133+ let additional_str = match has_more {
134+ true => "\n Data truncated." ,
135+ false => "" ,
136+ } ;
137+
138+ Ok ( format ! ( "DataFrame()\n {batches_as_displ}{additional_str}" ) )
126139 }
127140
128141 fn _repr_html_ ( & self , py : Python ) -> PyDataFusionResult < String > {
129- let ( batches, mut has_more) =
130- wait_for_future ( py, get_first_few_record_batches ( self . df . as_ref ( ) . clone ( ) ) ) ?;
131- let Some ( batches) = batches else {
132- return Ok ( "No data to display" . to_string ( ) ) ;
133- } ;
142+ let ( batches, mut has_more) = wait_for_future (
143+ py,
144+ collect_record_batches_to_display (
145+ self . df . as_ref ( ) . clone ( ) ,
146+ MIN_TABLE_ROWS_TO_DISPLAY ,
147+ usize:: MAX ,
148+ ) ,
149+ ) ?;
134150 if batches. is_empty ( ) {
135151 // This should not be reached, but do it for safety since we index into the vector below
136152 return Ok ( "No data to display" . to_string ( ) ) ;
@@ -200,10 +216,6 @@ impl PyDataFrame {
200216 let rows_per_batch = batches. iter ( ) . map ( |batch| batch. num_rows ( ) ) ;
201217 let total_rows = rows_per_batch. clone ( ) . sum ( ) ;
202218
203- // let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| {
204- // (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows())
205- // });
206-
207219 let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY {
208220 true => {
209221 let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32 ;
@@ -887,37 +899,78 @@ fn record_batch_into_schema(
887899/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
888900/// It additionally returns a bool, which indicates if there are more record batches available.
889901/// We do this so we can determine if we should indicate to the user that the data has been
890- /// truncated.
891- async fn get_first_few_record_batches (
902+ /// truncated. This collects until we have achived both of these two conditions
903+ ///
904+ /// - We have collected our minimum number of rows
905+ /// - We have reached our limit, either data size or maximum number of rows
906+ ///
907+ /// Otherwise it will return when the stream has exhausted. If you want a specific number of
908+ /// rows, set min_rows == max_rows.
909+ async fn collect_record_batches_to_display (
892910 df : DataFrame ,
893- ) -> Result < ( Option < Vec < RecordBatch > > , bool ) , DataFusionError > {
911+ min_rows : usize ,
912+ max_rows : usize ,
913+ ) -> Result < ( Vec < RecordBatch > , bool ) , DataFusionError > {
894914 let mut stream = df. execute_stream ( ) . await ?;
895915 let mut size_estimate_so_far = 0 ;
916+ let mut rows_so_far = 0 ;
896917 let mut record_batches = Vec :: default ( ) ;
897- while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY {
898- let rb = match stream. next ( ) . await {
918+ let mut has_more = false ;
919+
920+ while ( size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
921+ || rows_so_far < min_rows
922+ {
923+ let mut rb = match stream. next ( ) . await {
899924 None => {
900925 break ;
901926 }
902927 Some ( Ok ( r) ) => r,
903928 Some ( Err ( e) ) => return Err ( e) ,
904929 } ;
905930
906- if rb. num_rows ( ) > 0 {
931+ let mut rows_in_rb = rb. num_rows ( ) ;
932+ if rows_in_rb > 0 {
907933 size_estimate_so_far += rb. get_array_memory_size ( ) ;
934+
935+ if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
936+ let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32 ;
937+ let total_rows = rows_in_rb + rows_so_far;
938+
939+ let mut reduced_row_num = ( total_rows as f32 * ratio) . round ( ) as usize ;
940+ if reduced_row_num < min_rows {
941+ reduced_row_num = min_rows. min ( total_rows) ;
942+ }
943+
944+ let limited_rows_this_rb = reduced_row_num - rows_so_far;
945+ if limited_rows_this_rb < rows_in_rb {
946+ rows_in_rb = limited_rows_this_rb;
947+ rb = rb. slice ( 0 , limited_rows_this_rb) ;
948+ has_more = true ;
949+ }
950+ }
951+
952+ if rows_in_rb + rows_so_far > max_rows {
953+ rb = rb. slice ( 0 , max_rows - rows_so_far) ;
954+ has_more = true ;
955+ }
956+
957+ rows_so_far += rb. num_rows ( ) ;
908958 record_batches. push ( rb) ;
909959 }
910960 }
911961
912962 if record_batches. is_empty ( ) {
913- return Ok ( ( None , false ) ) ;
963+ return Ok ( ( Vec :: default ( ) , false ) ) ;
914964 }
915965
916- let has_more = match stream. try_next ( ) . await {
917- Ok ( None ) => false , // reached end
918- Ok ( Some ( _) ) => true ,
919- Err ( _) => false , // Stream disconnected
920- } ;
966+ if !has_more {
967+ // Data was not already truncated, so check to see if more record batches remain
968+ has_more = match stream. try_next ( ) . await {
969+ Ok ( None ) => false , // reached end
970+ Ok ( Some ( _) ) => true ,
971+ Err ( _) => false , // Stream disconnected
972+ } ;
973+ }
921974
922- Ok ( ( Some ( record_batches) , has_more) )
975+ Ok ( ( record_batches, has_more) )
923976}
0 commit comments