Skip to content

Commit 250d09b

Browse files
committed
Add more tests
1 parent 053bb71 commit 250d09b

File tree

1 file changed

+258
-0
lines changed

1 file changed

+258
-0
lines changed

src/proxy_service.rs

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,21 @@ impl DFRayProxyService {
615615
mod tests {
616616
use super::*;
617617
use std::collections::HashMap;
618+
619+
// Test-specific imports
620+
use arrow_flight::{
621+
FlightDescriptor,
622+
Ticket,
623+
decode::FlightRecordBatchStream,
624+
sql::{CommandStatementQuery, TicketStatementQuery},
625+
};
626+
use futures::StreamExt;
627+
use prost::Message;
628+
use tonic::Request;
629+
630+
// //////////////////////////////////////////////////////////////
631+
// Test helper functions
632+
// //////////////////////////////////////////////////////////////
618633

619634
/// Create a test handler for testing - bypasses worker discovery initialization
620635
fn create_test_handler() -> DfRayProxyHandler {
@@ -648,6 +663,70 @@ mod tests {
648663
std::env::set_var("DFRAY_WORKER_ADDRESSES", &mock_env_value);
649664
}
650665

666+
/// Create TicketStatementData for EXPLAIN testing from QueryPlan
667+
fn create_explain_ticket_statement_data(plans: QueryPlan) -> TicketStatementData {
668+
let explain_data = plans.explain_data.map(|data| {
669+
DistributedExplainExecNode {
670+
schema: data.schema().as_ref().try_into().ok(),
671+
logical_plan: data.logical_plan().to_string(),
672+
physical_plan: data.physical_plan().to_string(),
673+
distributed_plan: data.distributed_plan().to_string(),
674+
distributed_stages: data.distributed_stages().to_string(),
675+
}
676+
});
677+
678+
TicketStatementData {
679+
query_id: plans.query_id,
680+
stage_id: plans.final_stage_id,
681+
stage_addrs: Some(plans.worker_addresses.into()),
682+
schema: Some(plans.schema.as_ref().try_into().unwrap()),
683+
explain_data,
684+
}
685+
}
686+
687+
/// Consume a DoGetStream and verify it contains expected EXPLAIN results
688+
async fn verify_explain_stream_results(stream: crate::flight::DoGetStream) {
689+
690+
// Convert the stream to a FlightRecordBatchStream and consume it
691+
// Map Status errors to FlightError to match the expected stream type
692+
let mapped_stream = stream.map(|result| {
693+
result.map_err(|status| arrow_flight::error::FlightError::from(status))
694+
});
695+
let mut flight_stream = FlightRecordBatchStream::new_from_flight_data(mapped_stream);
696+
697+
let mut batches = Vec::new();
698+
while let Some(batch_result) = flight_stream.next().await {
699+
let batch = batch_result.expect("Failed to get batch from stream");
700+
batches.push(batch);
701+
}
702+
703+
// Verify we got exactly one batch with EXPLAIN results
704+
assert_eq!(batches.len(), 1);
705+
let batch = &batches[0];
706+
707+
// Verify schema: should have 2 columns (plan_type, plan)
708+
assert_eq!(batch.num_columns(), 2);
709+
assert_eq!(batch.schema().field(0).name(), "plan_type");
710+
assert_eq!(batch.schema().field(1).name(), "plan");
711+
712+
// Verify we have 4 rows (logical_plan, physical_plan, distributed_plan, distributed_stages)
713+
assert_eq!(batch.num_rows(), 4);
714+
715+
// Verify the plan_type column contains the expected values
716+
let plan_type_column = batch.column(0).as_any().downcast_ref::<arrow::array::StringArray>().unwrap();
717+
assert_eq!(plan_type_column.value(0), "logical_plan");
718+
assert_eq!(plan_type_column.value(1), "physical_plan");
719+
assert_eq!(plan_type_column.value(2), "distributed_plan");
720+
assert_eq!(plan_type_column.value(3), "distributed_stages");
721+
722+
// Verify the plan column contains actual plan content
723+
let plan_column = batch.column(1).as_any().downcast_ref::<arrow::array::StringArray>().unwrap();
724+
assert!(plan_column.value(0).contains("Projection: Int64(1) AS test_col"));
725+
assert!(plan_column.value(1).contains("ProjectionExec"));
726+
assert!(plan_column.value(2).contains("RayStageExec"));
727+
assert!(plan_column.value(3).contains("Stage 0:"));
728+
}
729+
651730
// //////////////////////////////////////////////////////////////
652731
// Unit tests for helper functions
653732
// //////////////////////////////////////////////////////////////
@@ -846,4 +925,183 @@ mod tests {
846925
}
847926
}
848927
}
928+
929+
// //////////////////////////////////////////////////////////////
930+
// EXPLAIN Flow Integration Tests
931+
// //////////////////////////////////////////////////////////////
932+
933+
#[tokio::test]
934+
async fn test_handle_explain_request() {
935+
let handler = create_test_handler();
936+
let query = "EXPLAIN SELECT 1 as test_col, 'hello' as text_col";
937+
938+
let result = handler.handle_explain_request(query).await;
939+
assert!(result.is_ok());
940+
941+
let response = result.unwrap();
942+
let flight_info = response.into_inner();
943+
944+
// Verify FlightInfo structure
945+
assert!(!flight_info.schema.is_empty());
946+
assert_eq!(flight_info.endpoint.len(), 1);
947+
assert!(flight_info.endpoint[0].ticket.is_some());
948+
949+
// Verify that ticket has content (encoded TicketStatementData)
950+
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap();
951+
assert!(!ticket.ticket.is_empty());
952+
953+
println!("✓ FlightInfo created successfully with {} schema bytes and ticket with {} bytes",
954+
flight_info.schema.len(), ticket.ticket.len());
955+
}
956+
957+
#[tokio::test]
958+
async fn test_handle_explain_request_invalid_query() {
959+
let handler = create_test_handler();
960+
961+
// Test with EXPLAIN ANALYZE (should fail)
962+
let query = "EXPLAIN ANALYZE SELECT 1";
963+
let result = handler.handle_explain_request(query).await;
964+
assert!(result.is_err());
965+
966+
let error = result.unwrap_err();
967+
assert_eq!(error.code(), tonic::Code::Internal);
968+
assert!(error.message().contains("Could not prepare EXPLAIN query"));
969+
}
970+
971+
#[tokio::test]
972+
async fn test_handle_explain_statement_execution() {
973+
let handler = create_test_handler();
974+
975+
// First prepare an EXPLAIN query to get the ticket data structure
976+
let query = "EXPLAIN SELECT 1 as test_col";
977+
let plans = handler.prepare_explain(query).await.unwrap();
978+
979+
// Create the TicketStatementData that would be sent to do_get_statement
980+
let tsd = create_explain_ticket_statement_data(plans);
981+
982+
// Test the execution
983+
let result = handler.handle_explain_statement_execution(tsd, "test_remote").await;
984+
assert!(result.is_ok());
985+
986+
let response = result.unwrap();
987+
let stream = response.into_inner();
988+
989+
// Use shared verification function
990+
verify_explain_stream_results(stream).await;
991+
}
992+
993+
#[tokio::test]
994+
async fn test_handle_explain_statement_execution_missing_explain_data() {
995+
let handler = create_test_handler();
996+
997+
// Create TicketStatementData without explain_data (should fail)
998+
let tsd = TicketStatementData {
999+
query_id: "test_query".to_string(),
1000+
stage_id: 0,
1001+
stage_addrs: None,
1002+
schema: None,
1003+
explain_data: None,
1004+
};
1005+
1006+
let result = handler.handle_explain_statement_execution(tsd, "test_remote").await;
1007+
assert!(result.is_err());
1008+
1009+
if let Err(error) = result {
1010+
assert_eq!(error.code(), tonic::Code::Internal);
1011+
assert!(error.message().contains("No explain_data in TicketStatementData"));
1012+
}
1013+
}
1014+
1015+
#[tokio::test]
1016+
async fn test_get_flight_info_statement_explain() {
1017+
1018+
let handler = create_test_handler();
1019+
1020+
// Test EXPLAIN query
1021+
let command = CommandStatementQuery {
1022+
query: "EXPLAIN SELECT 1 as test_col".to_string(),
1023+
transaction_id: None,
1024+
};
1025+
1026+
let request = Request::new(FlightDescriptor::new_cmd(vec![]));
1027+
let result = handler.get_flight_info_statement(command, request).await;
1028+
1029+
assert!(result.is_ok());
1030+
let response = result.unwrap();
1031+
let flight_info = response.into_inner();
1032+
1033+
// Verify FlightInfo structure
1034+
assert!(!flight_info.schema.is_empty());
1035+
assert_eq!(flight_info.endpoint.len(), 1);
1036+
assert!(flight_info.endpoint[0].ticket.is_some());
1037+
1038+
// Verify that ticket has content (encoded TicketStatementData)
1039+
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap();
1040+
assert!(!ticket.ticket.is_empty());
1041+
1042+
println!("✓ FlightInfo created successfully with {} schema bytes and ticket with {} bytes",
1043+
flight_info.schema.len(), ticket.ticket.len());
1044+
}
1045+
1046+
#[tokio::test]
1047+
async fn test_do_get_statement_explain() {
1048+
1049+
let handler = create_test_handler();
1050+
1051+
// First prepare an EXPLAIN query to get proper ticket data
1052+
let query = "EXPLAIN SELECT 1 as test_col";
1053+
let plans = handler.prepare_explain(query).await.unwrap();
1054+
1055+
let tsd = create_explain_ticket_statement_data(plans);
1056+
1057+
// Create the ticket
1058+
let ticket_query = TicketStatementQuery {
1059+
statement_handle: tsd.encode_to_vec().into(),
1060+
};
1061+
1062+
let request = Request::new(Ticket::new(vec![]));
1063+
let result = handler.do_get_statement(ticket_query, request).await;
1064+
1065+
assert!(result.is_ok());
1066+
let response = result.unwrap();
1067+
let stream = response.into_inner();
1068+
1069+
// Use shared verification function
1070+
verify_explain_stream_results(stream).await;
1071+
}
1072+
1073+
#[tokio::test]
1074+
async fn test_compare_explain_flight_info_responses() {
1075+
let handler = create_test_handler();
1076+
let query = "EXPLAIN SELECT 1 as test_col";
1077+
1078+
// Get FlightInfo from handle_explain_request
1079+
let result1 = handler.handle_explain_request(query).await.unwrap();
1080+
let flight_info1 = result1.into_inner();
1081+
1082+
// Get FlightInfo from get_flight_info_statement
1083+
let command = CommandStatementQuery {
1084+
query: query.to_string(),
1085+
transaction_id: None,
1086+
};
1087+
let request = Request::new(FlightDescriptor::new_cmd(vec![]));
1088+
let result2 = handler.get_flight_info_statement(command, request).await.unwrap();
1089+
let flight_info2 = result2.into_inner();
1090+
1091+
// Compare FlightInfo responses (structure should be identical)
1092+
assert_eq!(flight_info1.schema.len(), flight_info2.schema.len()); // Same schema size
1093+
assert_eq!(flight_info1.endpoint.len(), flight_info2.endpoint.len()); // Same number of endpoints
1094+
assert_eq!(flight_info1.endpoint.len(), 1); // Both should have exactly one endpoint
1095+
1096+
// Both should have tickets with content
1097+
let ticket1 = flight_info1.endpoint[0].ticket.as_ref().unwrap();
1098+
let ticket2 = flight_info2.endpoint[0].ticket.as_ref().unwrap();
1099+
assert!(!ticket1.ticket.is_empty());
1100+
assert!(!ticket2.ticket.is_empty());
1101+
1102+
println!("✓ Both tests produce FlightInfo with identical structure:");
1103+
println!(" - Schema bytes: {} vs {}", flight_info1.schema.len(), flight_info2.schema.len());
1104+
println!(" - Endpoints: {} vs {}", flight_info1.endpoint.len(), flight_info2.endpoint.len());
1105+
println!(" - Ticket bytes: {} vs {}", ticket1.ticket.len(), ticket2.ticket.len());
1106+
}
8491107
}

0 commit comments

Comments
 (0)