|
| 1 | +use arrow::util::pretty::pretty_format_batches; |
| 2 | +use async_trait::async_trait; |
| 3 | +use dashmap::{DashMap, Entry}; |
| 4 | +use datafusion::common::DataFusionError; |
| 5 | +use datafusion::execution::SessionStateBuilder; |
| 6 | +use datafusion::physical_plan::displayable; |
| 7 | +use datafusion::prelude::{ParquetReadOptions, SessionContext}; |
| 8 | +use datafusion_distributed::{ |
| 9 | + BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, |
| 10 | +}; |
| 11 | +use futures::TryStreamExt; |
| 12 | +use std::error::Error; |
| 13 | +use std::sync::Arc; |
| 14 | +use structopt::StructOpt; |
| 15 | +use tonic::transport::Channel; |
| 16 | +use url::Url; |
| 17 | + |
| 18 | +#[derive(StructOpt)] |
| 19 | +#[structopt(name = "run", about = "A localhost Distributed DataFusion runner")] |
| 20 | +struct Args { |
| 21 | + #[structopt()] |
| 22 | + query: String, |
| 23 | + |
| 24 | + // --cluster-ports 8080,8081,8082 |
| 25 | + #[structopt(long = "cluster-ports", use_delimiter = true)] |
| 26 | + cluster_ports: Vec<u16>, |
| 27 | + |
| 28 | + #[structopt(long)] |
| 29 | + explain: bool, |
| 30 | +} |
| 31 | + |
| 32 | +#[tokio::main] |
| 33 | +async fn main() -> Result<(), Box<dyn Error>> { |
| 34 | + let args = Args::from_args(); |
| 35 | + |
| 36 | + let localhost_resolver = LocalhostChannelResolver { |
| 37 | + ports: args.cluster_ports, |
| 38 | + cached: DashMap::new(), |
| 39 | + }; |
| 40 | + |
| 41 | + let state = SessionStateBuilder::new() |
| 42 | + .with_default_features() |
| 43 | + .with_distributed_channel_resolver(localhost_resolver) |
| 44 | + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule::new())) |
| 45 | + .build(); |
| 46 | + |
| 47 | + let ctx = SessionContext::from(state); |
| 48 | + |
| 49 | + ctx.register_parquet( |
| 50 | + "flights_1m", |
| 51 | + "testdata/flights-1m.parquet", |
| 52 | + ParquetReadOptions::default(), |
| 53 | + ) |
| 54 | + .await?; |
| 55 | + |
| 56 | + ctx.register_parquet( |
| 57 | + "weather", |
| 58 | + "testdata/weather.parquet", |
| 59 | + ParquetReadOptions::default(), |
| 60 | + ) |
| 61 | + .await?; |
| 62 | + |
| 63 | + let df = ctx.sql(&args.query).await?; |
| 64 | + if args.explain { |
| 65 | + let plan = df.create_physical_plan().await?; |
| 66 | + let display = displayable(plan.as_ref()).indent(true).to_string(); |
| 67 | + println!("{display}"); |
| 68 | + } else { |
| 69 | + let stream = df.execute_stream().await?; |
| 70 | + let batches = stream.try_collect::<Vec<_>>().await?; |
| 71 | + let formatted = pretty_format_batches(&batches)?; |
| 72 | + println!("{formatted}"); |
| 73 | + } |
| 74 | + Ok(()) |
| 75 | +} |
| 76 | + |
| 77 | +#[derive(Clone)] |
| 78 | +struct LocalhostChannelResolver { |
| 79 | + ports: Vec<u16>, |
| 80 | + cached: DashMap<Url, BoxCloneSyncChannel>, |
| 81 | +} |
| 82 | + |
| 83 | +#[async_trait] |
| 84 | +impl ChannelResolver for LocalhostChannelResolver { |
| 85 | + fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> { |
| 86 | + Ok(self |
| 87 | + .ports |
| 88 | + .iter() |
| 89 | + .map(|port| Url::parse(&format!("http://localhost:{port}")).unwrap()) |
| 90 | + .collect()) |
| 91 | + } |
| 92 | + |
| 93 | + async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> { |
| 94 | + match self.cached.entry(url.clone()) { |
| 95 | + Entry::Occupied(v) => Ok(v.get().clone()), |
| 96 | + Entry::Vacant(v) => { |
| 97 | + let channel = Channel::from_shared(url.to_string()) |
| 98 | + .unwrap() |
| 99 | + .connect_lazy(); |
| 100 | + let channel = BoxCloneSyncChannel::new(channel); |
| 101 | + v.insert(channel.clone()); |
| 102 | + Ok(channel) |
| 103 | + } |
| 104 | + } |
| 105 | + } |
| 106 | +} |
0 commit comments