|
| 1 | +use axum::{Json, Router, extract::Query, http::StatusCode, routing::get}; |
| 2 | +use ballista::datafusion::common::instant::Instant; |
| 3 | +use ballista::datafusion::execution::SessionStateBuilder; |
| 4 | +use ballista::datafusion::execution::runtime_env::RuntimeEnv; |
| 5 | +use ballista::datafusion::physical_plan::displayable; |
| 6 | +use ballista::datafusion::physical_plan::execute_stream; |
| 7 | +use ballista::datafusion::prelude::SessionConfig; |
| 8 | +use ballista::datafusion::prelude::SessionContext; |
| 9 | +use ballista::prelude::*; |
| 10 | +use futures::{StreamExt, TryFutureExt}; |
| 11 | +use log::{error, info}; |
| 12 | +use object_store::aws::AmazonS3Builder; |
| 13 | +use serde::Serialize; |
| 14 | +use std::collections::HashMap; |
| 15 | +use std::error::Error; |
| 16 | +use std::fmt::Display; |
| 17 | +use std::sync::Arc; |
| 18 | +use structopt::StructOpt; |
| 19 | +use url::Url; |
| 20 | + |
| 21 | +#[derive(Serialize)] |
| 22 | +struct QueryResult { |
| 23 | + plan: String, |
| 24 | + count: usize, |
| 25 | +} |
| 26 | + |
| 27 | +#[derive(Debug, StructOpt, Clone)] |
| 28 | +#[structopt(about = "worker spawn command")] |
| 29 | +struct Cmd { |
| 30 | + /// The bucket name. |
| 31 | + #[structopt(long, default_value = "datafusion-distributed-benchmarks")] |
| 32 | + bucket: String, |
| 33 | +} |
| 34 | + |
| 35 | +#[tokio::main] |
| 36 | +async fn main() -> Result<(), Box<dyn Error>> { |
| 37 | + env_logger::builder() |
| 38 | + .filter_level(log::LevelFilter::Info) |
| 39 | + .parse_default_env() |
| 40 | + .init(); |
| 41 | + |
| 42 | + let cmd = Cmd::from_args(); |
| 43 | + |
| 44 | + const LISTENER_ADDR: &str = "0.0.0.0:9002"; |
| 45 | + |
| 46 | + info!("Starting HTTP listener on {LISTENER_ADDR}..."); |
| 47 | + let listener = tokio::net::TcpListener::bind(LISTENER_ADDR).await?; |
| 48 | + |
| 49 | + // Register S3 object store |
| 50 | + let s3_url = Url::parse(&format!("s3://{}", cmd.bucket))?; |
| 51 | + |
| 52 | + info!("Building shared SessionContext for the whole lifetime of the HTTP listener..."); |
| 53 | + let s3 = Arc::new( |
| 54 | + AmazonS3Builder::from_env() |
| 55 | + .with_bucket_name(s3_url.host().unwrap().to_string()) |
| 56 | + .build()?, |
| 57 | + ); |
| 58 | + let runtime_env = Arc::new(RuntimeEnv::default()); |
| 59 | + runtime_env.register_object_store(&s3_url, s3); |
| 60 | + |
| 61 | + let config = SessionConfig::new_with_ballista().with_ballista_job_name("Benchmarks"); |
| 62 | + |
| 63 | + let state = SessionStateBuilder::new() |
| 64 | + .with_config(config) |
| 65 | + .with_default_features() |
| 66 | + .with_runtime_env(Arc::clone(&runtime_env)) |
| 67 | + .build(); |
| 68 | + let ctx = SessionContext::remote_with_state("df://localhost:50050", state).await?; |
| 69 | + |
| 70 | + let http_server = axum::serve( |
| 71 | + listener, |
| 72 | + Router::new().route( |
| 73 | + "/", |
| 74 | + get(move |Query(params): Query<HashMap<String, String>>| { |
| 75 | + let ctx = ctx.clone(); |
| 76 | + |
| 77 | + async move { |
| 78 | + let sql = params.get("sql").ok_or(err("Missing 'sql' parameter"))?; |
| 79 | + |
| 80 | + let mut df_opt = None; |
| 81 | + for sql in sql.split(";") { |
| 82 | + if sql.trim().is_empty() { |
| 83 | + continue; |
| 84 | + } |
| 85 | + let df = ctx.sql(sql).await.map_err(err)?; |
| 86 | + df_opt = Some(df); |
| 87 | + } |
| 88 | + let Some(df) = df_opt else { |
| 89 | + return Err(err("Empty 'sql' parameter")); |
| 90 | + }; |
| 91 | + |
| 92 | + let start = Instant::now(); |
| 93 | + |
| 94 | + info!("Executing query..."); |
| 95 | + let physical = df.create_physical_plan().await.map_err(err)?; |
| 96 | + let mut stream = |
| 97 | + execute_stream(physical.clone(), ctx.task_ctx()).map_err(err)?; |
| 98 | + let mut count = 0; |
| 99 | + while let Some(batch) = stream.next().await { |
| 100 | + count += batch.map_err(err)?.num_rows(); |
| 101 | + info!("Gathered {count} rows, query still in progress..") |
| 102 | + } |
| 103 | + let plan = displayable(physical.as_ref()).indent(true).to_string(); |
| 104 | + let elapsed = start.elapsed(); |
| 105 | + let ms = elapsed.as_secs_f64() * 1000.0; |
| 106 | + info!("Returned {count} rows in {ms} ms"); |
| 107 | + |
| 108 | + Ok::<_, (StatusCode, String)>(Json(QueryResult { count, plan })) |
| 109 | + } |
| 110 | + .inspect_err(|(_, msg)| { |
| 111 | + error!("Error executing query: {msg}"); |
| 112 | + }) |
| 113 | + }), |
| 114 | + ), |
| 115 | + ); |
| 116 | + |
| 117 | + info!("Started listener HTTP server in {LISTENER_ADDR}"); |
| 118 | + http_server.await?; |
| 119 | + Ok(()) |
| 120 | +} |
| 121 | + |
| 122 | +fn err(s: impl Display) -> (StatusCode, String) { |
| 123 | + (StatusCode::INTERNAL_SERVER_ERROR, s.to_string()) |
| 124 | +} |
0 commit comments