Skip to content

Commit 374872e

Browse files
committed
datafusion-flight-sql-server: add a few tests
1 parent 87ab8be commit 374872e

File tree

2 files changed

+365
-0
lines changed

2 files changed

+365
-0
lines changed
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
use std::sync::Arc;
2+
3+
use arrow::{
4+
array::{Int32Array, RecordBatch, StringArray},
5+
datatypes::{DataType, Field, Schema},
6+
};
7+
use arrow_flight::sql::client::FlightSqlServiceClient;
8+
use datafusion::{
9+
datasource::MemTable,
10+
execution::context::{SessionContext, SessionState},
11+
};
12+
use datafusion_flight_sql_server::service::FlightSqlService;
13+
use futures::TryStreamExt;
14+
use tokio::time::{sleep, Duration};
15+
use tonic::transport::{Channel, Endpoint};
16+
17+
fn create_test_session() -> SessionState {
18+
let ctx = SessionContext::new();
19+
20+
let schema = Arc::new(Schema::new(vec![
21+
Field::new("id", DataType::Int32, false),
22+
Field::new("name", DataType::Utf8, false),
23+
]));
24+
25+
let batch = RecordBatch::try_new(
26+
schema.clone(),
27+
vec![
28+
Arc::new(Int32Array::from(vec![1, 2, 3])),
29+
Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
30+
],
31+
)
32+
.unwrap();
33+
34+
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
35+
ctx.register_table("users", Arc::new(table)).unwrap();
36+
37+
let orders_schema = Arc::new(Schema::new(vec![
38+
Field::new("order_id", DataType::Int32, false),
39+
Field::new("user_id", DataType::Int32, false),
40+
Field::new("amount", DataType::Int32, false),
41+
]));
42+
43+
let orders_batch = RecordBatch::try_new(
44+
orders_schema.clone(),
45+
vec![
46+
Arc::new(Int32Array::from(vec![100, 101, 102, 103])),
47+
Arc::new(Int32Array::from(vec![1, 2, 1, 3])),
48+
Arc::new(Int32Array::from(vec![50, 75, 100, 25])),
49+
],
50+
)
51+
.unwrap();
52+
53+
let orders_table = MemTable::try_new(orders_schema, vec![vec![orders_batch]]).unwrap();
54+
ctx.register_table("orders", Arc::new(orders_table))
55+
.unwrap();
56+
57+
ctx.state()
58+
}
59+
60+
async fn start_test_server(addr: String, state: SessionState) {
61+
tokio::spawn(async move {
62+
FlightSqlService::new(state)
63+
.serve(addr)
64+
.await
65+
.expect("Server should start successfully");
66+
});
67+
68+
sleep(Duration::from_millis(500)).await;
69+
}
70+
71+
async fn create_test_client(addr: &str) -> FlightSqlServiceClient<Channel> {
72+
let endpoint = Endpoint::new(addr.to_string()).expect("Valid endpoint");
73+
let channel = endpoint.connect().await.expect("Connection successful");
74+
FlightSqlServiceClient::new(channel)
75+
}
76+
77+
#[tokio::test]
78+
async fn test_basic_query_execution() {
79+
let addr = "0.0.0.0:50061";
80+
let state = create_test_session();
81+
start_test_server(addr.to_string(), state).await;
82+
83+
let mut client = create_test_client(&format!("http://{}", addr)).await;
84+
85+
let flight_info = client
86+
.execute("SELECT * FROM users".to_string(), None)
87+
.await
88+
.expect("Query should succeed");
89+
90+
let ticket = flight_info
91+
.endpoint
92+
.first()
93+
.expect("Should have endpoint")
94+
.ticket
95+
.clone()
96+
.expect("Should have ticket");
97+
98+
let mut stream = client.do_get(ticket).await.expect("do_get should succeed");
99+
100+
let mut batches = Vec::new();
101+
while let Some(batch) = stream.try_next().await.expect("Stream should work") {
102+
batches.push(batch);
103+
}
104+
105+
assert!(!batches.is_empty(), "Should have result batches");
106+
107+
let first_batch = &batches[0];
108+
assert_eq!(first_batch.num_columns(), 2);
109+
assert_eq!(first_batch.schema().field(0).name(), "id");
110+
assert_eq!(first_batch.schema().field(1).name(), "name");
111+
112+
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
113+
assert_eq!(total_rows, 3);
114+
}
115+
116+
#[tokio::test]
117+
async fn test_query_with_filter() {
118+
let addr = "0.0.0.0:50062";
119+
let state = create_test_session();
120+
start_test_server(addr.to_string(), state).await;
121+
122+
let mut client = create_test_client(&format!("http://{}", addr)).await;
123+
124+
let flight_info = client
125+
.execute("SELECT name FROM users WHERE id > 1".to_string(), None)
126+
.await
127+
.expect("Query should succeed");
128+
129+
let ticket = flight_info
130+
.endpoint
131+
.first()
132+
.expect("Should have endpoint")
133+
.ticket
134+
.clone()
135+
.expect("Should have ticket");
136+
137+
let mut stream = client.do_get(ticket).await.expect("do_get should succeed");
138+
139+
let mut batches = Vec::new();
140+
while let Some(batch) = stream.try_next().await.expect("Stream should work") {
141+
batches.push(batch);
142+
}
143+
144+
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
145+
assert_eq!(total_rows, 2, "Should have 2 rows after filter");
146+
}
147+
148+
#[tokio::test]
149+
async fn test_prepared_statement_creation() {
150+
let addr = "0.0.0.0:50063";
151+
let state = create_test_session();
152+
start_test_server(addr.to_string(), state).await;
153+
154+
let mut client = create_test_client(&format!("http://{}", addr)).await;
155+
156+
let query = "SELECT * FROM users WHERE id = $1";
157+
let prepared = client
158+
.prepare(query.to_string(), None)
159+
.await
160+
.expect("Prepare should succeed");
161+
162+
let dataset_schema = prepared
163+
.dataset_schema()
164+
.expect("Should have dataset schema");
165+
assert_eq!(dataset_schema.fields().len(), 2);
166+
167+
let parameter_schema = prepared
168+
.parameter_schema()
169+
.expect("Should have parameter schema");
170+
assert_eq!(parameter_schema.fields().len(), 1);
171+
}
172+
173+
#[tokio::test]
174+
async fn test_get_schemas() {
175+
let addr = "0.0.0.0:50064";
176+
let state = create_test_session();
177+
start_test_server(addr.to_string(), state).await;
178+
179+
let mut client = create_test_client(&format!("http://{}", addr)).await;
180+
181+
let flight_info = client
182+
.get_db_schemas(arrow_flight::sql::CommandGetDbSchemas {
183+
catalog: Some("datafusion".to_string()),
184+
db_schema_filter_pattern: None,
185+
})
186+
.await
187+
.expect("GetDbSchemas should succeed");
188+
189+
let ticket = flight_info
190+
.endpoint
191+
.first()
192+
.expect("Should have endpoint")
193+
.ticket
194+
.clone()
195+
.expect("Should have ticket");
196+
197+
let mut stream = client.do_get(ticket).await.expect("do_get should succeed");
198+
199+
let mut batches = Vec::new();
200+
while let Some(batch) = stream.try_next().await.expect("Stream should work") {
201+
batches.push(batch);
202+
}
203+
204+
assert!(!batches.is_empty(), "Should have schema results");
205+
}
206+
207+
#[tokio::test]
208+
async fn test_get_tables() {
209+
let addr = "0.0.0.0:50065";
210+
let state = create_test_session();
211+
start_test_server(addr.to_string(), state).await;
212+
213+
let mut client = create_test_client(&format!("http://{}", addr)).await;
214+
215+
let flight_info = client
216+
.get_tables(arrow_flight::sql::CommandGetTables {
217+
catalog: Some("datafusion".to_string()),
218+
db_schema_filter_pattern: None,
219+
table_name_filter_pattern: None,
220+
table_types: vec![],
221+
include_schema: true,
222+
})
223+
.await
224+
.expect("GetTables should succeed");
225+
226+
let ticket = flight_info
227+
.endpoint
228+
.first()
229+
.expect("Should have endpoint")
230+
.ticket
231+
.clone()
232+
.expect("Should have ticket");
233+
234+
let mut stream = client.do_get(ticket).await.expect("do_get should succeed");
235+
236+
let mut batches = Vec::new();
237+
while let Some(batch) = stream.try_next().await.expect("Stream should work") {
238+
batches.push(batch);
239+
}
240+
241+
assert!(!batches.is_empty(), "Should have table results");
242+
243+
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
244+
assert!(total_rows > 0, "Should have at least one table");
245+
}
246+
247+
#[tokio::test]
248+
async fn test_invalid_query() {
249+
let addr = "0.0.0.0:50066";
250+
let state = create_test_session();
251+
start_test_server(addr.to_string(), state).await;
252+
253+
let mut client = create_test_client(&format!("http://{}", addr)).await;
254+
255+
let result = client
256+
.execute("SELECT * FROM nonexistent_table".to_string(), None)
257+
.await;
258+
259+
assert!(result.is_err(), "Query should fail for nonexistent table");
260+
}
261+
262+
#[tokio::test]
263+
async fn test_query_with_aggregation() {
264+
let addr = "0.0.0.0:50067";
265+
let state = create_test_session();
266+
start_test_server(addr.to_string(), state).await;
267+
268+
let mut client = create_test_client(&format!("http://{}", addr)).await;
269+
270+
let flight_info = client
271+
.execute("SELECT COUNT(*) as count FROM users".to_string(), None)
272+
.await
273+
.expect("Query should succeed");
274+
275+
let ticket = flight_info
276+
.endpoint
277+
.first()
278+
.expect("Should have endpoint")
279+
.ticket
280+
.clone()
281+
.expect("Should have ticket");
282+
283+
let mut stream = client.do_get(ticket).await.expect("do_get should succeed");
284+
285+
let mut batches = Vec::new();
286+
while let Some(batch) = stream.try_next().await.expect("Stream should work") {
287+
batches.push(batch);
288+
}
289+
290+
assert!(!batches.is_empty(), "Should have result batches");
291+
292+
let first_batch = &batches[0];
293+
assert_eq!(first_batch.num_columns(), 1);
294+
assert_eq!(first_batch.schema().field(0).name(), "count");
295+
}
296+
297+
#[tokio::test]
298+
async fn test_query_with_join() {
299+
let addr = "0.0.0.0:50068";
300+
let state = create_test_session();
301+
start_test_server(addr.to_string(), state).await;
302+
303+
let mut client = create_test_client(&format!("http://{}", addr)).await;
304+
305+
let flight_info = client
306+
.execute(
307+
r#"
308+
SELECT u.id, u.name, o.order_id
309+
FROM users u
310+
JOIN orders o
311+
ON u.id = o.user_id "#
312+
.to_string(),
313+
None,
314+
)
315+
.await
316+
.expect("Join query should succeed");
317+
318+
let ticket = flight_info.endpoint[0].ticket.clone().unwrap();
319+
let mut stream = client.do_get(ticket).await.expect("do_get should succeed");
320+
321+
let mut batches = Vec::new();
322+
while let Some(batch) = stream.try_next().await.expect("Stream should work") {
323+
batches.push(batch);
324+
}
325+
326+
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
327+
assert_eq!(total_rows, 4, "Should have 4 rows from join");
328+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use datafusion_flight_sql_server::state::QueryHandle;
2+
3+
#[test]
4+
fn test_query_handle_with_complex_sql() {
5+
let query = r#"
6+
SELECT
7+
a.id,
8+
a.name,
9+
COUNT(b.order_id) as order_count
10+
FROM customers a
11+
LEFT JOIN orders b ON a.id = b.customer_id
12+
WHERE a.created_at > $1 AND a.status = $2
13+
GROUP BY a.id, a.name
14+
HAVING COUNT(b.order_id) > $3
15+
ORDER BY order_count DESC
16+
LIMIT 100
17+
"#
18+
.to_string();
19+
20+
let handle = QueryHandle::new(query.clone(), None);
21+
22+
let encoded = handle.clone().encode();
23+
let decoded = QueryHandle::try_decode(encoded).expect("Should decode");
24+
25+
assert_eq!(decoded.query(), query);
26+
}
27+
28+
#[test]
29+
fn test_query_handle_empty_query() {
30+
let query = String::new();
31+
let handle = QueryHandle::new(query.clone(), None);
32+
33+
let encoded = handle.clone().encode();
34+
let decoded = QueryHandle::try_decode(encoded).expect("Should decode");
35+
36+
assert_eq!(decoded.query(), "");
37+
}

0 commit comments

Comments
 (0)