diff --git a/Cargo.lock b/Cargo.lock index c0a40ea..f71585e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1088,6 +1088,7 @@ version = "0.1.0" dependencies = [ "arrow", "arrow-flight", + "arrow-select", "async-trait", "bytes", "chrono", diff --git a/Cargo.toml b/Cargo.toml index c60857d..6b51920 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ chrono = { version = "0.4.42" } datafusion = { workspace = true } datafusion-proto = { workspace = true } arrow-flight = "56.1.0" +arrow-select = "56.1.0" async-trait = "0.1.88" tokio = { version = "1.46.1", features = ["full"] } # Updated to 0.13.1 to match arrow-flight 56.1.0 diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 9db7f59..02111c7 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -14,12 +14,11 @@ use arrow_flight::error::FlightError; use bytes::Bytes; use dashmap::DashMap; use datafusion::common::{exec_err, internal_err, plan_err}; -use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use futures::{TryFutureExt, TryStreamExt}; use http::Extensions; use prost::Message; use std::any::Any; @@ -296,8 +295,6 @@ impl ExecutionPlan for NetworkCoalesceExec { }; let metrics_collection_capture = self_ready.metrics_collection.clone(); - let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema()); - let (mapper, _indices) = adapter.map_schema(&self.schema())?; let stream = async move { let mut client = channel_resolver.get_flight_client_for_url(&url).await?; let stream = client @@ -312,12 +309,7 @@ impl ExecutionPlan for NetworkCoalesceExec { Ok( FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) - .map_err(map_flight_to_datafusion_error) - .map(move |batch| { - let batch = batch?; - - mapper.map_batch(batch) - }), + .map_err(map_flight_to_datafusion_error), ) } .try_flatten_stream(); diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 5769bfc..a4fcba8 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -14,7 +14,6 @@ use arrow_flight::error::FlightError; use bytes::Bytes; use dashmap::DashMap; use datafusion::common::{exec_err, internal_datafusion_err, plan_err}; -use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::Partitioning; @@ -318,12 +317,8 @@ impl ExecutionPlan for NetworkShuffleExec { let task_context = DistributedTaskContext::from_ctx(&context); let off = self_ready.properties.partitioning.partition_count() * task_context.task_index; - let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema()); - let (mapper, _indices) = adapter.map_schema(&self.schema())?; - let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| { let channel_resolver = Arc::clone(&channel_resolver); - let mapper = mapper.clone(); let ticket = Request::from_parts( MetadataMap::from_headers(context_headers.clone()), @@ -364,12 +359,7 @@ impl ExecutionPlan for NetworkShuffleExec { Ok( FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) - .map_err(map_flight_to_datafusion_error) - .map(move |batch| { - let batch = batch?; - - mapper.map_batch(batch) - }), + .map_err(map_flight_to_datafusion_error), ) } .try_flatten_stream() diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index a60d43d..e51aecd 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -11,10 +11,12 @@ use crate::protobuf::{ }; use arrow_flight::FlightData; use arrow_flight::Ticket; -use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::encode::{DictionaryHandling, FlightDataEncoderBuilder}; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; +use arrow_select::dictionary::garbage_collect_any_dictionary; use bytes::Bytes; +use datafusion::arrow::array::{Array, AsArray, RecordBatch}; use datafusion::common::exec_datafusion_err; use datafusion::error::DataFusionError; @@ -134,8 +136,22 @@ impl ArrowFlightEndpoint { .execute(doget.target_partition as usize, session_state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; + let schema = stream.schema().clone(); + + // Apply garbage collection of dictionary and view arrays before sending over the network + let stream = stream.and_then(|rb| std::future::ready(garbage_collect_arrays(rb))); + let stream = FlightDataEncoderBuilder::new() - .with_schema(stream.schema().clone()) + .with_schema(schema) + // This tells the encoder to send dictionaries across the wire as-is. + // The alternative (`DictionaryHandling::Hydrate`) would expand the dictionaries + // into their value types, which can potentially blow up the size of the data transfer. + // The main reason to use `DictionaryHandling::Hydrate` is for compatibility with clients + // that do not support dictionaries, but since we are using the same server/client on both + // sides, we can safely use `DictionaryHandling::Resend`. + // Note that we do garbage collection of unused dictionary values above, so we are not sending + // unused dictionary values over the wire. + .with_dictionary_handling(DictionaryHandling::Resend) .build(stream.map_err(|err| { FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) })); @@ -210,6 +226,34 @@ fn collect_and_create_metrics_flight_data( Ok(incoming.with_app_metadata(buf)) } +/// Garbage collects values sub-arrays. +/// +/// We apply this before sending RecordBatches over the network to avoid sending +/// values that are not referenced by any dictionary keys or buffers that are not used. +/// +/// Unused values can arise from operations such as filtering, where +/// some keys may no longer be referenced in the filtered result. +fn garbage_collect_arrays(batch: RecordBatch) -> Result { + let (schema, arrays, _row_count) = batch.into_parts(); + + let arrays = arrays + .into_iter() + .map(|array| { + if let Some(array) = array.as_any_dictionary_opt() { + garbage_collect_any_dictionary(array) + } else if let Some(array) = array.as_string_view_opt() { + Ok(Arc::new(array.gc()) as Arc) + } else if let Some(array) = array.as_binary_view_opt() { + Ok(Arc::new(array.gc()) as Arc) + } else { + Ok(array) + } + }) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(schema, arrays)?) +} + #[cfg(test)] mod tests { use super::*;