@@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
8484 async fn scan (
8585 & self ,
8686 _state : & dyn Session ,
87- _projection : Option < & Vec < usize > > ,
87+ projection : Option < & Vec < usize > > ,
8888 _filters : & [ Expr ] ,
8989 limit : Option < usize > ,
9090 ) -> Result < Arc < dyn ExecutionPlan > > {
9191 let mut reader_guard = self . reader . lock ( ) ;
9292 if let Some ( reader) = reader_guard. take ( ) {
93- Ok ( Arc :: new ( RecordBatchReaderExec :: new ( reader, limit) ) )
93+ let projection = projection. cloned ( ) ;
94+ Ok ( Arc :: new ( RecordBatchReaderExec :: try_new (
95+ reader, limit, projection,
96+ ) ?) )
9497 } else {
9598 sedona_internal_err ! ( "Can't scan RecordBatchReader provider more than once" )
9699 }
@@ -158,24 +161,39 @@ struct RecordBatchReaderExec {
158161 schema : SchemaRef ,
159162 properties : PlanProperties ,
160163 limit : Option < usize > ,
164+ projection : Option < Vec < usize > > ,
161165}
162166
163167impl RecordBatchReaderExec {
164- fn new ( reader : Box < dyn RecordBatchReader + Send > , limit : Option < usize > ) -> Self {
165- let schema = reader. schema ( ) ;
168+ fn try_new (
169+ reader : Box < dyn RecordBatchReader + Send > ,
170+ limit : Option < usize > ,
171+ projection : Option < Vec < usize > > ,
172+ ) -> Result < Self > {
173+ let full_schema = reader. schema ( ) ;
174+ let schema: SchemaRef = if let Some ( indices) = projection. as_ref ( ) {
175+ SchemaRef :: new (
176+ full_schema
177+ . project ( indices)
178+ . map_err ( DataFusionError :: from) ?,
179+ )
180+ } else {
181+ full_schema. clone ( )
182+ } ;
166183 let properties = PlanProperties :: new (
167184 EquivalenceProperties :: new ( schema. clone ( ) ) ,
168185 Partitioning :: UnknownPartitioning ( 1 ) ,
169186 EmissionType :: Incremental ,
170187 Boundedness :: Bounded ,
171188 ) ;
172189
173- Self {
190+ Ok ( Self {
174191 reader : Mutex :: new ( Some ( reader) ) ,
175192 schema,
176193 properties,
177194 limit,
178- }
195+ projection,
196+ } )
179197 }
180198}
181199
@@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
186204 . field ( "schema" , & self . schema )
187205 . field ( "properties" , & self . properties )
188206 . field ( "limit" , & self . limit )
207+ . field ( "projection" , & self . projection )
189208 . finish ( )
190209 }
191210}
@@ -240,17 +259,34 @@ impl ExecutionPlan for RecordBatchReaderExec {
240259 match self . limit {
241260 Some ( limit) => {
242261 // Create a row-limited iterator that properly handles row counting
243- let iter = RowLimitedIterator :: new ( reader, limit) ;
262+ let projection = self . projection . clone ( ) ;
263+ let iter = RowLimitedIterator :: new ( reader, limit) . map ( move |res| match res {
264+ Ok ( batch) => {
265+ if let Some ( indices) = projection. as_ref ( ) {
266+ batch. project ( indices) . map_err ( |e| e. into ( ) )
267+ } else {
268+ Ok ( batch)
269+ }
270+ }
271+ Err ( e) => Err ( e) ,
272+ } ) ;
244273 let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
245274 let record_batch_stream =
246275 RecordBatchStreamAdapter :: new ( self . schema . clone ( ) , stream) ;
247276 Ok ( Box :: pin ( record_batch_stream) )
248277 }
249278 None => {
250279 // No limit, just convert the reader directly to a stream
251- let iter = reader. map ( |item| match item {
252- Ok ( batch) => Ok ( batch) ,
253- Err ( e) => Err ( DataFusionError :: from ( e) ) ,
280+ let projection = self . projection . clone ( ) ;
281+ let iter = reader. map ( move |item| match item {
282+ Ok ( batch) => {
283+ if let Some ( indices) = projection. as_ref ( ) {
284+ batch. project ( indices) . map_err ( |e| e. into ( ) )
285+ } else {
286+ Ok ( batch)
287+ }
288+ }
289+ Err ( e) => Err ( e. into ( ) ) ,
254290 } ) ;
255291 let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
256292 let record_batch_stream =
@@ -266,7 +302,7 @@ mod test {
266302
267303 use arrow_array:: { RecordBatch , RecordBatchIterator } ;
268304 use arrow_schema:: { DataType , Field , Schema } ;
269- use datafusion:: prelude:: { DataFrame , SessionContext } ;
305+ use datafusion:: prelude:: { col , DataFrame , SessionContext } ;
270306 use rstest:: rstest;
271307 use sedona_schema:: datatypes:: WKB_GEOMETRY ;
272308 use sedona_testing:: create:: create_array_storage;
@@ -383,6 +419,45 @@ mod test {
383419 }
384420 }
385421
422+ #[ tokio:: test]
423+ async fn test_projection_pushdown ( ) {
424+ let ctx = SessionContext :: new ( ) ;
425+
426+ // Create a two-column batch
427+ let schema = Schema :: new ( vec ! [
428+ Field :: new( "a" , DataType :: Int32 , false ) ,
429+ Field :: new( "b" , DataType :: Int32 , false ) ,
430+ ] ) ;
431+ let batch = RecordBatch :: try_new (
432+ Arc :: new ( schema. clone ( ) ) ,
433+ vec ! [
434+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
435+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 10 , 20 , 30 ] ) ) ,
436+ ] ,
437+ )
438+ . unwrap ( ) ;
439+
440+ // Wrap in a RecordBatchReaderProvider
441+ let reader =
442+ RecordBatchIterator :: new ( vec ! [ batch. clone( ) ] . into_iter ( ) . map ( Ok ) , Arc :: new ( schema) ) ;
443+ let provider = Arc :: new ( RecordBatchReaderProvider :: new ( Box :: new ( reader) ) ) ;
444+
445+ // Read table then select only column b (this should push projection into scan)
446+ let df = ctx. read_table ( provider) . unwrap ( ) ;
447+ let df_b = df. select ( vec ! [ col( "b" ) ] ) . unwrap ( ) ;
448+ let results = df_b. collect ( ) . await . unwrap ( ) ;
449+ assert_eq ! ( results. len( ) , 1 ) ;
450+ let out_batch = & results[ 0 ] ;
451+ assert_eq ! ( out_batch. num_columns( ) , 1 ) ;
452+ assert_eq ! ( out_batch. schema( ) . field( 0 ) . name( ) , "b" ) ;
453+ let values = out_batch
454+ . column ( 0 )
455+ . as_any ( )
456+ . downcast_ref :: < arrow_array:: Int32Array > ( )
457+ . unwrap ( ) ;
458+ assert_eq ! ( values. values( ) , & [ 10 , 20 , 30 ] ) ;
459+ }
460+
386461 fn read_test_table_with_limit (
387462 ctx : & SessionContext ,
388463 batch_sizes : Vec < usize > ,
0 commit comments