@@ -615,6 +615,21 @@ impl DFRayProxyService {
615615mod tests {
616616 use super :: * ;
617617 use std:: collections:: HashMap ;
618+
619+ // Test-specific imports
620+ use arrow_flight:: {
621+ FlightDescriptor ,
622+ Ticket ,
623+ decode:: FlightRecordBatchStream ,
624+ sql:: { CommandStatementQuery , TicketStatementQuery } ,
625+ } ;
626+ use futures:: StreamExt ;
627+ use prost:: Message ;
628+ use tonic:: Request ;
629+
630+ // //////////////////////////////////////////////////////////////
631+ // Test helper functions
632+ // //////////////////////////////////////////////////////////////
618633
619634 /// Create a test handler for testing - bypasses worker discovery initialization
620635 fn create_test_handler ( ) -> DfRayProxyHandler {
@@ -648,6 +663,70 @@ mod tests {
648663 std:: env:: set_var ( "DFRAY_WORKER_ADDRESSES" , & mock_env_value) ;
649664 }
650665
666+ /// Create TicketStatementData for EXPLAIN testing from QueryPlan
667+ fn create_explain_ticket_statement_data ( plans : QueryPlan ) -> TicketStatementData {
668+ let explain_data = plans. explain_data . map ( |data| {
669+ DistributedExplainExecNode {
670+ schema : data. schema ( ) . as_ref ( ) . try_into ( ) . ok ( ) ,
671+ logical_plan : data. logical_plan ( ) . to_string ( ) ,
672+ physical_plan : data. physical_plan ( ) . to_string ( ) ,
673+ distributed_plan : data. distributed_plan ( ) . to_string ( ) ,
674+ distributed_stages : data. distributed_stages ( ) . to_string ( ) ,
675+ }
676+ } ) ;
677+
678+ TicketStatementData {
679+ query_id : plans. query_id ,
680+ stage_id : plans. final_stage_id ,
681+ stage_addrs : Some ( plans. worker_addresses . into ( ) ) ,
682+ schema : Some ( plans. schema . as_ref ( ) . try_into ( ) . unwrap ( ) ) ,
683+ explain_data,
684+ }
685+ }
686+
687+ /// Consume a DoGetStream and verify it contains expected EXPLAIN results
688+ async fn verify_explain_stream_results ( stream : crate :: flight:: DoGetStream ) {
689+
690+ // Convert the stream to a FlightRecordBatchStream and consume it
691+ // Map Status errors to FlightError to match the expected stream type
692+ let mapped_stream = stream. map ( |result| {
693+ result. map_err ( |status| arrow_flight:: error:: FlightError :: from ( status) )
694+ } ) ;
695+ let mut flight_stream = FlightRecordBatchStream :: new_from_flight_data ( mapped_stream) ;
696+
697+ let mut batches = Vec :: new ( ) ;
698+ while let Some ( batch_result) = flight_stream. next ( ) . await {
699+ let batch = batch_result. expect ( "Failed to get batch from stream" ) ;
700+ batches. push ( batch) ;
701+ }
702+
703+ // Verify we got exactly one batch with EXPLAIN results
704+ assert_eq ! ( batches. len( ) , 1 ) ;
705+ let batch = & batches[ 0 ] ;
706+
707+ // Verify schema: should have 2 columns (plan_type, plan)
708+ assert_eq ! ( batch. num_columns( ) , 2 ) ;
709+ assert_eq ! ( batch. schema( ) . field( 0 ) . name( ) , "plan_type" ) ;
710+ assert_eq ! ( batch. schema( ) . field( 1 ) . name( ) , "plan" ) ;
711+
712+ // Verify we have 4 rows (logical_plan, physical_plan, distributed_plan, distributed_stages)
713+ assert_eq ! ( batch. num_rows( ) , 4 ) ;
714+
715+ // Verify the plan_type column contains the expected values
716+ let plan_type_column = batch. column ( 0 ) . as_any ( ) . downcast_ref :: < arrow:: array:: StringArray > ( ) . unwrap ( ) ;
717+ assert_eq ! ( plan_type_column. value( 0 ) , "logical_plan" ) ;
718+ assert_eq ! ( plan_type_column. value( 1 ) , "physical_plan" ) ;
719+ assert_eq ! ( plan_type_column. value( 2 ) , "distributed_plan" ) ;
720+ assert_eq ! ( plan_type_column. value( 3 ) , "distributed_stages" ) ;
721+
722+ // Verify the plan column contains actual plan content
723+ let plan_column = batch. column ( 1 ) . as_any ( ) . downcast_ref :: < arrow:: array:: StringArray > ( ) . unwrap ( ) ;
724+ assert ! ( plan_column. value( 0 ) . contains( "Projection: Int64(1) AS test_col" ) ) ;
725+ assert ! ( plan_column. value( 1 ) . contains( "ProjectionExec" ) ) ;
726+ assert ! ( plan_column. value( 2 ) . contains( "RayStageExec" ) ) ;
727+ assert ! ( plan_column. value( 3 ) . contains( "Stage 0:" ) ) ;
728+ }
729+
651730 // //////////////////////////////////////////////////////////////
652731 // Unit tests for helper functions
653732 // //////////////////////////////////////////////////////////////
@@ -846,4 +925,183 @@ mod tests {
846925 }
847926 }
848927 }
928+
929+ // //////////////////////////////////////////////////////////////
930+ // EXPLAIN Flow Integration Tests
931+ // //////////////////////////////////////////////////////////////
932+
933+ #[ tokio:: test]
934+ async fn test_handle_explain_request ( ) {
935+ let handler = create_test_handler ( ) ;
936+ let query = "EXPLAIN SELECT 1 as test_col, 'hello' as text_col" ;
937+
938+ let result = handler. handle_explain_request ( query) . await ;
939+ assert ! ( result. is_ok( ) ) ;
940+
941+ let response = result. unwrap ( ) ;
942+ let flight_info = response. into_inner ( ) ;
943+
944+ // Verify FlightInfo structure
945+ assert ! ( !flight_info. schema. is_empty( ) ) ;
946+ assert_eq ! ( flight_info. endpoint. len( ) , 1 ) ;
947+ assert ! ( flight_info. endpoint[ 0 ] . ticket. is_some( ) ) ;
948+
949+ // Verify that ticket has content (encoded TicketStatementData)
950+ let ticket = flight_info. endpoint [ 0 ] . ticket . as_ref ( ) . unwrap ( ) ;
951+ assert ! ( !ticket. ticket. is_empty( ) ) ;
952+
953+ println ! ( "✓ FlightInfo created successfully with {} schema bytes and ticket with {} bytes" ,
954+ flight_info. schema. len( ) , ticket. ticket. len( ) ) ;
955+ }
956+
957+ #[ tokio:: test]
958+ async fn test_handle_explain_request_invalid_query ( ) {
959+ let handler = create_test_handler ( ) ;
960+
961+ // Test with EXPLAIN ANALYZE (should fail)
962+ let query = "EXPLAIN ANALYZE SELECT 1" ;
963+ let result = handler. handle_explain_request ( query) . await ;
964+ assert ! ( result. is_err( ) ) ;
965+
966+ let error = result. unwrap_err ( ) ;
967+ assert_eq ! ( error. code( ) , tonic:: Code :: Internal ) ;
968+ assert ! ( error. message( ) . contains( "Could not prepare EXPLAIN query" ) ) ;
969+ }
970+
971+ #[ tokio:: test]
972+ async fn test_handle_explain_statement_execution ( ) {
973+ let handler = create_test_handler ( ) ;
974+
975+ // First prepare an EXPLAIN query to get the ticket data structure
976+ let query = "EXPLAIN SELECT 1 as test_col" ;
977+ let plans = handler. prepare_explain ( query) . await . unwrap ( ) ;
978+
979+ // Create the TicketStatementData that would be sent to do_get_statement
980+ let tsd = create_explain_ticket_statement_data ( plans) ;
981+
982+ // Test the execution
983+ let result = handler. handle_explain_statement_execution ( tsd, "test_remote" ) . await ;
984+ assert ! ( result. is_ok( ) ) ;
985+
986+ let response = result. unwrap ( ) ;
987+ let stream = response. into_inner ( ) ;
988+
989+ // Use shared verification function
990+ verify_explain_stream_results ( stream) . await ;
991+ }
992+
993+ #[ tokio:: test]
994+ async fn test_handle_explain_statement_execution_missing_explain_data ( ) {
995+ let handler = create_test_handler ( ) ;
996+
997+ // Create TicketStatementData without explain_data (should fail)
998+ let tsd = TicketStatementData {
999+ query_id : "test_query" . to_string ( ) ,
1000+ stage_id : 0 ,
1001+ stage_addrs : None ,
1002+ schema : None ,
1003+ explain_data : None ,
1004+ } ;
1005+
1006+ let result = handler. handle_explain_statement_execution ( tsd, "test_remote" ) . await ;
1007+ assert ! ( result. is_err( ) ) ;
1008+
1009+ if let Err ( error) = result {
1010+ assert_eq ! ( error. code( ) , tonic:: Code :: Internal ) ;
1011+ assert ! ( error. message( ) . contains( "No explain_data in TicketStatementData" ) ) ;
1012+ }
1013+ }
1014+
1015+ #[ tokio:: test]
1016+ async fn test_get_flight_info_statement_explain ( ) {
1017+
1018+ let handler = create_test_handler ( ) ;
1019+
1020+ // Test EXPLAIN query
1021+ let command = CommandStatementQuery {
1022+ query : "EXPLAIN SELECT 1 as test_col" . to_string ( ) ,
1023+ transaction_id : None ,
1024+ } ;
1025+
1026+ let request = Request :: new ( FlightDescriptor :: new_cmd ( vec ! [ ] ) ) ;
1027+ let result = handler. get_flight_info_statement ( command, request) . await ;
1028+
1029+ assert ! ( result. is_ok( ) ) ;
1030+ let response = result. unwrap ( ) ;
1031+ let flight_info = response. into_inner ( ) ;
1032+
1033+ // Verify FlightInfo structure
1034+ assert ! ( !flight_info. schema. is_empty( ) ) ;
1035+ assert_eq ! ( flight_info. endpoint. len( ) , 1 ) ;
1036+ assert ! ( flight_info. endpoint[ 0 ] . ticket. is_some( ) ) ;
1037+
1038+ // Verify that ticket has content (encoded TicketStatementData)
1039+ let ticket = flight_info. endpoint [ 0 ] . ticket . as_ref ( ) . unwrap ( ) ;
1040+ assert ! ( !ticket. ticket. is_empty( ) ) ;
1041+
1042+ println ! ( "✓ FlightInfo created successfully with {} schema bytes and ticket with {} bytes" ,
1043+ flight_info. schema. len( ) , ticket. ticket. len( ) ) ;
1044+ }
1045+
1046+ #[ tokio:: test]
1047+ async fn test_do_get_statement_explain ( ) {
1048+
1049+ let handler = create_test_handler ( ) ;
1050+
1051+ // First prepare an EXPLAIN query to get proper ticket data
1052+ let query = "EXPLAIN SELECT 1 as test_col" ;
1053+ let plans = handler. prepare_explain ( query) . await . unwrap ( ) ;
1054+
1055+ let tsd = create_explain_ticket_statement_data ( plans) ;
1056+
1057+ // Create the ticket
1058+ let ticket_query = TicketStatementQuery {
1059+ statement_handle : tsd. encode_to_vec ( ) . into ( ) ,
1060+ } ;
1061+
1062+ let request = Request :: new ( Ticket :: new ( vec ! [ ] ) ) ;
1063+ let result = handler. do_get_statement ( ticket_query, request) . await ;
1064+
1065+ assert ! ( result. is_ok( ) ) ;
1066+ let response = result. unwrap ( ) ;
1067+ let stream = response. into_inner ( ) ;
1068+
1069+ // Use shared verification function
1070+ verify_explain_stream_results ( stream) . await ;
1071+ }
1072+
1073+ #[ tokio:: test]
1074+ async fn test_compare_explain_flight_info_responses ( ) {
1075+ let handler = create_test_handler ( ) ;
1076+ let query = "EXPLAIN SELECT 1 as test_col" ;
1077+
1078+ // Get FlightInfo from handle_explain_request
1079+ let result1 = handler. handle_explain_request ( query) . await . unwrap ( ) ;
1080+ let flight_info1 = result1. into_inner ( ) ;
1081+
1082+ // Get FlightInfo from get_flight_info_statement
1083+ let command = CommandStatementQuery {
1084+ query : query. to_string ( ) ,
1085+ transaction_id : None ,
1086+ } ;
1087+ let request = Request :: new ( FlightDescriptor :: new_cmd ( vec ! [ ] ) ) ;
1088+ let result2 = handler. get_flight_info_statement ( command, request) . await . unwrap ( ) ;
1089+ let flight_info2 = result2. into_inner ( ) ;
1090+
1091+ // Compare FlightInfo responses (structure should be identical)
1092+ assert_eq ! ( flight_info1. schema. len( ) , flight_info2. schema. len( ) ) ; // Same schema size
1093+ assert_eq ! ( flight_info1. endpoint. len( ) , flight_info2. endpoint. len( ) ) ; // Same number of endpoints
1094+ assert_eq ! ( flight_info1. endpoint. len( ) , 1 ) ; // Both should have exactly one endpoint
1095+
1096+ // Both should have tickets with content
1097+ let ticket1 = flight_info1. endpoint [ 0 ] . ticket . as_ref ( ) . unwrap ( ) ;
1098+ let ticket2 = flight_info2. endpoint [ 0 ] . ticket . as_ref ( ) . unwrap ( ) ;
1099+ assert ! ( !ticket1. ticket. is_empty( ) ) ;
1100+ assert ! ( !ticket2. ticket. is_empty( ) ) ;
1101+
1102+ println ! ( "✓ Both tests produce FlightInfo with identical structure:" ) ;
1103+ println ! ( " - Schema bytes: {} vs {}" , flight_info1. schema. len( ) , flight_info2. schema. len( ) ) ;
1104+ println ! ( " - Endpoints: {} vs {}" , flight_info1. endpoint. len( ) , flight_info2. endpoint. len( ) ) ;
1105+ println ! ( " - Ticket bytes: {} vs {}" , ticket1. ticket. len( ) , ticket2. ticket. len( ) ) ;
1106+ }
8491107}
0 commit comments