diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f93d292..6d5411e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: - uses: actions/checkout@v4 - uses: ./.github/actions/setup - run: cargo test --features tpch --test tpch_validation_test - + format-check: runs-on: ubuntu-latest steps: @@ -48,3 +48,4 @@ jobs: with: components: rustfmt - run: cargo fmt --all -- --check + diff --git a/Cargo.lock b/Cargo.lock index f71585e..3f642dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -680,12 +680,52 @@ dependencies = [ "ansi_term", "atty", "bitflags 1.3.2", - "strsim", + "strsim 0.8.0", "textwrap", "unicode-width 0.1.14", "vec_map", ] +[[package]] +name = "clap" +version = "4.5.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c26d721170e0295f191a69bd9a1f93efcdb0aff38684b61ab5750468972e5f5" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75835f0c7bf681bfd05abe44e965760fea999a5286c6eb2d59883634fd02011a" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim 0.11.1", +] + +[[package]] +name = "clap_derive" +version = "4.5.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "clap_lex" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" + [[package]] name = "colorchoice" version = "1.0.4" @@ -765,6 +805,24 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -1581,6 +1639,27 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "distributed-datafusion-controller" +version = "0.1.0" +dependencies = [ + "arrow-flight", + "async-trait", + "clap 4.5.51", + "datafusion", + "datafusion-distributed", + "futures", + "log", + "moka", + "serde", + "serde_yaml", + "tempfile", + "tokio", + "tonic", + "tower", + "url", +] + [[package]] name = "either" version = "1.15.0" @@ -2154,7 +2233,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5" dependencies = [ "equivalent", - "hashbrown 0.15.4", + "hashbrown 0.16.0", ] [[package]] @@ -2372,9 +2451,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" [[package]] name = "lz4_flex" @@ -2444,6 +2523,24 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "moka" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8261cd88c312e0004c1d51baad2980c66528dfdb2bee62003e643a4d8f86b077" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "parking_lot", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "uuid", +] + [[package]] name = "num" version = "0.4.3" @@ -3064,6 +3161,19 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "sha2" version = "0.10.9" @@ -3199,13 +3309,19 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "structopt" version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c6b5c64445ba8094a6ab0c3cd2ad323e07171012d9c98b0b15651daf1787a10" dependencies = [ - "clap", + "clap 2.34.0", "lazy_static", "structopt-derive", ] @@ -3287,6 +3403,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tempfile" version = "3.20.0" @@ -3561,6 +3683,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 6b51920..22d6255 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,16 @@ [workspace] -members = ["benchmarks"] - +members = ["benchmarks", "datafusion_distributed_controller"] [workspace.dependencies] +arrow-flight = "56.1.0" +async-trait = "0.1.88" datafusion = { version = "50.0.0", default-features = false } datafusion-proto = { version = "50.0.0" } +tokio = { version = "1.46.1", features = ["full"] } +tonic = { version = "0.13.1", features = ["transport"] } +# Updated to 0.13.1 to match arrow-flight 56.1.0 +tower = "0.5.2" +url = "2.5.4" +futures = "0.3.31" [package] name = "datafusion-distributed" @@ -11,20 +18,19 @@ version = "0.1.0" edition = "2024" [dependencies] +arrow-flight = { workspace = true } +async-trait = { workspace = true } chrono = { version = "0.4.42" } datafusion = { workspace = true } datafusion-proto = { workspace = true } -arrow-flight = "56.1.0" arrow-select = "56.1.0" -async-trait = "0.1.88" -tokio = { version = "1.46.1", features = ["full"] } -# Updated to 0.13.1 to match arrow-flight 56.1.0 -tonic = { version = "0.13.1", features = ["transport"] } -tower = "0.5.2" +tokio = { workspace = true } +tonic = { workspace = true } +tower = { workspace = true } +url = { workspace = true } http = "1.3.1" itertools = "0.14.0" futures = "0.3.31" -url = "2.5.4" uuid = "1.17.0" delegate = "0.13.4" dashmap = "6.1.0" diff --git a/datafusion_distributed_controller/Cargo.toml b/datafusion_distributed_controller/Cargo.toml new file mode 100644 index 0000000..a1a24d9 --- /dev/null +++ b/datafusion_distributed_controller/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "distributed-datafusion-controller" +version = "0.1.0" +edition = "2024" + +[[bin]] +name = "ddf_test" +path = "src/main.rs" + +[[bin]] +name = "agent" +path = "src/bin/agent.rs" + +[dependencies] +clap = { version = "4.5.51", features = ["derive"] } +datafusion-distributed = { path = "..", features = ["integration"] } +moka = {version = "0.12.11", features = ["sync"]} +tokio = { workspace = true } +tonic = { workspace = true } +tower = { workspace = true } +url = { workspace = true } +futures = { workspace = true } +datafusion = { workspace = true } +arrow-flight = "56.1.0" +async-trait = "0.1.88" +serde = { version = "1.0", features = ["derive"] } +serde_yaml = "0.9" +tempfile = "3.8" +log = "0.4.28" + + diff --git a/datafusion_distributed_controller/src/bin/agent.rs b/datafusion_distributed_controller/src/bin/agent.rs new file mode 100644 index 0000000..735baa9 --- /dev/null +++ b/datafusion_distributed_controller/src/bin/agent.rs @@ -0,0 +1,447 @@ +use clap::Args as ClapArgs; +use clap::{Parser, Subcommand}; +use datafusion::common::{DataFusionError, Result, internal_datafusion_err}; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::physical_plan::collect; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_distributed::{ + ArrowFlightEndpoint, DistributedExt, DistributedPhysicalOptimizerRule, + DistributedSessionBuilder, DistributedSessionBuilderContext, display_plan_ascii, + explain_analyze, +}; +use distributed_datafusion_controller::channel_resolver::DistributedChannelResolver; +use log::{info, warn}; +use serde::{Deserialize, Serialize}; +use std::fmt::Display; +use std::fs; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::PathBuf; +use std::sync::Arc; +use tonic::async_trait; +use tonic::transport::Server; + +#[derive(Parser)] +#[command(name = "DataFusion Distributed Agent")] +#[command(about = "Runs Code on a Remote Machines")] +struct Args { + #[command(subcommand)] + command: Commands, +} + +#[derive(Debug, Serialize, Deserialize)] +/// YAML containing all the endpoints (IP + port pairs) of the datafusion distributed cluster. +struct ClusterConfig { + endpoints: Vec, +} + +const DEFAULT_CONFIG_FILE: &str = "/etc/datafusion-distributed/config.yaml"; +const DEFAULT_FIXTURES_DIR: &str = "/var/lib/datafusion-distributed/data"; + +#[derive(ClapArgs, Debug)] +struct CommonArgs { + #[arg(long, default_value = DEFAULT_CONFIG_FILE, help = "YAML config file containing cluster IPs")] + config_path: PathBuf, + + #[arg(long, default_value = DEFAULT_FIXTURES_DIR, help = "Directory containing data files")] + fixtures_dir: PathBuf, +} + +#[derive(ClapArgs, Debug)] +struct QueryArgs { + #[arg(long, help = "SQL Query")] + sql: String, + #[arg( + short, + long, + default_value = "2", + help = "Number of shuffle tasks to use" + )] + num_shuffle_tasks: usize, + #[arg( + short, + long, + default_value = "2", + help = "Number of coalesce tasks to use" + )] + num_coalesce_tasks: usize, +} + +#[derive(Subcommand)] +enum Commands { + #[command(about = "Start a datafusion distributed worker on the provided port")] + StartServer { + #[command(flatten)] + common: CommonArgs, + #[arg(short, long)] + port: u16, + }, + #[command(about = "Run a SQL query against distributed cluster")] + Query { + #[command(flatten)] + common: CommonArgs, + #[command(flatten)] + query: QueryArgs, + }, + #[command(about = "Show execution plan for a SQL query")] + Explain { + #[command(flatten)] + common: CommonArgs, + #[command(flatten)] + query: QueryArgs, + }, + #[command(about = "Show execution plan with metrics for a SQL query")] + ExplainAnalyze { + #[command(flatten)] + common: CommonArgs, + #[command(flatten)] + query: QueryArgs, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + + match args.command { + Commands::StartServer { common, port } => start_server(port, common.config_path).await, + Commands::Query { common, query } => run_query(common, query).await.map(|_| ()), + Commands::Explain { common, query } => run_explain(common, query, false).await.map(|_| ()), + Commands::ExplainAnalyze { common, query } => { + run_explain(common, query, true).await.map(|_| ()) + } + } +} + +/// Starts an arrow flight endpoint server on the specified port. +async fn start_server(port: u16, config_path: PathBuf) -> Result<()> { + let endpoint = ArrowFlightEndpoint::try_new(SessionBuilder::new(config_path, 0, 0))?; + + info!("Starting ArrowFlightEndpoint server on port {}", port); + + Server::builder() + .add_service(endpoint.into_flight_server()) + .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port)) + .await + .map_err(|e| internal_datafusion_err!("Error staring server: {e}"))?; + + Ok(()) +} + +/// SessionBuilder which creates a datafusion distributed session state. +struct SessionBuilder { + config_path: PathBuf, + num_shuffle_tasks: usize, + num_coalesce_tasks: usize, +} + +impl SessionBuilder { + fn new(config_path: PathBuf, num_shuffle_tasks: usize, num_coalesce_tasks: usize) -> Self { + Self { + config_path, + num_shuffle_tasks, + num_coalesce_tasks, + } + } +} + +#[async_trait] +impl DistributedSessionBuilder for SessionBuilder { + async fn build_session_state( + &self, + ctx: DistributedSessionBuilderContext, + ) -> Result { + let config_content = std::fs::read_to_string(self.config_path.clone()) + .map_err(|e| internal_datafusion_err!("Error reading config file: {e}"))?; + let config: ClusterConfig = serde_yaml::from_str(&config_content) + .map_err(|e| internal_datafusion_err!("Error parsing YAML config: {e}"))?; + + let channel_resolver = DistributedChannelResolver::try_new(config.endpoints).await?; + + Ok(SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env.clone()) + .with_default_features() + .with_distributed_channel_resolver(channel_resolver) + .with_distributed_network_shuffle_tasks(self.num_shuffle_tasks) + .with_distributed_network_coalesce_tasks(self.num_coalesce_tasks) + .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) + .build()) + } +} + +/// Creates a [SessionContext] for distributed execution. +async fn new_session_context( + config_path: PathBuf, + fixtures_dir: PathBuf, + num_shuffle_tasks: usize, + num_coalesce_tasks: usize, +) -> Result { + let state = SessionBuilder::new(config_path, num_shuffle_tasks, num_coalesce_tasks) + .build_session_state(DistributedSessionBuilderContext::default()) + .await?; + + let ctx = SessionContext::new_with_state(state); + + register_tables_from_fixtures(&ctx, &fixtures_dir).await?; + + Ok(ctx) +} + +/// Run query command on the cluster denoted by the config file. +async fn run_query(common: CommonArgs, query: QueryArgs) -> Result { + let ctx = new_session_context( + common.config_path, + common.fixtures_dir, + query.num_shuffle_tasks, + query.num_coalesce_tasks, + ) + .await?; + + let df = ctx.sql(&query.sql).await?; + let batches = df.collect().await?; + + let results = + datafusion::common::arrow::util::pretty::pretty_format_batches(batches.as_slice()) + .map_err(|e| internal_datafusion_err!("Error formatting batch: {e}"))?; + + Ok(results) +} + +/// Run explain command on the cluster denoted by the config file. +async fn run_explain( + common: CommonArgs, + query: QueryArgs, + show_metrics: bool, +) -> Result { + let ctx = new_session_context( + common.config_path, + common.fixtures_dir, + query.num_shuffle_tasks, + query.num_coalesce_tasks, + ) + .await?; + + let df = ctx.sql(&query.sql).await?; + let physical_plan = df.create_physical_plan().await?; + + if show_metrics { + collect(physical_plan.clone(), ctx.task_ctx()).await?; + return explain_analyze(physical_plan); + } + + Ok(display_plan_ascii(physical_plan.as_ref(), false)) +} + +async fn register_tables_from_fixtures(ctx: &SessionContext, fixtures_dir: &PathBuf) -> Result<()> { + if !fixtures_dir.exists() { + println!( + "Warning: Fixtures directory {:?} does not exist", + fixtures_dir + ); + return Ok(()); + } + + let entries = fs::read_dir(fixtures_dir) + .map_err(|e| internal_datafusion_err!("Error reading fixtures directory: {e}"))?; + + for entry in entries { + let entry = + entry.map_err(|e| internal_datafusion_err!("Error reading directory entry: {e}"))?; + let path = entry.path(); + + if path.is_file() { + let file_name = path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + let table_name = path + .file_stem() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + + if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) { + match extension.to_lowercase().as_str() { + "parquet" => { + info!("Registering parquet table '{}' from {:?}", table_name, path); + ctx.register_parquet( + table_name, + path.to_str().unwrap(), + ParquetReadOptions::default(), + ) + .await?; + } + "csv" => { + info!("Registering CSV table '{}' from {:?}", table_name, path); + ctx.register_csv( + table_name, + path.to_str().unwrap(), + datafusion::prelude::CsvReadOptions::default(), + ) + .await?; + } + _ => { + warn!("Skipping unsupported file type: {}", file_name); + } + } + } + } else if path.is_dir() { + // Handle directories as parquet datasets (like weather/) + let dir_name = path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + info!( + "Registering parquet directory '{}' from {:?}", + dir_name, path + ); + ctx.register_parquet( + dir_name, + path.to_str().unwrap(), + ParquetReadOptions::default(), + ) + .await?; + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use std::time::Duration; + use tempfile::NamedTempFile; + use tokio::task::JoinHandle; + + use datafusion_distributed::assert_snapshot; + + struct TestCluster { + server_handles: Vec>, + config_file: NamedTempFile, + testdata_path: PathBuf, + } + + async fn setup_test_cluster(base_port: u16) -> Result { + // Create YAML config file in temp directory + let mut temp_file = NamedTempFile::new().unwrap(); + let config = ClusterConfig { + endpoints: vec![ + format!("127.0.0.1:{}", base_port), + format!("127.0.0.1:{}", base_port + 1), + format!("127.0.0.1:{}", base_port + 2), + ], + }; + let yaml_content = serde_yaml::to_string(&config).unwrap(); + temp_file.write_all(yaml_content.as_bytes()).unwrap(); + temp_file.flush().unwrap(); + + // Start 3 servers on different ports + let ports = vec![base_port, base_port + 1, base_port + 2]; + let mut server_handles = Vec::new(); + + for port in &ports { + let config_path = temp_file.path().to_path_buf(); + let port = *port; + let handle = tokio::spawn(async move { + start_server(port, config_path) + .await + .expect("Server failed"); + }); + server_handles.push(handle); + } + + // Wait a bit for servers to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let testdata_path = PathBuf::from("../testdata"); // Relative to the controller crate + + Ok(TestCluster { + server_handles, + config_file: temp_file, + testdata_path, + }) + } + + impl Drop for TestCluster { + fn drop(&mut self) { + // Abort server tasks when test cluster is dropped + for handle in &self.server_handles { + handle.abort(); + } + } + } + + #[tokio::test] + async fn test_sql_command() { + let cluster = setup_test_cluster(18080).await.unwrap(); + + let result = run_query( + CommonArgs { + config_path: cluster.config_file.path().to_path_buf(), + fixtures_dir: cluster.testdata_path.clone(), + }, + QueryArgs { + sql: "SELECT \"RainToday\", count(*) FROM weather GROUP BY \"RainToday\" ORDER BY count(*)" + .to_string(), + num_shuffle_tasks: 2, + num_coalesce_tasks: 2, + }, + ) + .await.unwrap(); + + assert_snapshot!(result.to_string(), @r" + +-----------+----------+ + | RainToday | count(*) | + +-----------+----------+ + | Yes | 66 | + | No | 300 | + +-----------+----------+ + "); + } + + #[tokio::test] + async fn test_explain_command() { + let cluster = setup_test_cluster(18100).await.unwrap(); + + let result = run_explain( + CommonArgs { + config_path: cluster.config_file.path().to_path_buf(), + fixtures_dir: cluster.testdata_path.clone(), + }, + QueryArgs { + sql: "SELECT \"RainToday\", count(*) FROM weather GROUP BY \"RainToday\" ORDER BY count(*)" + .to_string(), + num_shuffle_tasks: 2, + num_coalesce_tasks: 2, + }, + false, // show_metrics = false + ) + .await.unwrap(); + + assert!(result.to_string().contains("DistributedExec")); + assert!(!result.to_string().contains("output_rows")); + } + + #[tokio::test] + async fn test_explain_analyze_command() { + let cluster = setup_test_cluster(18120).await.unwrap(); + + let result = run_explain( + CommonArgs { + config_path: cluster.config_file.path().to_path_buf(), + fixtures_dir: cluster.testdata_path.clone(), + }, + QueryArgs { + sql: "SELECT \"RainToday\", count(*) FROM weather GROUP BY \"RainToday\" ORDER BY count(*)" + .to_string(), + num_shuffle_tasks: 2, + num_coalesce_tasks: 2, + }, + true, // show_metrics = true + ) + .await.unwrap(); + + assert!(result.to_string().contains("DistributedExec")); + assert!(result.to_string().contains("output_rows=2")); + } +} diff --git a/datafusion_distributed_controller/src/channel_resolver.rs b/datafusion_distributed_controller/src/channel_resolver.rs new file mode 100644 index 0000000..78b76fc --- /dev/null +++ b/datafusion_distributed_controller/src/channel_resolver.rs @@ -0,0 +1,136 @@ +use arrow_flight::flight_service_client::FlightServiceClient; +use async_trait::async_trait; +use datafusion::common::internal_datafusion_err; +use datafusion::error::DataFusionError; +use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver}; +use futures::FutureExt; +use futures::future::{BoxFuture, Shared}; +use std::sync::Arc; +use tonic::transport::Channel; +use url::Url; + +const MAX_DECODING_MSG_SIZE: usize = 2 * 1024 * 1024 * 1024; + +/// [ChannelResolver] implementation that uses grpc for resolving nodes hosted at the +/// provided ip addresses. +#[derive(Clone)] +pub struct DistributedChannelResolver { + addresses: Vec, + channels: Arc>, +} + +type ChannelCacheValue = Shared< + BoxFuture<'static, Result, Arc>>, +>; + +impl DistributedChannelResolver { + /// Builds a [DistributedChannelResolver] from a list of endpoint URLs or IP:port pairs. + pub async fn try_new(endpoints: Vec) -> Result { + let mut addresses = Vec::new(); + + for endpoint in endpoints { + // Add http:// prefix if missing + let url_str = if endpoint.starts_with("http://") || endpoint.starts_with("https://") { + endpoint.clone() + } else { + format!("http://{}", endpoint) + }; + + let url = Url::parse(&url_str).map_err(|e| { + internal_datafusion_err!("Error parsing endpoint URL '{}': {}", endpoint, e) + })?; + addresses.push(url); + } + + Ok(Self { + addresses, + channels: Arc::new(moka::sync::Cache::new(1000)), + }) + } +} + +#[async_trait] +impl ChannelResolver for DistributedChannelResolver { + fn get_urls(&self) -> Result, DataFusionError> { + Ok(self.addresses.clone()) + } + + async fn get_flight_client_for_url( + &self, + url: &url::Url, + ) -> Result, DataFusionError> { + let url_clone = url.to_string(); + let result = self + .channels + .get_with(url.clone(), move || { + async move { + let unconnected_channel = Channel::from_shared(url_clone) + .map_err(|e| internal_datafusion_err!("Invalid URI: {e}"))?; + let channel = unconnected_channel.connect().await.map_err(|e| { + DataFusionError::Execution(format!("Error connecting to url: {e}")) + })?; + + // Apply layers to the channel. + let channel = tower::ServiceBuilder::new().service(channel); + + let client = FlightServiceClient::new(BoxCloneSyncChannel::new(channel)) + .max_decoding_message_size(MAX_DECODING_MSG_SIZE); + + Ok(client) + } + .boxed() + .shared() + }) + .await; + + match result { + Ok(result) => Ok(result), + Err(err) => { + self.channels.remove(url); + Err(DataFusionError::Shared(err)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio; + + #[tokio::test] + async fn test_endpoint_parsing() { + let endpoints = vec![ + "127.0.0.1:8080".to_string(), + "192.168.1.1:9090".to_string(), + "http://example.com:8081".to_string(), + ]; + + let resolver = DistributedChannelResolver::try_new(endpoints) + .await + .unwrap(); + let urls = resolver.get_urls().unwrap(); + + assert_eq!(urls.len(), 3); + assert_eq!(urls[0].as_str(), "http://127.0.0.1:8080/"); + assert_eq!(urls[1].as_str(), "http://192.168.1.1:9090/"); + assert_eq!(urls[2].as_str(), "http://example.com:8081/"); + } + + #[tokio::test] + async fn test_ipv6_endpoint_parsing() { + let endpoints = vec![ + "[::1]:8080".to_string(), + "https://[2001:db8::1]:9090".to_string(), + ]; + + let resolver = DistributedChannelResolver::try_new(endpoints) + .await + .unwrap(); + let urls = resolver.get_urls().unwrap(); + + assert_eq!(urls.len(), 2); + assert_eq!(urls[0].as_str(), "http://[::1]:8080/"); + assert_eq!(urls[1].as_str(), "https://[2001:db8::1]:9090/"); + } +} diff --git a/datafusion_distributed_controller/src/lib.rs b/datafusion_distributed_controller/src/lib.rs new file mode 100644 index 0000000..9272837 --- /dev/null +++ b/datafusion_distributed_controller/src/lib.rs @@ -0,0 +1 @@ +pub mod channel_resolver; diff --git a/datafusion_distributed_controller/src/main.rs b/datafusion_distributed_controller/src/main.rs new file mode 100644 index 0000000..f328e4d --- /dev/null +++ b/datafusion_distributed_controller/src/main.rs @@ -0,0 +1 @@ +fn main() {}