diff --git a/Cargo.lock b/Cargo.lock index e3f8d344..cd0574d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1754,6 +1754,7 @@ dependencies = [ "tokio-stream", "tonic", "tonic-build", + "tower 0.5.2", "url", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 132d0926..b2c7a6b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ kube = { version = "1.1", features = ["derive", "runtime"] } log = "0.4" rand = "0.8" uuid = { version = "1.6", features = ["v4"] } +tower = "0.5" object_store = { version = "0.12.0", features = [ "aws", diff --git a/src/bin/distributed-datafusion.rs b/src/bin/distributed-datafusion.rs index 738235b0..0f4a0bff 100644 --- a/src/bin/distributed-datafusion.rs +++ b/src/bin/distributed-datafusion.rs @@ -2,7 +2,7 @@ use anyhow::Result; use clap::Parser; use distributed_datafusion::{ friendly::new_friendly_name, proxy_service::DDProxyService, setup, - worker_service::DDWorkerService, + worker_service::DDWorkerService, worker_discovery::{EnvDiscovery, WorkerDiscovery}, }; #[derive(Parser)] @@ -30,10 +30,16 @@ async fn main() -> Result<()> { setup(); let args = Args::parse(); - + match args.mode.as_str() { "proxy" => { - let service = DDProxyService::new(new_friendly_name()?, args.port, None).await?; + // TODO: put the k8s or ENV decision behind some flag, WARNING: this will kick the discovery so workers should be up + let discovery: Arc = if std::env::var("DD_WORKER_ADDRESSES").is_ok() { + Arc::new(EnvDiscovery::new().await?) + } else { + Arc::new(K8sDiscovery::new().await?) + }; + let service = DDProxyService::new(new_friendly_name()?, args.port,discovery, None).await?; service.serve().await?; } "worker" => { diff --git a/src/lib.rs b/src/lib.rs index 76299b97..3ef6bbd6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,9 @@ pub mod util; pub mod vocab; pub mod worker_discovery; pub mod worker_service; +pub mod transport_traits; +pub mod transport; +pub mod test_worker; #[cfg(not(target_env = "msvc"))] use tikv_jemallocator::Jemalloc; diff --git a/src/planning.rs b/src/planning.rs index 461435ec..9083d7bc 100644 --- a/src/planning.rs +++ b/src/planning.rs @@ -1,7 +1,5 @@ use std::{ - collections::HashMap, - env, - sync::{Arc, LazyLock}, + collections::HashMap, env, sync::{Arc, LazyLock} }; use anyhow::{anyhow, Context}; @@ -30,19 +28,10 @@ use itertools::Itertools; use prost::Message; use crate::{ - analyze::{DistributedAnalyzeExec, DistributedAnalyzeRootExec}, - isolator::PartitionIsolatorExec, - logging::{debug, error, info, trace}, - max_rows::MaxRowsExec, - physical::DDStageOptimizerRule, - result::{DDError, Result}, - stage::DDStageExec, - stage_reader::{DDStageReaderExec, QueryId}, - util::{display_plan_with_partition_counts, get_client, physical_plan_to_bytes, wait_for}, - vocab::{ + analyze::{DistributedAnalyzeExec, DistributedAnalyzeRootExec}, isolator::PartitionIsolatorExec, logging::{debug, info, trace, error}, max_rows::MaxRowsExec, physical::DDStageOptimizerRule, result::{DDError, Result}, stage::DDStageExec, stage_reader::{DDStageReaderExec, QueryId}, transport::WorkerTransport, util::{display_plan_with_partition_counts, physical_plan_to_bytes, wait_for}, vocab::{ Addrs, CtxAnnotatedOutputs, CtxHost, CtxPartitionGroup, CtxStageAddrs, CtxStageId, DDTask, Host, Hosts, PartitionAddrs, StageAddrs, - }, + } }; #[derive(Debug)] @@ -439,109 +428,84 @@ pub fn add_distributed_analyze( pub async fn distribute_stages( query_id: &str, stages: Vec, - worker_addrs: Vec, + workers: Vec<(Host, Arc)>, codec: &dyn PhysicalExtensionCodec, ) -> Result<(Addrs, Vec)> { - // map of worker name to address - // FIXME: use types over tuples of strings, as we can accidently swap them and - // not know - - // a map of worker name to host - let mut workers: HashMap = worker_addrs - .iter() - .map(|host| (host.name.clone(), host.clone())) + // materialise a name-keyed map so we can remove “bad” workers on each retry + let mut valid_workers: HashMap<_, _> = workers + .into_iter() + .map(|(h, tx)| (h.name.clone(), (h, tx))) .collect(); for attempt in 0..3 { - if workers.is_empty() { + if valid_workers.is_empty() { return Err(anyhow!("No workers available to distribute stages").into()); } - // all stages to workers - let (task_datas, final_addrs) = - assign_to_workers(query_id, &stages, workers.values().collect(), codec)?; + let current: Vec<_> = valid_workers.values().cloned().collect(); + let (tasks, final_addrs, tx_host_pairs) = + assign_to_workers(query_id, &stages, current, codec)?; + + match try_distribute_tasks(&tasks, &tx_host_pairs).await { + Ok(_) => return Ok((final_addrs, tasks)), - // we retry this a few times to ensure that the workers are ready - // and can accept the stages - match try_distribute_tasks(&task_datas).await { - Ok(_) => return Ok((final_addrs, task_datas)), - Err(DDError::WorkerCommunicationError(bad_worker)) => { + // remove the poisoned worker and retry on the non poisoned workers + Err(DDError::WorkerCommunicationError(bad_host)) => { error!( - "distribute stages for query {query_id} attempt {attempt} failed removing \ - worker {bad_worker}. Retrying..." + "distribute_stages: attempt {attempt} – \ + worker {} failed; will retry without it", + bad_host.name ); - // if we cannot communicate with a worker, we remove it from the list of workers - workers.remove(&bad_worker.name); + valid_workers.remove(&bad_host.name); } + + // any other error is terminal Err(e) => return Err(e), } - if attempt == 2 { - return Err( - anyhow!("Failed to distribute query {query_id} stages after 3 attempts").into(), - ); - } } - unreachable!() + + unreachable!("retry loop exits on success or early return on error"); } /// try to distribute the stages to the workers, if we cannot communicate with a /// worker return it as the element in the Err -async fn try_distribute_tasks(task_datas: &[DDTask]) -> Result<()> { - // we can use the stage data to distribute the stages to workers - for task_data in task_datas { +async fn try_distribute_tasks( + tasks: &[DDTask], + tx_host_pairs: &[(Arc, Host)], +) -> Result<()> { + for ((task, (tx, host))) in tasks.iter().zip(tx_host_pairs) { trace!( - "Distributing Task: stage_id {}, pg: {:?} to worker: {:?}", - task_data.stage_id, - task_data.partition_group, - task_data.assigned_host + "Sending stage {} / pg {:?} to {}", + task.stage_id, + task.partition_group, + host ); - // populate its child stages - let mut stage_data = task_data.clone(); - stage_data.stage_addrs = Some(get_stage_addrs_from_tasks( - &stage_data.child_stage_ids, - task_datas, + // embed the StageAddrs of all children before shipping + let mut stage = task.clone(); + stage.stage_addrs = Some(get_stage_addrs_from_tasks( + &stage.child_stage_ids, + tasks, )?); - let host = stage_data - .assigned_host - .clone() - .context("Assigned host is missing for task data")?; - - let mut client = match get_client(&host) { - Ok(client) => client, - Err(e) => { - error!("Couldn't not communicate with worker {e:#?}"); - return Err(DDError::WorkerCommunicationError( - host.clone(), // here - )); - } - }; - - let mut buf = vec![]; - stage_data - .encode(&mut buf) - .context("Failed to encode stage data to buf")?; + let mut buf = Vec::new(); + stage.encode(&mut buf).map_err(anyhow::Error::from)?; let action = Action { - r#type: "add_plan".to_string(), + r#type: "add_plan".into(), body: buf.into(), }; - let mut response = client + // gRPC call, if it fails, transport poisons itself on failure and removes the address from the registry + let mut stream = tx .do_action(action) .await - .context("Failed to send action to worker")?; + .map_err(|_| DDError::WorkerCommunicationError(host.clone()))?; - // consume this empty response to ensure the action was successful - while let Some(_res) = response - .try_next() - .await - .context("error consuming do_action response")? - { - // we don't care about the response, just that it was successful - } - trace!("do action success for stage_id: {}", stage_data.stage_id); + // drain the (empty) response – ensures the worker actually accepted it + while stream.try_next().await? != None {} + + trace!("stage {} delivered to {}", stage.stage_id, host); } Ok(()) } @@ -552,40 +516,35 @@ async fn try_distribute_tasks(task_datas: &[DDTask]) -> Result<()> { fn assign_to_workers( query_id: &str, stages: &[DDStage], - worker_addrs: Vec<&Host>, + workers: Vec<(Host, Arc)>, codec: &dyn PhysicalExtensionCodec, -) -> Result<(Vec, Addrs)> { - let mut task_datas = vec![]; - let mut worker_idx = 0; +) -> Result<(Vec, Addrs, Vec<(Arc, Host)>)> { + let mut task_datas = Vec::new(); + let mut tx_host_pairs = Vec::new(); - trace!( - "assigning stages: {:?}", - stages - .iter() - .map(|s| format!("stage_id: {}, pgs:{:?}", s.stage_id, s.partition_groups)) - .join(",\n") - ); + // round-robin scheduler + let mut idx = 0; + let n_workers = workers.len(); - // keep track of which worker has the root of the plan tree (highest stage - // number) - let mut max_stage_id = -1; + // keep track of where the root of the plan will live (highest stage id) + let mut max_stage_id: i64 = -1; let mut final_addrs = Addrs::default(); for stage in stages { - for partition_group in stage.partition_groups.iter() { + for pg in &stage.partition_groups { let plan_bytes = physical_plan_to_bytes(stage.plan.clone(), codec)?; - let host = worker_addrs[worker_idx].clone(); - worker_idx = (worker_idx + 1) % worker_addrs.len(); + // pick next worker + let (host, tx) = workers[idx].clone(); + idx = (idx + 1) % n_workers; - if stage.stage_id as isize > max_stage_id { - // this wasn't the last stage - max_stage_id = stage.stage_id as isize; + // remember which host serves the final stage + if stage.stage_id as i64 > max_stage_id { + max_stage_id = stage.stage_id as i64; final_addrs.clear(); } - if stage.stage_id as isize == max_stage_id { - for part in partition_group.iter() { - // we are the final stage, so we will be the one to serve this partition + if stage.stage_id as i64 == max_stage_id { + for part in pg { final_addrs .entry(stage.stage_id) .or_default() @@ -595,22 +554,24 @@ fn assign_to_workers( } } - let task_data = DDTask { - query_id: query_id.to_string(), + task_datas.push(DDTask { + query_id: query_id.to_owned(), stage_id: stage.stage_id, plan_bytes, - partition_group: partition_group.to_vec(), - child_stage_ids: stage.child_stage_ids().unwrap_or_default().to_vec(), - stage_addrs: None, // will be calculated and filled in later + partition_group: pg.clone(), + child_stage_ids: stage.child_stage_ids().unwrap_or_default(), + stage_addrs: None, // filled in later num_output_partitions: stage.plan.output_partitioning().partition_count() as u64, full_partitions: stage.full_partitions, - assigned_host: Some(host), - }; - task_datas.push(task_data); + assigned_host: Some(host.clone()), + }); + + // keep the order **exactly** aligned with task_datas + tx_host_pairs.push((tx, host)); } } - Ok((task_datas, final_addrs)) + Ok((task_datas, final_addrs, tx_host_pairs)) } fn get_stage_addrs_from_tasks(target_stage_ids: &[u64], stages: &[DDTask]) -> Result { diff --git a/src/proxy_service.rs b/src/proxy_service.rs index 12ac59d0..19e057ba 100644 --- a/src/proxy_service.rs +++ b/src/proxy_service.rs @@ -50,7 +50,7 @@ use crate::{ stage_reader::DDStageReaderExec, util::{display_plan_with_partition_counts, get_addrs, start_up}, vocab::{Addrs, Host}, - worker_discovery::get_worker_addresses, + worker_discovery::{ WorkerDiscovery}, }; pub struct DDProxyHandler { @@ -58,23 +58,22 @@ pub struct DDProxyHandler { pub host: Host, pub planner: QueryPlanner, + pub discovery: Arc, /// Optional customizer for our context and proto serde pub customizer: Option>, } impl DDProxyHandler { - pub fn new(name: String, addr: String, customizer: Option>) -> Self { - // call this function to bootstrap the worker discovery mechanism - get_worker_addresses().expect("Could not get worker addresses upon startup"); - + pub fn new(name: String, addr: String, discovery: Arc, customizer: Option>) -> Self { let host = Host { name: name.clone(), addr: addr.clone(), }; Self { host: host.clone(), - planner: QueryPlanner::new(customizer.clone()), + planner: QueryPlanner::new(customizer.clone(), discovery.clone()), + discovery, customizer, } } @@ -292,6 +291,7 @@ impl DDProxyService { pub async fn new( name: String, port: usize, + discovery: Arc, ctx_customizer: Option>, ) -> Result { debug!("Creating DDProxyService!"); @@ -305,7 +305,7 @@ impl DDProxyService { info!("DDProxyService bound to {addr}"); - let handler = Arc::new(DDProxyHandler::new(name, addr.clone(), ctx_customizer)); + let handler = Arc::new(DDProxyHandler::new(name, addr.clone(), discovery, ctx_customizer)); Ok(Self { listener, diff --git a/src/query_planner.rs b/src/query_planner.rs index 46af7cad..e31af502 100644 --- a/src/query_planner.rs +++ b/src/query_planner.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use anyhow::{anyhow, Context as AnyhowContext}; +use anyhow::{anyhow}; use arrow::{compute::concat_batches, datatypes::SchemaRef}; use datafusion::{ logical_expr::LogicalPlan, @@ -21,8 +21,7 @@ use crate::{ }, record_batch_exec::RecordBatchExec, result::Result, - vocab::{Addrs, DDTask}, - worker_discovery::get_worker_addresses, + vocab::{Addrs, DDTask}, worker_discovery::{EnvDiscovery, K8sDiscovery, WorkerDiscovery}, }; /// Result of query preparation for execution of both query and its EXPLAIN @@ -61,16 +60,22 @@ impl std::fmt::Debug for QueryPlan { pub struct QueryPlanner { customizer: Option>, codec: Arc, + discovery: Arc, } impl Default for QueryPlanner { fn default() -> Self { - Self::new(None) + let discovery = if std::env::var("DD_WORKER_DEPLOYMENT").is_ok() { + Arc::new(K8sDiscovery::new().expect("k8s discovery")) as _ + } else { + Arc::new(EnvDiscovery::new().expect("env discovery")) as _ + }; + Self::new(discovery, None) } } impl QueryPlanner { - pub fn new(customizer: Option>) -> Self { + pub fn new(customizer: Option>, discovery: Arc) -> Self { let codec = Arc::new(DDCodec::new( customizer .clone() @@ -79,7 +84,7 @@ impl QueryPlanner { .unwrap(), )); - Self { customizer, codec } + Self { customizer, codec, discovery } } /// Common planning steps shared by both query and its EXPLAIN @@ -196,7 +201,7 @@ impl QueryPlanner { let (distributed_plan, distributed_stages) = execution_planning(physical_plan.clone(), 8192, Some(2)).await?; - let worker_addrs = get_worker_addresses()?; + let worker_addrs = self.discovery.workers().await?; // gather some information we need to send back such that // we can send a ticket to the client diff --git a/src/stage_reader.rs b/src/stage_reader.rs index 989ef524..1e3c790e 100644 --- a/src/stage_reader.rs +++ b/src/stage_reader.rs @@ -18,10 +18,7 @@ use futures::{stream::TryStreamExt, StreamExt}; use prost::Message; use crate::{ - logging::{error, trace}, - protobuf::{FlightDataMetadata, FlightTicketData}, - util::{get_client, CombinedRecordBatchStream}, - vocab::{CtxAnnotatedOutputs, CtxHost, CtxStageAddrs}, + logging::{error, trace}, protobuf::{FlightDataMetadata, FlightTicketData}, transport, util::CombinedRecordBatchStream, vocab::{CtxAnnotatedOutputs, CtxHost, CtxStageAddrs} }; pub(crate) struct QueryId(pub String); @@ -148,7 +145,7 @@ impl ExecutionPlan for DDStageReaderExec { )) })?? .iter() - .map(get_client) + .map(transport::get) .collect::>>()?; trace!("got clients. {name} num clients: {}", clients.len()); diff --git a/src/test_worker.rs b/src/test_worker.rs new file mode 100644 index 00000000..f0aedcd7 --- /dev/null +++ b/src/test_worker.rs @@ -0,0 +1,96 @@ +use arrow_flight::{flight_service_server::FlightService, FlightDescriptor, PollInfo}; +use futures::stream::BoxStream; +use tonic::{async_trait, Request, Response, Status}; +use futures::stream::{self}; + + +pub struct TestWorker; + +impl Default for TestWorker { + fn default() -> Self { + TestWorker + } +} + +#[async_trait] +impl FlightService for TestWorker { + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("handshake not implemented")) + } + + async fn poll_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + // For tests we don’t need the long-running behaviour yet. + // A simple “unimplemented” stub keeps the compiler happy. + Err(Status::unimplemented("poll_flight_info not implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("list_flights not implemented")) + } + + async fn get_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("get_flight_info not implemented")) + } + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("get_schema not implemented")) + } + + async fn do_get( + &self, + _req: Request, + ) -> Result, Status> { + Err(Status::unimplemented("do_get not implemented")) + } + + async fn do_put( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("do_put not implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("do_exchange not implemented")) + } + + async fn do_action( + &self, + _req: Request, + ) -> Result, Status> { + Ok(Response::new(Box::pin(stream::empty()))) + } + + async fn list_actions( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("list_actions not implemented")) + } +} \ No newline at end of file diff --git a/src/transport.rs b/src/transport.rs new file mode 100644 index 00000000..858080e6 --- /dev/null +++ b/src/transport.rs @@ -0,0 +1,50 @@ + +use std::sync::Arc; + +use async_trait::async_trait; +use arrow_flight::{decode::FlightRecordBatchStream, Action, Ticket}; +use bytes::Bytes; +use futures::stream::BoxStream; +use crate::result::Result; +use std::collections::HashMap; +use parking_lot::RwLock; +use std::sync::OnceLock; + +use crate::vocab::Host; + + +/// Global “cache” – **exactly one Arc per host** (backed by a Grpc/duplex channel etc.) +static REGISTRY: OnceLock>>> = OnceLock::new(); + +pub fn register(host: &Host, tx: Arc) { + REGISTRY.write().insert(host.addr.clone(), tx); +} + +pub fn get(host: &Host) -> anyhow::Result> { + REGISTRY + .read() + .get(&host.addr) + .cloned() + .ok_or_else(|| anyhow::anyhow!("no transport registered for {}", host)) +} + +/// Poison the entry when gRPC tells us the channel is broken. +pub fn poison(host: &Host) { + REGISTRY.write().remove(&host.addr); +} + +#[async_trait] +pub trait WorkerTransport: Send + Sync { + /// Execute a Flight `Action` (e.g. `add_plan`) and + /// stream back the raw payloads. + async fn do_action( + &self, + action: Action, + ) -> Result>>; + + /// Fetch a stream of record batches identified by `Ticket`. + async fn do_get( + &self, + ticket: Ticket, + ) -> Result; +} \ No newline at end of file diff --git a/src/transport_traits.rs b/src/transport_traits.rs new file mode 100644 index 00000000..08c45d0f --- /dev/null +++ b/src/transport_traits.rs @@ -0,0 +1,124 @@ +use anyhow::Result; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; +use arrow_flight::{Action, FlightClient, Ticket}; +use async_trait::async_trait; +use bytes::Bytes; +use futures::stream::BoxStream; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{duplex, DuplexStream}; +use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::{Channel, Endpoint, Server, Uri}; +use tonic::server::NamedService; +use tonic::{Status}; + +use crate::transport::{self, WorkerTransport}; + + +pub struct InMemTransport { + chan: Channel, +} + +impl InMemTransport { + /// Build a transport pair: `(client_half, server_half)` to feed into WorkerDiscovery + pub async fn pair( + svc: S, + ) -> Result<(Arc, ())> + where + S: FlightService + Send + 'static, + { + // 32 KiB buffer is usually enough; adjust if you stream large batches + let (cli_io, srv_io) = duplex(32 * 1024); + + // 1a. spin up the server on one half + tokio::spawn(run_flight_server(svc, srv_io)); + + // 1b. build a tonic `Channel` over the client half + let chan = Endpoint::try_from("http://inmem") + .unwrap() // URI never used on-wire + .connect_with_connector(tower::service_fn(move |_: Uri| { + // each connect() gets its own copy of the duplex half + async move { Ok::<_, std::io::Error>(cli_io) } + })) + .await?; + + Ok((Arc::new(Self { chan }), ())) + } +} + +// implement the WorkerTransport for in memory requests +#[async_trait] +impl WorkerTransport for InMemTransport { + async fn do_action( + &self, + action: Action, + ) -> Result>> { + let mut client = FlightClient::new(self.chan.clone()); + Ok(client.do_action(action).await?) + } + + async fn do_get( + &self, + ticket: Ticket, + ) -> Result { + let mut client = FlightClient::new(self.chan.clone()); + Ok(client.do_get(ticket).await?) + } +} + +// --------------------------------- +// Helper: run a tiny Flight server +// --------------------------------- +async fn run_flight_server(svc: S, io: DuplexStream) +where + S: Flight + NamedService + 'static, +{ + // tonic lets you serve **one** IO object with `Server::builder().serve_with_incoming_stream` + // We adapt our single DuplexStream into a stream-of-exactly-one. + let incoming = ReceiverStream::new(tokio_stream::once(Ok::<_, std::io::Error>(io))); + + Server::builder() + .add_service(FlightServiceServer::new(svc)) + .serve_with_incoming(incoming) + .await + .ok(); // ignore shutdown errors in tests +} + + +pub struct GrpcTransport { + chan: Channel, +} + +impl GrpcTransport { + pub async fn connect(addr: &str) -> Result> { + let chan = Channel::from_shared(format!("http://{addr}"))? + .connect_timeout(Duration::from_secs(2)) + .connect() + .await?; + Ok(Arc::new(Self { chan })) + } +} + +#[async_trait] +impl WorkerTransport for GrpcTransport { + + async fn do_action( + &self, + action: Action, + ) -> Result>> { + let res = self.client.do_action(action).await; + // if the request fails & we fail to connect, we remove the transport from the registry + if res.is_err() { transport::poison(&self.host) } + res.map_err(Into::into) + } + + async fn do_get( + &self, + ticket: Ticket, + ) -> Result { + let res = self.client.do_get(ticket).await; + if res.is_err() { transport::poison(&self.host) } + res.map_err(Into::into) + } +} \ No newline at end of file diff --git a/src/util.rs b/src/util.rs index 0d9bd97b..87dc10fe 100644 --- a/src/util.rs +++ b/src/util.rs @@ -22,9 +22,8 @@ use arrow::{ MetadataVersion, }, }; -use arrow_flight::{decode::FlightRecordBatchStream, FlightClient, FlightData, Ticket}; +use arrow_flight::FlightData; use async_stream::stream; -use bytes::Bytes; use datafusion::{ common::{ internal_datafusion_err, @@ -39,27 +38,24 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; -use futures::{stream::BoxStream, Stream, StreamExt}; +use futures::{Stream, StreamExt}; use object_store::{ aws::AmazonS3Builder, gcp::GoogleCloudStorageBuilder, http::HttpBuilder, ObjectStore, }; -use parking_lot::RwLock; use prost::Message; use tokio::{ macros::support::thread_rng_n, net::TcpListener, runtime::{Handle, Runtime}, }; -use tonic::transport::Channel; use url::Url; use crate::{ - codec::DDCodec, logging::{debug, error, trace}, protobuf::StageAddrs, result::Result, stage_reader::DDStageReaderExec, - vocab::{Addrs, Host}, + vocab::{Addrs} }; struct Spawner { @@ -327,129 +323,6 @@ pub fn reporting_stream( Box::pin(RecordBatchStreamAdapter::new(schema, out_stream)) as SendableRecordBatchStream } -pub struct WorkerClient { - /// the host we are connecting to - pub(crate) host: Host, - /// The flight client to the worker - inner: FlightClient, - /// the channel cache in the factory - channels: Arc>>, -} - -impl WorkerClient { - pub fn new( - host: Host, - inner: FlightClient, - channels: Arc>>, - ) -> Self { - Self { - host, - inner, - channels, - } - } - - pub async fn do_get( - &mut self, - ticket: Ticket, - ) -> arrow_flight::error::Result { - let stream = self.inner.do_get(ticket).await.inspect_err(|e| { - error!( - "Error in do_get for worker {}: {e:?}. - Considering this channel poisoned and removing it from WorkerClientFactory cache", - self.host - ); - self.channels.write().remove(&self.host.addr); - })?; - - Ok(stream) - } - - pub async fn do_action( - &mut self, - action: arrow_flight::Action, - ) -> arrow_flight::error::Result>> { - let result = self.inner.do_action(action).await.inspect_err(|e| { - error!( - "Error in do_action for worker {}: {e:?}. - Considering this channel poisoned and removing it from WorkerClientFactory \ - cache", - self.host - ); - self.channels.write().remove(&self.host.addr); - })?; - - Ok(result) - } -} - -struct WorkerClientFactory { - channels: Arc>>, -} - -impl WorkerClientFactory { - fn new() -> Self { - Self { - channels: Arc::new(RwLock::new(HashMap::new())), - } - } - - pub fn get_client(&self, host: &Host) -> Result { - let url = format!("http://{}", host.addr); - - let maybe_chan = self.channels.read().get(&host.addr).cloned(); - let chan = match maybe_chan { - Some(chan) => { - debug!("WorkerFactory using cached channel for {host}"); - chan - } - None => { - let host_str = host.to_string(); - let fut = async move { - trace!("WorkerFactory connecting to {host_str}"); - Channel::from_shared(url.clone()) - .map_err(|e| internal_datafusion_err!("WorkerFactory invalid url {e:#?}"))? - // FIXME: update timeout value to not be a magic number - .connect_timeout(Duration::from_secs(2)) - .connect() - .await - .map_err(|e| { - internal_datafusion_err!("WorkerFactory cannot connect {e:#?}") - }) - }; - - let chan = wait_for(fut, "WorkerFactory::get_client").map_err(|e| { - internal_datafusion_err!( - "WorkerFactory Cannot wait for channel connect future {e:#?}" - ) - })??; - trace!("WorkerFactory connected to {host}"); - self.channels - .write() - .insert(host.addr.to_string(), chan.clone()); - - chan - } - }; - debug!("WorkerFactory have channel now for {host}"); - - let flight_client = FlightClient::new(chan); - debug!("WorkerFactory made flight client for {host}"); - Ok(WorkerClient::new( - host.clone(), - flight_client, - self.channels.clone(), - )) - } -} - -static FACTORY: OnceLock = OnceLock::new(); - -pub fn get_client(host: &Host) -> Result { - let factory = FACTORY.get_or_init(WorkerClientFactory::new); - factory.get_client(host) -} - /// Copied from datafusion_physical_plan::union as its useful and not public pub struct CombinedRecordBatchStream { /// Schema wrapped by Arc diff --git a/src/worker_discovery.rs b/src/worker_discovery.rs index 36a5dee2..2d8744b1 100644 --- a/src/worker_discovery.rs +++ b/src/worker_discovery.rs @@ -1,291 +1,182 @@ -use std::{ - collections::HashMap, - sync::{Arc, OnceLock}, - time::Duration, -}; +use std::{ env, sync::{Arc}}; use anyhow::{anyhow, Context}; -use arrow_flight::{Action, FlightClient}; use futures::{StreamExt, TryStreamExt}; use k8s_openapi::api::{apps::v1::Deployment, core::v1::Pod}; use kube::{ - api::{Api, ResourceExt, WatchEvent, WatchParams}, + api::{Api, ListParams, ResourceExt}, Client, }; -use parking_lot::RwLock; -use prost::Message; -use tonic::transport::Channel; +use tonic::{async_trait}; use crate::{ - logging::{debug, error, trace}, - result::Result, - vocab::Host, + logging::trace, result::Result, transport::{self, WorkerTransport}, transport_traits::{GrpcTransport, InMemTransport}, vocab::Host }; +use crate::test_worker::TestWorker; -static WORKER_DISCOVERY: OnceLock> = OnceLock::new(); - -pub fn get_worker_addresses() -> Result> { - match WORKER_DISCOVERY.get_or_init(WorkerDiscovery::new) { - Ok(wd) => { - let worker_addrs = wd.get_addresses(); - debug!( - "Worker addresses found:\n{}", - worker_addrs - .iter() - .map(|host| format!("{host}")) - .collect::>() - .join("\n") - ); - Ok(worker_addrs) - } - Err(e) => Err(anyhow!("Failed to initialize WorkerDiscovery: {}", e).into()), - } +#[async_trait] +pub trait WorkerDiscovery: Send + Sync { + async fn workers( + &self, + ) -> Result)>>; } -struct WorkerDiscovery { - addresses: Arc>>, -} +pub struct TestDiscovery { workers: Vec<(Host, Arc)> } -impl WorkerDiscovery { - pub fn new() -> Result { - let wd = WorkerDiscovery { - addresses: Arc::new(RwLock::new(HashMap::new())), - }; - wd.start()?; - Ok(wd) - } +impl TestDiscovery { + /// Spin up `n` duplex-backed Flight servers and return their transports. + pub async fn new(n: usize) -> Result { + let mut workers = Vec::with_capacity(n); - fn get_addresses(&self) -> Vec { - let guard = self.addresses.read(); - guard.iter().map(|(_ip, host)| host.clone()).collect() - } + for i in 0..n { + // 1. Build the pair (client transport & background server task). + let (transport, _server_task) = InMemTransport::pair(TestWorker::default()).await?; - fn start(&self) -> Result<()> { - let worker_addrs_env = std::env::var("DD_WORKER_ADDRESSES"); - let worker_deployment_env = std::env::var("DD_WORKER_DEPLOYMENT"); - let worker_deployment_namespace_env = std::env::var("DD_WORKER_DEPLOYMENT_NAMESPACE"); + let transport: Arc = transport; // TODO: I dont like this upcast - if worker_addrs_env.is_ok() { - let addresses = self.addresses.clone(); - tokio::spawn(async move { - // if the env var is set, use it - set_worker_addresses_from_env(addresses, worker_addrs_env.unwrap().as_str()) - .await - .expect("Could not set worker addresses from env"); - }); - } else if worker_deployment_namespace_env.is_ok() && worker_deployment_env.is_ok() { - let addresses = self.addresses.clone(); - let deployment = worker_deployment_env.unwrap(); - let namespace = worker_deployment_namespace_env.unwrap(); - tokio::spawn(async move { - match watch_deployment_hosts_continuous(addresses, &deployment, &namespace).await { - Ok(_) => {} - Err(e) => error!("Error starting worker watcher: {:?}", e), - } - }); - } else { - // if neither env var is set, return an error - return Err(anyhow!( - "Either DD_WORKER_ADDRESSES or both DD_WORKER_DEPLOYMENT and \ - DD_WORKER_DEPLOYMENT_NAMESPACE must be set" - ) - .into()); + // 2. Give the worker a human-friendly name (handy for debug logs). + let host = Host { + name: format!("test-{i}"), + addr: format!("inmem://{i}"), + }; + + workers.push((host.clone(), transport.clone())); // TODO: I dont like this clone + transport::register(&host, transport); } - Ok(()) + + Ok(Self { workers }) } } -async fn set_worker_addresses_from_env( - addresses: Arc>>, - env_str: &str, -) -> Result<()> { - // get addresss from an env var where addresses are split by comans - // and in the form of name/address,name/address - - for addr in env_str.split(',') { - let host = get_worker_host(addr.to_string()) - .await - .context(format!("Failed to get worker host for address: {}", addr))?; - addresses.write().insert(addr.to_owned(), host); +#[async_trait] +impl WorkerDiscovery for TestDiscovery { + async fn workers(&self) -> Result)>> { + // This is trivial, the addresses are dummy + Ok(self.workers.clone()) } - Ok(()) } -/// Continuously watch for changes to pods in a Kubernetes deployment and call a -/// handler function whenever the list of hosts changes. -/// -/// # Arguments -/// * `deployment_name` - Name of the deployment -/// * `namespace` - Kubernetes namespace where the deployment is located -/// * `handler` - A function to call when the host list changes -/// -/// # Returns -/// This function runs indefinitely until an error occurs -/// -/// # Errors -/// Returns an error if there's an issue connecting to the Kubernetes API -/// or if the deployment or its pods cannot be found -async fn watch_deployment_hosts_continuous( - addresses: Arc>>, - deployment_name: &str, - namespace: &str, -) -> Result<()> { - debug!( - "Starting to watch deployment {} in namespace {}", - deployment_name, namespace - ); - // Initialize the Kubernetes client - let client = Client::try_default() - .await - .context("Failed to create Kubernetes client")?; - - // Access the Deployments API - let deployments: Api = Api::namespaced(client.clone(), namespace); - - // Get the specific deployment - let deployment = deployments - .get(deployment_name) - .await - .context(format!("Failed to get deployment {}", deployment_name))?; - - // Extract the selector labels from the deployment - let selector = deployment - .spec - .as_ref() - .and_then(|spec| spec.selector.match_labels.as_ref()) - .context("Deployment has no selector labels")?; - - // Convert selector to a string format for the label selector - let label_selector = selector - .iter() - .map(|(k, v)| format!("{}={}", k, v)) - .collect::>() - .join(","); - - // Access the Pods API - let pods: Api = Api::namespaced(client, namespace); - - debug!( - "Watching deployment {} in namespace {} with label selector: {}", - deployment_name, namespace, label_selector - ); - - let wp = WatchParams::default().labels(&label_selector); - - // Start watching for pod changes - let mut watcher = pods - .watch(&wp, "0") - .await - .context("could not build watcher")? - .boxed(); +pub struct EnvDiscovery { + cached: Vec<(Host, Arc)>, +} - while let Some(event_result) = watcher - .try_next() - .await - .context("could not get next event from watcher")? - { - match &event_result { - WatchEvent::Added(pod) | WatchEvent::Modified(pod) => { - trace!( - "Pod event: {:?}, added or modified: {:#?}", - event_result, - pod - ); - if let Some(Some(_ip)) = pod.status.as_ref().map(|s| s.pod_ip.as_ref()) { - let (pod_ip, host) = get_worker_info_from_pod(pod).await?; - debug!( - "Pod {} has IP address {}, host {}", - pod.name_any(), - pod_ip, - host - ); - addresses.write().insert(pod_ip, host); - } else { - trace!("Pod {} has no IP address, skipping", pod.name_any()); - } - } - WatchEvent::Deleted(pod) => { - debug!("Pod deleted: {}", pod.name_any()); - if let Some(status) = &pod.status { - if let Some(pod_ip) = &status.pod_ip { - if !pod_ip.is_empty() { - debug!("Removing pod IP: {}", pod_ip); - addresses.write().remove(pod_ip); - } - } - } - } - WatchEvent::Bookmark(_) => {} - WatchEvent::Error(e) => { - eprintln!("Watch error: {}", e); - } +impl EnvDiscovery { + pub async fn new() -> Result { + let raw = env::var("DD_WORKER_ADDRESSES") + .context("DD_WORKER_ADDRESSES must be set for EnvDiscovery")?; + + let mut cached = Vec::new(); + for token in raw.split(',').filter(|s| !s.is_empty()) { + let (name, addr) = match token.split_once('/') { + Some((n, a)) => (n.to_string(), a.to_string()), + None => (token.to_string(), token.to_string()), + }; + + let host = Host { name, addr: addr.clone() }; + let transport: Arc = + GrpcTransport::connect(&addr).await?; + transport::register(&host, transport.clone()); + cached.push((host, transport)); } + Ok(Self { cached }) } - - Ok(()) } -async fn get_worker_host(addr: String) -> Result { - let mut client = Channel::from_shared(format!("http://{addr}")) - .context("Failed to create channel")? - .connect_timeout(Duration::from_secs(2)) - .connect() - .await - .map(FlightClient::new) - .context("Failed to connect to worker")?; - - let action = Action { - r#type: "get_host".to_string(), - body: vec![].into(), - }; - - let mut response = client - .do_action(action) - .await - .context("Failed to send action to worker")?; +#[async_trait] +impl WorkerDiscovery for EnvDiscovery { + async fn workers(&self) -> Result)>> { + // Static list of addresses + Ok(self.cached.clone()) + } +} - Ok(response - .try_next() - .await - .transpose() - .context("error consuming do_action response")? - .map(Host::decode)? - .context("Failed to decode Host from worker response")?) +pub struct K8sDiscovery { + deployment: String, + namespace: String, + port_name: String, //defaults to first containerPort } -async fn get_worker_info_from_pod(pod: &Pod) -> Result<(String, Host)> { - let status = pod.status.as_ref().context("Pod has no status")?; - let pod_ip = status.pod_ip.as_ref().context("Pod has no IP address")?; +impl K8sDiscovery { + pub fn new() -> anyhow::Result { + let deployment = env::var("DD_WORKER_DEPLOYMENT") + .context("DD_WORKER_DEPLOYMENT must be set for K8sDiscovery")?; + let namespace = env::var("DD_WORKER_DEPLOYMENT_NAMESPACE") + .context("DD_WORKER_DEPLOYMENT_NAMESPACE must be set for K8sDiscovery")?; + + Ok(Self { + deployment, + namespace, + port_name: "dd-worker".into(), + }) + } - // filter on container name - let port = pod - .spec - .as_ref() - .and_then(|spec| { - spec.containers - .iter() - .find(|c| c.name == "dd-worker") - .and_then(|c| { - c.ports - .as_ref() - .and_then(|ports| ports.iter().next().map(|p| p.container_port)) - }) - }) - .ok_or_else(|| { - anyhow::anyhow!( - "No could not find container port for container named dd-worker found in pod {}", - pod.name_any() - ) - })?; + async fn list_pods(&self) -> Result> { + let client = Client::try_default().await.context("failed to create kube client")?; + let pods: Api = Api::namespaced(client, &self.namespace); - if pod_ip.is_empty() { - Err(anyhow::anyhow!("Pod {} has no IP address", pod.name_any()).into()) - } else { - let host_str = format!("{}:{}", pod_ip, port); - let host = get_worker_host(host_str.clone()).await.context(format!( - "Failed to get worker host for pod {}", - pod.name_any() - ))?; - Ok((pod_ip.to_owned(), host)) + // Re-use the selector of the deployment – robust & efficient + let dep_api: Api = Api::namespaced(pods.clone().into_client(), &self.namespace); + let dep = dep_api + .get(&self.deployment) + .await + .context("failed to get deployment")?; + let selector = dep.spec + .and_then(|s| s.selector.match_labels) + .ok_or_else(|| anyhow!("deployment has no selector"))?; + let selector_string = selector.into_iter() + .map(|(k,v)| format!("{k}={v}")) + .collect::>() + .join(","); + + let lp = ListParams::default().labels(&selector_string); + Ok(pods.list(&lp).await.context("failed to list pods")?.items) + } + + async fn pod_to_worker(&self, pod: &Pod) -> Result<(Host, Arc)> { + let ip = pod.status + .as_ref() + .and_then(|s| s.pod_ip.clone()) + .ok_or_else(|| anyhow!("pod {} has no IP yet", pod.name_any()))?; + + // find the port labelled `dd-worker` or just take the first one + let port = pod.spec + .as_ref() + .and_then(|s| { + s.containers.iter().flat_map(|c| c.ports.as_ref()) + .flatten() + .find(|p| p.name.as_deref() == Some(&self.port_name)) + .or_else(|| s.containers.iter() + .flat_map(|c| c.ports.as_ref()).flatten().next()) + .map(|p| p.container_port) + }) + .ok_or_else(|| anyhow!("pod {} has no container port", pod.name_any()))?; + + let addr = format!("{ip}:{port}"); + let host = Host { name: pod.name_any(), addr: addr.clone() }; + let tx = GrpcTransport::connect(&addr).await?; + transport::register(&host, tx.clone()); + Ok((host, tx)) } } + +#[async_trait] +impl WorkerDiscovery for K8sDiscovery { + async fn workers(&self) -> Result)>> { + let pods = self.list_pods().await?; + let mut out = Vec::with_capacity(pods.len()); + for pod in pods { + match self.pod_to_worker(&pod).await { + Ok(pair) => out.push(pair), + Err(e) => trace!("skip pod: {e:#}, pod={}", pod.name_any()), + } + } + if out.is_empty() { + Err(anyhow!( + "no ready pods found for deployment {} in {}", + self.deployment, self.namespace + )) + } else { + Ok(out) + } + } +} \ No newline at end of file