Skip to content

Commit 1209188

Browse files
committed
Make the number of workers flexible
1 parent 2099424 commit 1209188

File tree

3 files changed

+155
-107
lines changed

3 files changed

+155
-107
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,14 @@ cargo test -- --nocapture
157157
Run comprehensive TPC-H validation tests that compare distributed DataFusion against regular DataFusion. No prerequisites needed - the tests handle everything automatically!
158158

159159
```bash
160-
# Run all TPC-H validation tests (manual - excluded from cargo test for speed)
160+
# Run all TPC-H validation tests
161161
cargo test --test tpch_validation test_tpch_validation_all_queries -- --ignored --nocapture
162162

163163
# Run single query test for debugging
164164
cargo test --test tpch_validation test_tpch_validation_single_query -- --ignored --nocapture
165165
```
166166

167-
**Note:** TPC-H validation tests are marked with `#[ignore]` to keep `cargo test` fast for development. Run them manually when needed for validation.
167+
**Note:** TPC-H validation tests are annotated with #[ignore] to avoid slowing down `cargo test` during development. They're included in the CI pipeline and can be run manually when needed.
168168

169169
## Usage
170170

tests/common/mod.rs

Lines changed: 147 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ use datafusion::prelude::*;
1919

2020
/// Test configuration constants
2121
pub const PROXY_PORT: u16 = 40400;
22-
pub const WORKER_PORT_1: u16 = 40401;
23-
pub const WORKER_PORT_2: u16 = 40402;
22+
pub const FIRST_WORKER_PORT: u16 = 40401;
23+
pub const NUM_WORKERS: usize = 2;
2424
pub const TPCH_DATA_PATH: &str = "/tmp/tpch_s1";
2525
pub const QUERY_PATH: &str = "./tpch/queries";
2626
pub const FLOATING_POINT_TOLERANCE: f64 = 1e-5;
@@ -44,6 +44,16 @@ pub fn should_be_verbose(_query_name: &str) -> bool {
4444
VERBOSE_COMPARISON
4545
}
4646

47+
/// Get worker port for a given worker index (0-based)
48+
pub fn get_worker_port(worker_index: usize) -> u16 {
49+
FIRST_WORKER_PORT + worker_index as u16
50+
}
51+
52+
/// Get all worker ports
53+
pub fn get_all_worker_ports() -> Vec<u16> {
54+
(0..NUM_WORKERS).map(get_worker_port).collect()
55+
}
56+
4757
/// Validation results
4858
#[derive(Debug)]
4959
pub struct ValidationResults {
@@ -70,15 +80,19 @@ pub struct ComparisonResult {
7080
pub struct ClusterManager {
7181
pub proxy_process: Option<Child>,
7282
pub worker_processes: Vec<Child>,
83+
pub worker_ports: Vec<u16>,
7384
pub is_running: Arc<AtomicBool>,
85+
pub reusable_python_script: Option<String>,
7486
}
7587

7688
impl ClusterManager {
7789
pub fn new() -> Self {
7890
Self {
7991
proxy_process: None,
8092
worker_processes: Vec::new(),
93+
worker_ports: get_all_worker_ports(),
8194
is_running: Arc::new(AtomicBool::new(false)),
95+
reusable_python_script: None,
8296
}
8397
}
8498

@@ -163,6 +177,67 @@ impl ClusterManager {
163177
}
164178
}
165179

180+
/// Generate the reusable Python script for executing distributed queries
181+
pub fn generate_reusable_python_script(&mut self) -> Result<(), Box<dyn std::error::Error>> {
182+
let python_script = format!(
183+
r#"
184+
import adbc_driver_flightsql.dbapi as dbapi
185+
import duckdb
186+
import sys
187+
import time
188+
189+
if len(sys.argv) != 2:
190+
print("Usage: python script.py <sql_file_path>", file=sys.stderr)
191+
sys.exit(1)
192+
193+
sql_file_path = sys.argv[1]
194+
195+
try:
196+
# Connect to the distributed cluster
197+
conn = dbapi.connect("grpc://localhost:{}")
198+
cur = conn.cursor()
199+
200+
# Read and execute the SQL query
201+
with open(sql_file_path, 'r') as f:
202+
sql = f.read()
203+
204+
start_time = time.time()
205+
cur.execute(sql)
206+
reader = cur.fetch_record_batch()
207+
208+
# Convert results to string using DuckDB for consistent formatting
209+
results = duckdb.sql("select * from reader")
210+
211+
# Fetch all results before closing connections
212+
all_rows = results.fetchall()
213+
214+
# Close cursor and connection properly
215+
cur.close()
216+
conn.close()
217+
218+
# Print results in a format that can be parsed
219+
print("--- RESULTS START ---")
220+
for row in all_rows:
221+
print("|" + "|".join([str(cell) if cell is not None else "NULL" for cell in row]) + "|")
222+
print("--- RESULTS END ---")
223+
224+
except Exception as e:
225+
print(f"Error executing distributed query: {{str(e)}}", file=sys.stderr)
226+
sys.exit(1)
227+
"#,
228+
self.get_proxy_address().split(':').last().unwrap()
229+
);
230+
231+
let script_path = Self::write_temp_file(
232+
"reusable_distributed_query.py",
233+
&python_script,
234+
"reusable Python script for distributed queries",
235+
)?;
236+
237+
self.reusable_python_script = Some(script_path);
238+
Ok(())
239+
}
240+
166241
/// Setup everything needed for tests
167242
pub async fn setup() -> Result<Self, Box<dyn std::error::Error>> {
168243
let mut cluster = ClusterManager::new();
@@ -185,14 +260,18 @@ impl ClusterManager {
185260
// Step 6: Wait for cluster to be ready
186261
cluster.wait_for_cluster_ready().await?;
187262

263+
// Step 7: Generate reusable Python script
264+
cluster.generate_reusable_python_script()?;
265+
188266
Ok(cluster)
189267
}
190268

191269
/// Kill any existing processes on our specific test ports only
192270
pub fn kill_existing_processes(&self) -> Result<(), Box<dyn std::error::Error>> {
193271
println!("🧹 Cleaning up existing processes on test ports...");
194272

195-
let ports = [PROXY_PORT, WORKER_PORT_1, WORKER_PORT_2];
273+
let mut ports = vec![PROXY_PORT];
274+
ports.extend(&self.worker_ports);
196275

197276
for port in &ports {
198277
// Find and kill processes using lsof
@@ -370,34 +449,29 @@ impl ClusterManager {
370449
let tpch_views = "CREATE VIEW revenue0 (supplier_no, total_revenue) AS SELECT l_suppkey, sum(l_extendedprice * (1 - l_discount)) FROM lineitem WHERE l_shipdate >= date '1996-08-01' AND l_shipdate < date '1996-08-01' + interval '3' month GROUP BY l_suppkey";
371450

372451
// Start workers first
373-
println!(" Starting worker 1 on port {}...", WORKER_PORT_1);
374-
let worker1 = Self::spawn_process(
375-
binary_path_str,
376-
&["--mode", "worker", "--port", &WORKER_PORT_1.to_string()],
377-
&[("DFRAY_TABLES", &tpch_tables), ("DFRAY_VIEWS", &tpch_views)],
378-
"start worker 1",
379-
)?;
380-
381-
println!(" Starting worker 2 on port {}...", WORKER_PORT_2);
382-
let worker2 = Self::spawn_process(
383-
binary_path_str,
384-
&["--mode", "worker", "--port", &WORKER_PORT_2.to_string()],
385-
&[("DFRAY_TABLES", &tpch_tables), ("DFRAY_VIEWS", &tpch_views)],
386-
"start worker 2",
387-
)?;
388-
389-
self.worker_processes.push(worker1);
390-
self.worker_processes.push(worker2);
452+
for (i, &port) in self.worker_ports.iter().enumerate() {
453+
println!(" Starting worker {} on port {}...", i + 1, port);
454+
let worker = Self::spawn_process(
455+
binary_path_str,
456+
&["--mode", "worker", "--port", &port.to_string()],
457+
&[("DFRAY_TABLES", &tpch_tables), ("DFRAY_VIEWS", &tpch_views)],
458+
&format!("start worker {}", i + 1),
459+
)?;
460+
self.worker_processes.push(worker);
461+
}
391462

392463
// Give workers time to start
393464
thread::sleep(Duration::from_secs(WORKER_STARTUP_WAIT_SECONDS));
394465

395466
// Start proxy
396467
println!(" Starting proxy on port {}...", PROXY_PORT);
397-
let worker_addresses = format!(
398-
"worker1/127.0.0.1:{},worker2/127.0.0.1:{}",
399-
WORKER_PORT_1, WORKER_PORT_2
400-
);
468+
let worker_addresses = self
469+
.worker_ports
470+
.iter()
471+
.enumerate()
472+
.map(|(i, &port)| format!("worker{}/127.0.0.1:{}", i + 1, port))
473+
.collect::<Vec<_>>()
474+
.join(",");
401475
let proxy = Self::spawn_process(
402476
binary_path_str,
403477
&["--mode", "proxy", "--port", &PROXY_PORT.to_string()],
@@ -533,7 +607,7 @@ impl ClusterManager {
533607

534608
/// Check if we can connect to workers
535609
pub fn can_connect_to_workers(&self) -> bool {
536-
for port in &[WORKER_PORT_1, WORKER_PORT_2] {
610+
for &port in &self.worker_ports {
537611
if let Ok(stream) = std::net::TcpStream::connect(format!("127.0.0.1:{}", port)) {
538612
drop(stream);
539613
return true;
@@ -568,6 +642,11 @@ impl Drop for ClusterManager {
568642
self.is_running.store(false, Ordering::Relaxed);
569643
println!("✅ Cluster cleanup complete");
570644
}
645+
646+
// Clean up reusable Python script
647+
if let Some(script_path) = &self.reusable_python_script {
648+
Self::cleanup_temp_files(&[script_path]);
649+
}
571650
}
572651
}
573652

@@ -672,10 +751,12 @@ pub async fn execute_query_datafusion(
672751
) -> Result<(Vec<RecordBatch>, Duration), Box<dyn std::error::Error>> {
673752
let start_time = Instant::now();
674753

675-
let df = ctx
676-
.sql(sql)
677-
.await
678-
.map_err(|e| format!("DataFusion SQL parsing failed for {}: {}", query_name, e))?;
754+
let df = ctx.sql(sql).await.map_err(|e| {
755+
format!(
756+
"DataFusion SQL parsing/planning failed for {}: {}",
757+
query_name, e
758+
)
759+
})?;
679760

680761
let batches = df.collect().await.map_err(|e| {
681762
format!(
@@ -703,65 +784,20 @@ pub async fn execute_query_distributed(
703784
&format!("SQL for {}", query_name),
704785
)?;
705786

706-
// Create a temporary Python script to execute the query
707-
let python_script = format!(
708-
r#"
709-
import adbc_driver_flightsql.dbapi as dbapi
710-
import duckdb
711-
import sys
712-
import time
713-
714-
try:
715-
# Connect to the distributed cluster
716-
conn = dbapi.connect("grpc://localhost:{}")
717-
cur = conn.cursor()
718-
719-
# Read and execute the SQL query
720-
with open('{}', 'r') as f:
721-
sql = f.read()
722-
723-
start_time = time.time()
724-
cur.execute(sql)
725-
reader = cur.fetch_record_batch()
726-
727-
# Convert results to string using DuckDB for consistent formatting
728-
results = duckdb.sql("select * from reader")
729-
730-
# Fetch all results before closing connections
731-
all_rows = results.fetchall()
732-
733-
# Close cursor and connection properly
734-
cur.close()
735-
conn.close()
736-
737-
# Print results in a format that can be parsed
738-
print("--- RESULTS START ---")
739-
for row in all_rows:
740-
print("|" + "|".join([str(cell) if cell is not None else "NULL" for cell in row]) + "|")
741-
print("--- RESULTS END ---")
742-
743-
except Exception as e:
744-
print(f"Error executing distributed query: {{str(e)}}", file=sys.stderr)
745-
sys.exit(1)
746-
"#,
747-
cluster.get_proxy_address().split(':').last().unwrap(),
748-
temp_sql_file
749-
);
750-
751-
let temp_python_script = ClusterManager::write_temp_file(
752-
&format!("{}_query.py", query_name),
753-
&python_script,
754-
&format!("Python script for {}", query_name),
755-
)?;
787+
// Get the reusable Python script
788+
let python_script_path = cluster
789+
.reusable_python_script
790+
.as_ref()
791+
.ok_or("Reusable Python script not available")?;
756792

757-
// Execute using Python Flight SQL client
793+
// Execute using Python Flight SQL client with SQL file as argument
758794
let output = ClusterManager::run_python_command(
759-
&[&temp_python_script],
795+
&[python_script_path, &temp_sql_file],
760796
&format!("execute distributed query for {}", query_name),
761797
)?;
762798

763-
// Clean up temp files
764-
ClusterManager::cleanup_temp_files(&[&temp_sql_file, &temp_python_script]);
799+
// Clean up only the SQL temp file (keep the reusable Python script)
800+
ClusterManager::cleanup_temp_files(&[&temp_sql_file]);
765801

766802
let result = String::from_utf8_lossy(&output.stdout).to_string();
767803
let execution_time = start_time.elapsed();
@@ -1145,31 +1181,40 @@ pub async fn execute_single_query_validation(
11451181
}
11461182
};
11471183

1148-
// Execute with DataFusion
1149-
let (datafusion_batches, datafusion_time) =
1150-
match execute_query_datafusion(ctx, &sql, query_name).await {
1184+
// Execute with distributed system
1185+
let (distributed_output, distributed_time) =
1186+
match execute_query_distributed(cluster, &sql, query_name).await {
11511187
Ok(result) => result,
11521188
Err(e) => {
1153-
return create_error_comparison_result(
1154-
query_name,
1155-
format!("DataFusion execution failed: {}", e),
1156-
0,
1157-
Duration::default(),
1158-
);
1189+
return ComparisonResult {
1190+
query_name: query_name.to_string(),
1191+
matches: false,
1192+
row_count_datafusion: 0,
1193+
row_count_distributed: 0,
1194+
error_message: Some(format!("Distributed execution failed: {}", e)),
1195+
execution_time_datafusion: Duration::default(),
1196+
execution_time_distributed: Duration::default(),
1197+
};
11591198
}
11601199
};
11611200

1162-
// Execute with distributed system
1163-
let (distributed_output, distributed_time) =
1164-
match execute_query_distributed(cluster, &sql, query_name).await {
1201+
// Execute with DataFusion
1202+
let (datafusion_batches, datafusion_time) =
1203+
match execute_query_datafusion(ctx, &sql, query_name).await {
11651204
Ok(result) => result,
11661205
Err(e) => {
1167-
return create_error_comparison_result(
1168-
query_name,
1169-
format!("Distributed execution failed: {}", e),
1170-
datafusion_batches.iter().map(|b| b.num_rows()).sum(),
1171-
datafusion_time,
1172-
);
1206+
// Get estimated row count from distributed output
1207+
let distributed_row_count =
1208+
distributed_output_to_sorted_strings(&distributed_output).len();
1209+
return ComparisonResult {
1210+
query_name: query_name.to_string(),
1211+
matches: false,
1212+
row_count_datafusion: 0,
1213+
row_count_distributed: distributed_row_count,
1214+
error_message: Some(format!("DataFusion execution failed: {}", e)),
1215+
execution_time_datafusion: Duration::default(),
1216+
execution_time_distributed: distributed_time,
1217+
};
11731218
}
11741219
};
11751220

0 commit comments

Comments
 (0)