|
| 1 | +use clap::{Parser, ValueEnum}; |
| 2 | +use datafusion::error::Result; |
| 3 | +use datafusion_distributed::test_utils::{ |
| 4 | + fuzz::{FuzzDB, SessionConfig}, |
| 5 | + tpcds::{ensure_tpcds_data, get_queries_dir, register_available_tpcds_tables}, |
| 6 | +}; |
| 7 | +use std::fs; |
| 8 | +use std::process; |
| 9 | +use std::time::Instant; |
| 10 | + |
| 11 | +#[derive(Parser)] |
| 12 | +#[command(author, version, about, long_about = None)] |
| 13 | +#[command(name = "fuzz")] |
| 14 | +#[command(about = "Fuzz test distributed DataFusion with various workloads")] |
| 15 | +struct Cli { |
| 16 | + /// Number of workers to use (must be >= 1) |
| 17 | + #[arg(short, long, default_value_t = 3)] |
| 18 | + workers: usize, |
| 19 | + |
| 20 | + /// Workload to run |
| 21 | + #[arg(long, value_enum, default_value_t = Workload::Tpcds)] |
| 22 | + workload: Workload, |
| 23 | + |
| 24 | + /// Scale factor for TPCDS data generation |
| 25 | + #[arg(short, long, default_value = "0.01")] |
| 26 | + scale_factor: String, |
| 27 | + |
| 28 | + /// Generate data even if it already exists |
| 29 | + #[arg(long, default_value_t = false)] |
| 30 | + force_regenerate: bool, |
| 31 | + |
| 32 | + /// Run only specific queries (comma-separated list, e.g., "q1,q5,q10") |
| 33 | + #[arg(long)] |
| 34 | + queries: Option<String>, |
| 35 | + |
| 36 | + /// Verbose output |
| 37 | + #[arg(short, long, default_value_t = false)] |
| 38 | + verbose: bool, |
| 39 | +} |
| 40 | + |
| 41 | +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] |
| 42 | +enum Workload { |
| 43 | + /// TPC-DS benchmark queries |
| 44 | + Tpcds, |
| 45 | +} |
| 46 | + |
| 47 | +#[tokio::main] |
| 48 | +async fn main() -> Result<()> { |
| 49 | + let cli = Cli::parse(); |
| 50 | + |
| 51 | + // Validate arguments |
| 52 | + if let Err(e) = validate_args(&cli) { |
| 53 | + eprintln!("Error: {}", e); |
| 54 | + process::exit(1); |
| 55 | + } |
| 56 | + |
| 57 | + println!("🚀 Starting DataFusion distributed fuzz testing"); |
| 58 | + println!(" Workers: {}", cli.workers); |
| 59 | + println!(" Workload: {:?}", cli.workload); |
| 60 | + println!(" Scale factor: {}", cli.scale_factor); |
| 61 | + println!(); |
| 62 | + |
| 63 | + match cli.workload { |
| 64 | + Workload::Tpcds => { |
| 65 | + if let Err(e) = run_tpcds_workload(cli).await { |
| 66 | + eprintln!("❌ Fuzz testing failed: {}", e); |
| 67 | + process::exit(1); |
| 68 | + } |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + println!("✅ All fuzz tests passed!"); |
| 73 | + Ok(()) |
| 74 | +} |
| 75 | + |
| 76 | +/// Validate command line arguments |
| 77 | +fn validate_args(cli: &Cli) -> std::result::Result<(), String> { |
| 78 | + if cli.workers < 1 { |
| 79 | + return Err("Number of workers must be >= 1".to_string()); |
| 80 | + } |
| 81 | + |
| 82 | + // Validate scale factor is numeric (decimal allowed) |
| 83 | + if cli.scale_factor.parse::<f64>().is_err() { |
| 84 | + return Err(format!("Scale factor '{}' is not a valid number", cli.scale_factor)); |
| 85 | + } |
| 86 | + |
| 87 | + Ok(()) |
| 88 | +} |
| 89 | + |
| 90 | +/// Run the TPC-DS workload |
| 91 | +async fn run_tpcds_workload(cli: Cli) -> Result<()> { |
| 92 | + println!("📊 Running TPC-DS workload..."); |
| 93 | + |
| 94 | + // Ensure TPCDS data is available |
| 95 | + println!("🔧 Ensuring TPC-DS data is available (scale factor: {})...", cli.scale_factor); |
| 96 | + if let Err(e) = ensure_tpcds_data(Some(&cli.scale_factor), cli.force_regenerate) { |
| 97 | + eprintln!("⚠️ Warning: Failed to ensure TPCDS data: {}", e); |
| 98 | + println!(" Continuing anyway - will skip queries that require missing tables"); |
| 99 | + } |
| 100 | + |
| 101 | + // Create FuzzDB |
| 102 | + println!("⚙️ Setting up distributed session with {} workers...", cli.workers); |
| 103 | + let config = SessionConfig { |
| 104 | + num_workers: cli.workers, |
| 105 | + tasks_per_file: 4, // Use reasonable defaults |
| 106 | + cardinality_task_count_factor: 4, |
| 107 | + target_partitions: cli.workers * 2, |
| 108 | + }; |
| 109 | + |
| 110 | + let fuzz_db = match FuzzDB::new_with_config(config, |ctx| { |
| 111 | + let result = tokio::task::block_in_place(|| { |
| 112 | + tokio::runtime::Handle::current().block_on(async { |
| 113 | + let (registered_tables, missing_tables) = register_available_tpcds_tables(ctx, None).await?; |
| 114 | + |
| 115 | + if !missing_tables.is_empty() { |
| 116 | + println!("Missing TPCDS tables: {:?}", missing_tables); |
| 117 | + } |
| 118 | + |
| 119 | + Ok::<Vec<String>, datafusion::error::DataFusionError>(registered_tables) |
| 120 | + }) |
| 121 | + }); |
| 122 | + result |
| 123 | + }).await { |
| 124 | + Ok(db) => { |
| 125 | + if db.registered_tables.is_empty() { |
| 126 | + println!("⚠️ Warning: No tables were registered. Some queries may fail."); |
| 127 | + } else { |
| 128 | + println!("✅ Successfully registered {} TPCDS tables", db.registered_tables.len()); |
| 129 | + if cli.verbose { |
| 130 | + println!(" Registered tables: {:?}", db.registered_tables); |
| 131 | + } |
| 132 | + } |
| 133 | + db |
| 134 | + } |
| 135 | + Err(e) => { |
| 136 | + return Err(datafusion::error::DataFusionError::Execution(format!( |
| 137 | + "Failed to create FuzzDB: {}", e |
| 138 | + ))); |
| 139 | + } |
| 140 | + }; |
| 141 | + |
| 142 | + // Discover queries |
| 143 | + let queries = discover_tpcds_queries(&cli)?; |
| 144 | + println!("🔍 Found {} TPC-DS queries to execute", queries.len()); |
| 145 | + |
| 146 | + if queries.is_empty() { |
| 147 | + return Err(datafusion::error::DataFusionError::Execution( |
| 148 | + "No queries found to execute".to_string() |
| 149 | + )); |
| 150 | + } |
| 151 | + |
| 152 | + // Run each query |
| 153 | + let mut successful = 0; |
| 154 | + let mut failed = 0; |
| 155 | + let start_time = Instant::now(); |
| 156 | + |
| 157 | + for (query_name, query_sql) in queries { |
| 158 | + if cli.verbose { |
| 159 | + println!(); |
| 160 | + println!("🔄 Executing query: {}", query_name); |
| 161 | + println!("Executing fuzz query: {}", query_sql.trim()); |
| 162 | + println!(); |
| 163 | + } |
| 164 | + |
| 165 | + let query_start = Instant::now(); |
| 166 | + match fuzz_db.run(&query_sql).await { |
| 167 | + Ok(results) => { |
| 168 | + let query_duration = query_start.elapsed(); |
| 169 | + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); |
| 170 | + successful += 1; |
| 171 | + println!("✅ {} completed successfully", query_name); |
| 172 | + if cli.verbose { |
| 173 | + println!(" Duration: {:?}", query_duration); |
| 174 | + println!(" Results: {} rows in {} batch(es)", total_rows, results.len()); |
| 175 | + } |
| 176 | + } |
| 177 | + Err(e) => { |
| 178 | + failed += 1; |
| 179 | + println!("❌ {} failed: {}", query_name, e); |
| 180 | + if cli.verbose { |
| 181 | + println!(" Error details: {:?}", e); |
| 182 | + } |
| 183 | + } |
| 184 | + } |
| 185 | + } |
| 186 | + |
| 187 | + let total_duration = start_time.elapsed(); |
| 188 | + println!(); |
| 189 | + println!("📈 Summary:"); |
| 190 | + println!(" Total queries: {}", successful + failed); |
| 191 | + println!(" Successful: {}", successful); |
| 192 | + println!(" Failed: {}", failed); |
| 193 | + println!(" Total duration: {:?}", total_duration); |
| 194 | + |
| 195 | + if failed > 0 { |
| 196 | + return Err(datafusion::error::DataFusionError::Execution(format!( |
| 197 | + "{} out of {} queries failed", failed, successful + failed |
| 198 | + ))); |
| 199 | + } |
| 200 | + |
| 201 | + Ok(()) |
| 202 | +} |
| 203 | + |
| 204 | +/// Discover TPC-DS queries from the queries directory |
| 205 | +fn discover_tpcds_queries(cli: &Cli) -> Result<Vec<(String, String)>> { |
| 206 | + let queries_dir = get_queries_dir(); |
| 207 | + |
| 208 | + if !queries_dir.exists() { |
| 209 | + return Err(datafusion::error::DataFusionError::Execution(format!( |
| 210 | + "TPC-DS queries directory not found: {}", queries_dir.display() |
| 211 | + ))); |
| 212 | + } |
| 213 | + |
| 214 | + let mut queries = Vec::new(); |
| 215 | + let entries = fs::read_dir(&queries_dir).map_err(|e| { |
| 216 | + datafusion::error::DataFusionError::Execution(format!( |
| 217 | + "Failed to read queries directory {}: {}", queries_dir.display(), e |
| 218 | + )) |
| 219 | + })?; |
| 220 | + |
| 221 | + // Get list of specific queries if provided |
| 222 | + let specific_queries: Option<std::collections::HashSet<String>> = cli.queries.as_ref().map(|q| { |
| 223 | + q.split(',') |
| 224 | + .map(|s| s.trim().to_lowercase()) |
| 225 | + .collect() |
| 226 | + }); |
| 227 | + |
| 228 | + for entry in entries { |
| 229 | + let entry = entry.map_err(|e| { |
| 230 | + datafusion::error::DataFusionError::Execution(format!( |
| 231 | + "Failed to read directory entry: {}", e |
| 232 | + )) |
| 233 | + })?; |
| 234 | + |
| 235 | + let path = entry.path(); |
| 236 | + if path.extension().and_then(|s| s.to_str()) == Some("sql") { |
| 237 | + if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { |
| 238 | + let query_name = file_stem.to_lowercase(); |
| 239 | + |
| 240 | + // Filter specific queries if requested |
| 241 | + if let Some(ref specific) = specific_queries { |
| 242 | + if !specific.contains(&query_name) { |
| 243 | + continue; |
| 244 | + } |
| 245 | + } |
| 246 | + |
| 247 | + let query_sql = fs::read_to_string(&path).map_err(|e| { |
| 248 | + datafusion::error::DataFusionError::Execution(format!( |
| 249 | + "Failed to read query file {}: {}", path.display(), e |
| 250 | + )) |
| 251 | + })?; |
| 252 | + |
| 253 | + queries.push((query_name, query_sql)); |
| 254 | + } |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + // Sort queries for consistent execution order |
| 259 | + queries.sort_by(|a, b| { |
| 260 | + // Extract query number for natural sorting (q1, q2, ..., q10, etc.) |
| 261 | + let extract_num = |name: &str| { |
| 262 | + name.strip_prefix('q') |
| 263 | + .and_then(|s| s.parse::<u32>().ok()) |
| 264 | + .unwrap_or(0) |
| 265 | + }; |
| 266 | + |
| 267 | + extract_num(&a.0).cmp(&extract_num(&b.0)) |
| 268 | + }); |
| 269 | + |
| 270 | + Ok(queries) |
| 271 | +} |
0 commit comments