@@ -8,13 +8,12 @@ use std::task::{Context, Poll};
88use std:: time:: Duration ;
99
1010use arrow:: array:: RecordBatch ;
11- use arrow:: compute:: concat_batches;
1211use arrow:: datatypes:: SchemaRef ;
1312use arrow:: error:: ArrowError ;
1413use arrow:: ipc:: convert:: fb_to_schema;
1514use arrow:: ipc:: reader:: StreamReader ;
1615use arrow:: ipc:: writer:: { IpcWriteOptions , StreamWriter } ;
17- use arrow:: ipc:: { root_as_message , MetadataVersion } ;
16+ use arrow:: ipc:: { MetadataVersion , root_as_message } ;
1817use arrow:: pyarrow:: * ;
1918use arrow:: util:: pretty;
2019use arrow_flight:: { FlightClient , FlightData , Ticket } ;
@@ -30,16 +29,16 @@ use datafusion::error::DataFusionError;
3029use datafusion:: execution:: object_store:: ObjectStoreUrl ;
3130use datafusion:: execution:: { RecordBatchStream , SendableRecordBatchStream , SessionStateBuilder } ;
3231use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
33- use datafusion:: physical_plan:: { displayable , ExecutionPlan , ExecutionPlanProperties } ;
34- use datafusion:: prelude:: { SessionConfig , SessionContext } ;
32+ use datafusion:: physical_plan:: { ExecutionPlan , ExecutionPlanProperties , displayable } ;
33+ use datafusion:: prelude:: { ParquetReadOptions , SessionConfig , SessionContext } ;
3534use datafusion_proto:: physical_plan:: AsExecutionPlan ;
3635use datafusion_python:: utils:: wait_for_future;
3736use futures:: { Stream , StreamExt } ;
3837use log:: debug;
38+ use object_store:: ObjectStore ;
3939use object_store:: aws:: AmazonS3Builder ;
4040use object_store:: gcp:: GoogleCloudStorageBuilder ;
4141use object_store:: http:: HttpBuilder ;
42- use object_store:: ObjectStore ;
4342use parking_lot:: Mutex ;
4443use pyo3:: prelude:: * ;
4544use pyo3:: types:: { PyBytes , PyList } ;
@@ -411,62 +410,77 @@ fn print_node(plan: &Arc<dyn ExecutionPlan>, indent: usize, output: &mut String)
411410 }
412411}
413412
414- async fn exec_sql (
415- query : String ,
416- tables : Vec < ( String , String ) > ,
417- ) -> Result < RecordBatch , DataFusionError > {
418- let ctx = SessionContext :: new ( ) ;
419- for ( name, path) in tables {
420- let opt =
421- ListingOptions :: new ( Arc :: new ( ParquetFormat :: new ( ) ) ) . with_file_extension ( ".parquet" ) ;
422- debug ! ( "exec_sql: registering table {} at {}" , name, path) ;
413+ #[ pyclass]
414+ pub struct LocalValidator {
415+ ctx : SessionContext ,
416+ }
417+
418+ #[ pymethods]
419+ impl LocalValidator {
420+ #[ new]
421+ fn new ( ) -> Self {
422+ let ctx = SessionContext :: new ( ) ;
423+ Self { ctx }
424+ }
425+
426+ pub fn register_parquet ( & self , py : Python , name : String , path : String ) -> PyResult < ( ) > {
427+ let options = ParquetReadOptions :: default ( ) ;
423428
424- let url = ListingTableUrl :: parse ( & path) ?;
429+ let url = ListingTableUrl :: parse ( & path) . to_py_err ( ) ?;
425430
426- maybe_register_object_store ( & ctx, url. as_ref ( ) ) ?;
431+ maybe_register_object_store ( & self . ctx , url. as_ref ( ) ) . to_py_err ( ) ?;
432+ debug ! ( "register_parquet: registering table {} at {}" , name, path) ;
427433
428- ctx. register_listing_table ( & name, & path, opt , None , None )
429- . await ? ;
434+ wait_for_future ( py , self . ctx . register_parquet ( & name, & path, options . clone ( ) ) ) ? ;
435+ Ok ( ( ) )
430436 }
431- let df = ctx. sql ( & query) . await ?;
432- let schema = df. schema ( ) . inner ( ) . clone ( ) ;
433- let batches = df. collect ( ) . await ?;
434- concat_batches ( & schema, batches. iter ( ) ) . map_err ( |e| DataFusionError :: ArrowError ( e, None ) )
435- }
436437
437- /// Executes a query on the specified tables using DataFusion without Ray.
438- ///
439- /// Returns the query results as a RecordBatch that can be used to verify the
440- /// correctness of DataFusion-Ray execution of the same query.
441- ///
442- /// # Arguments
443- ///
444- /// * `py`: the Python token
445- /// * `query`: the SQL query string to execute
446- /// * `tables`: a list of `(name, url)` tuples specifying the tables to query;
447- /// the `url` identifies the parquet files for each listing table and see
448- /// [`datafusion::datasource::listing::ListingTableUrl::parse`] for details
449- /// of supported URL formats
450- /// * `listing`: boolean indicating whether this is a listing table path or not
451- #[ pyfunction]
452- #[ pyo3( signature = ( query, tables, listing=false ) ) ]
453- pub fn exec_sql_on_tables (
454- py : Python ,
455- query : String ,
456- tables : Bound < ' _ , PyList > ,
457- listing : bool ,
458- ) -> PyResult < PyObject > {
459- let table_vec = {
460- let mut v = Vec :: with_capacity ( tables. len ( ) ) ;
461- for entry in tables. iter ( ) {
462- let ( name, path) = entry. extract :: < ( String , String ) > ( ) ?;
463- let path = if listing { format ! ( "{path}/" ) } else { path } ;
464- v. push ( ( name, path) ) ;
465- }
466- v
467- } ;
468- let batch = wait_for_future ( py, exec_sql ( query, table_vec) ) ?;
469- batch. to_pyarrow ( py)
438+ #[ pyo3( signature = ( name, path, file_extension=".parquet" ) ) ]
439+ pub fn register_listing_table (
440+ & mut self ,
441+ py : Python ,
442+ name : & str ,
443+ path : & str ,
444+ file_extension : & str ,
445+ ) -> PyResult < ( ) > {
446+ let options =
447+ ListingOptions :: new ( Arc :: new ( ParquetFormat :: new ( ) ) ) . with_file_extension ( file_extension) ;
448+
449+ let path = format ! ( "{path}/" ) ;
450+ let url = ListingTableUrl :: parse ( & path) . to_py_err ( ) ?;
451+
452+ maybe_register_object_store ( & self . ctx , url. as_ref ( ) ) . to_py_err ( ) ?;
453+
454+ debug ! (
455+ "register_listing_table: registering table {} at {}" ,
456+ name, path
457+ ) ;
458+ wait_for_future (
459+ py,
460+ self . ctx
461+ . register_listing_table ( name, path, options, None , None ) ,
462+ )
463+ . to_py_err ( )
464+ }
465+
466+ #[ pyo3( signature = ( query) ) ]
467+ fn collect_sql ( & self , py : Python , query : String ) -> PyResult < PyObject > {
468+ let fut = async || {
469+ let df = self . ctx . sql ( & query) . await ?;
470+ let batches = df. collect ( ) . await ?;
471+
472+ Ok :: < _ , DataFusionError > ( batches)
473+ } ;
474+
475+ let batches = wait_for_future ( py, fut ( ) )
476+ . to_py_err ( ) ?
477+ . iter ( )
478+ . map ( |batch| batch. to_pyarrow ( py) )
479+ . collect :: < PyResult < Vec < _ > > > ( ) ?;
480+
481+ let pylist = PyList :: new ( py, batches) ?;
482+ Ok ( pylist. into ( ) )
483+ }
470484}
471485
472486pub ( crate ) fn register_object_store_for_paths_in_plan (
@@ -570,62 +584,14 @@ mod test {
570584 use std:: { sync:: Arc , vec} ;
571585
572586 use arrow:: {
573- array:: { Int32Array , StringArray } ,
587+ array:: Int32Array ,
574588 datatypes:: { DataType , Field , Schema } ,
575589 } ;
576- use datafusion:: {
577- parquet:: file:: properties:: WriterProperties , test_util:: parquet:: TestParquetFile ,
578- } ;
590+
579591 use futures:: stream;
580592
581593 use super :: * ;
582594
583- #[ tokio:: test]
584- async fn test_exec_sql ( ) {
585- let dir = tempfile:: tempdir ( ) . unwrap ( ) ;
586- let path = dir. path ( ) . join ( "people.parquet" ) ;
587-
588- let batch = RecordBatch :: try_new (
589- Arc :: new ( Schema :: new ( vec ! [
590- Field :: new( "age" , DataType :: Int32 , false ) ,
591- Field :: new( "name" , DataType :: Utf8 , false ) ,
592- ] ) ) ,
593- vec ! [
594- Arc :: new( Int32Array :: from( vec![ 11 , 12 , 13 ] ) ) ,
595- Arc :: new( StringArray :: from( vec![ "alice" , "bob" , "cindy" ] ) ) ,
596- ] ,
597- )
598- . unwrap ( ) ;
599- let props = WriterProperties :: builder ( ) . build ( ) ;
600- let file = TestParquetFile :: try_new ( path. clone ( ) , props, Some ( batch. clone ( ) ) ) . unwrap ( ) ;
601-
602- // test with file
603- let tables = vec ! [ (
604- "people" . to_string( ) ,
605- format!( "file://{}" , file. path( ) . to_str( ) . unwrap( ) ) ,
606- ) ] ;
607- let query = "SELECT * FROM people ORDER BY age" . to_string ( ) ;
608- let res = exec_sql ( query. clone ( ) , tables) . await . unwrap ( ) ;
609- assert_eq ! (
610- format!(
611- "{}" ,
612- pretty:: pretty_format_batches( & [ batch. clone( ) ] ) . unwrap( )
613- ) ,
614- format!( "{}" , pretty:: pretty_format_batches( & [ res] ) . unwrap( ) ) ,
615- ) ;
616-
617- // test with dir
618- let tables = vec ! [ (
619- "people" . to_string( ) ,
620- format!( "file://{}/" , dir. path( ) . to_str( ) . unwrap( ) ) ,
621- ) ] ;
622- let res = exec_sql ( query, tables) . await . unwrap ( ) ;
623- assert_eq ! (
624- format!( "{}" , pretty:: pretty_format_batches( & [ batch] ) . unwrap( ) ) ,
625- format!( "{}" , pretty:: pretty_format_batches( & [ res] ) . unwrap( ) ) ,
626- ) ;
627- }
628-
629595 #[ test]
630596 fn test_ipc_roundtrip ( ) {
631597 let batch = RecordBatch :: try_new (
@@ -641,10 +607,9 @@ mod test {
641607 #[ tokio:: test]
642608 async fn test_max_rows_stream ( ) {
643609 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ) ;
644- let batch = RecordBatch :: try_new (
645- schema. clone ( ) ,
646- vec ! [ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ) ) ] ,
647- )
610+ let batch = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( Int32Array :: from( vec![
611+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ,
612+ ] ) ) ] )
648613 . unwrap ( ) ;
649614
650615 // 24 total rows
0 commit comments