@@ -24,6 +24,8 @@ use arrow::{
2424use std:: mem:: size_of;
2525use std:: { cmp:: Ordering , collections:: BinaryHeap , sync:: Arc } ;
2626
27+ use super :: metrics:: { BaselineMetrics , Count , ExecutionPlanMetricsSet , MetricBuilder } ;
28+ use crate :: spill:: get_record_batch_memory_size;
2729use crate :: { stream:: RecordBatchStreamAdapter , SendableRecordBatchStream } ;
2830use arrow_array:: { Array , ArrayRef , RecordBatch } ;
2931use arrow_schema:: SchemaRef ;
@@ -36,8 +38,6 @@ use datafusion_execution::{
3638use datafusion_physical_expr:: PhysicalSortExpr ;
3739use datafusion_physical_expr_common:: sort_expr:: LexOrdering ;
3840
39- use super :: metrics:: { BaselineMetrics , Count , ExecutionPlanMetricsSet , MetricBuilder } ;
40-
4141/// Global TopK
4242///
4343/// # Background
@@ -575,7 +575,7 @@ impl RecordBatchStore {
575575 pub fn insert ( & mut self , entry : RecordBatchEntry ) {
576576 // uses of 0 means that none of the rows in the batch were stored in the topk
577577 if entry. uses > 0 {
578- self . batches_size += entry. batch . get_array_memory_size ( ) ;
578+ self . batches_size += get_record_batch_memory_size ( & entry. batch ) ;
579579 self . batches . insert ( entry. id , entry) ;
580580 }
581581 }
@@ -630,7 +630,7 @@ impl RecordBatchStore {
630630 let old_entry = self . batches . remove ( & id) . unwrap ( ) ;
631631 self . batches_size = self
632632 . batches_size
633- . checked_sub ( old_entry. batch . get_array_memory_size ( ) )
633+ . checked_sub ( get_record_batch_memory_size ( & old_entry. batch ) )
634634 . unwrap ( ) ;
635635 }
636636 }
@@ -643,3 +643,44 @@ impl RecordBatchStore {
643643 + self . batches_size
644644 }
645645}
646+
647+ #[ cfg( test) ]
648+ mod tests {
649+ use super :: * ;
650+ use arrow:: array:: Int32Array ;
651+ use arrow:: datatypes:: { DataType , Field , Schema } ;
652+ use arrow:: record_batch:: RecordBatch ;
653+ use arrow_array:: Float64Array ;
654+
655+ /// This test ensures the size calculation is correct for RecordBatches with multiple columns.
656+ #[ test]
657+ fn test_record_batch_store_size ( ) {
658+ // given
659+ let schema = Arc :: new ( Schema :: new ( vec ! [
660+ Field :: new( "ints" , DataType :: Int32 , true ) ,
661+ Field :: new( "float64" , DataType :: Float64 , false ) ,
662+ ] ) ) ;
663+ let mut record_batch_store = RecordBatchStore :: new ( Arc :: clone ( & schema) ) ;
664+ let int_array =
665+ Int32Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , Some ( 3 ) , Some ( 4 ) , Some ( 5 ) ] ) ; // 5 * 4 = 20
666+ let float64_array = Float64Array :: from ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 ] ) ; // 5 * 8 = 40
667+
668+ let record_batch_entry = RecordBatchEntry {
669+ id : 0 ,
670+ batch : RecordBatch :: try_new (
671+ schema,
672+ vec ! [ Arc :: new( int_array) , Arc :: new( float64_array) ] ,
673+ )
674+ . unwrap ( ) ,
675+ uses : 1 ,
676+ } ;
677+
678+ // when insert record batch entry
679+ record_batch_store. insert ( record_batch_entry) ;
680+ assert_eq ! ( record_batch_store. batches_size, 60 ) ;
681+
682+ // when unuse record batch entry
683+ record_batch_store. unuse ( 0 ) ;
684+ assert_eq ! ( record_batch_store. batches_size, 0 ) ;
685+ }
686+ }
0 commit comments