diff --git a/Cargo.lock b/Cargo.lock index dd71c0fb6..a3fcce576 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -952,6 +952,7 @@ dependencies = [ "serde_json", "snmalloc-rs", "structopt", + "tempfile", "tokio", ] diff --git a/ballista/client/tests/sort_shuffle.rs b/ballista/client/tests/sort_shuffle.rs new file mode 100644 index 000000000..dc09f3235 --- /dev/null +++ b/ballista/client/tests/sort_shuffle.rs @@ -0,0 +1,532 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! End-to-end integration tests for sort-based shuffle. +//! +//! These tests verify that the sort-based shuffle implementation produces +//! correct results for various query patterns that involve shuffling. +//! +//! Tests are parameterized to run with both local reads and remote reads +//! (via the flight service) to ensure both paths work correctly. + +mod common; + +#[cfg(test)] +#[cfg(feature = "standalone")] +mod sort_shuffle_tests { + use ballista::prelude::{SessionConfigExt, SessionContextExt}; + use ballista_core::config::{ + BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, + BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, + BALLISTA_SHUFFLE_SORT_BASED_BUFFER_SIZE, BALLISTA_SHUFFLE_SORT_BASED_ENABLED, + BALLISTA_SHUFFLE_SORT_BASED_MEMORY_LIMIT, + }; + use datafusion::arrow::util::pretty::pretty_format_batches; + use datafusion::common::Result; + use datafusion::execution::SessionStateBuilder; + use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; + use rstest::rstest; + use std::collections::HashSet; + + /// Read mode for shuffle data + #[derive(Debug, Clone, Copy)] + enum ReadMode { + /// Read shuffle data locally (default) + Local, + /// Read shuffle data via flight service (remote read) + RemoteFlight, + } + + /// Creates a standalone session context with sort-based shuffle enabled. + async fn create_sort_shuffle_context(read_mode: ReadMode) -> SessionContext { + let mut config = SessionConfig::new_with_ballista() + .set_str(BALLISTA_SHUFFLE_SORT_BASED_ENABLED, "true") + .set_str(BALLISTA_SHUFFLE_SORT_BASED_BUFFER_SIZE, "1048576") // 1MB + .set_str(BALLISTA_SHUFFLE_SORT_BASED_MEMORY_LIMIT, "268435456"); // 256MB + + // Configure read mode + match read_mode { + ReadMode::Local => {} + ReadMode::RemoteFlight => { + config = config + .set_str(BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, "true") + .set_str(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, "true"); + } + } + + let state = SessionStateBuilder::new() + .with_config(config) + .with_default_features() + .build(); + + SessionContext::standalone_with_state(state).await.unwrap() + } + + /// Creates a standalone session context with hash-based shuffle (default). + async fn create_hash_shuffle_context() -> SessionContext { + SessionContext::standalone().await.unwrap() + } + + /// Registers test data in the context. + async fn register_test_data(ctx: &SessionContext) { + ctx.register_parquet( + "test", + "testdata/alltypes_plain.parquet", + ParquetReadOptions::default(), + ) + .await + .unwrap(); + } + + fn assert_result_eq( + expected: Vec<&str>, + results: &[datafusion::arrow::record_batch::RecordBatch], + ) { + assert_eq!( + expected, + pretty_format_batches(results) + .unwrap() + .to_string() + .trim() + .lines() + .collect::>() + ); + } + + /// Extracts values from a result set, ignoring order. + fn extract_values_unordered( + results: &[datafusion::arrow::record_batch::RecordBatch], + ) -> HashSet { + pretty_format_batches(results) + .unwrap() + .to_string() + .trim() + .lines() + .skip(3) // Skip header lines + .filter(|line| !line.starts_with('+')) + .map(|s| s.to_string()) + .collect() + } + + // ==================== Basic Aggregation Tests ==================== + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_group_by_single_column( + #[case] read_mode: ReadMode, + ) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx + .sql("SELECT bool_col, COUNT(*) as cnt FROM test GROUP BY bool_col ORDER BY bool_col") + .await?; + let results = df.collect().await?; + + let expected = vec![ + "+----------+-----+", + "| bool_col | cnt |", + "+----------+-----+", + "| false | 4 |", + "| true | 4 |", + "+----------+-----+", + ]; + assert_result_eq(expected, &results); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_group_by_multiple_columns( + #[case] read_mode: ReadMode, + ) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx + .sql( + "SELECT bool_col, tinyint_col, COUNT(*) as cnt + FROM test + GROUP BY bool_col, tinyint_col + ORDER BY bool_col, tinyint_col", + ) + .await?; + let results = df.collect().await?; + + // Verify we got results with correct grouping + assert!(!results.is_empty()); + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert!(total_rows > 0); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_aggregate_sum(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx.sql("SELECT SUM(id) FROM test").await?; + let results = df.collect().await?; + + let expected = vec![ + "+--------------+", + "| sum(test.id) |", + "+--------------+", + "| 28 |", + "+--------------+", + ]; + assert_result_eq(expected, &results); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_aggregate_avg(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx.sql("SELECT AVG(id) FROM test").await?; + let results = df.collect().await?; + + let expected = vec![ + "+--------------+", + "| avg(test.id) |", + "+--------------+", + "| 3.5 |", + "+--------------+", + ]; + assert_result_eq(expected, &results); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_aggregate_count( + #[case] read_mode: ReadMode, + ) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx.sql("SELECT COUNT(*) FROM test").await?; + let results = df.collect().await?; + + let expected = vec![ + "+----------+", + "| count(*) |", + "+----------+", + "| 8 |", + "+----------+", + ]; + assert_result_eq(expected, &results); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_aggregate_min_max( + #[case] read_mode: ReadMode, + ) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx.sql("SELECT MIN(id), MAX(id) FROM test").await?; + let results = df.collect().await?; + + let expected = vec![ + "+--------------+--------------+", + "| min(test.id) | max(test.id) |", + "+--------------+--------------+", + "| 0 | 7 |", + "+--------------+--------------+", + ]; + assert_result_eq(expected, &results); + Ok(()) + } + + // ==================== Comparison with Hash Shuffle ==================== + + #[tokio::test] + async fn test_sort_vs_hash_shuffle_group_by() -> Result<()> { + // Test with sort shuffle (local read is sufficient for comparison) + let sort_ctx = create_sort_shuffle_context(ReadMode::Local).await; + register_test_data(&sort_ctx).await; + let sort_results = sort_ctx + .sql("SELECT bool_col, SUM(id) as total FROM test GROUP BY bool_col") + .await? + .collect() + .await?; + + // Test with hash shuffle + let hash_ctx = create_hash_shuffle_context().await; + register_test_data(&hash_ctx).await; + let hash_results = hash_ctx + .sql("SELECT bool_col, SUM(id) as total FROM test GROUP BY bool_col") + .await? + .collect() + .await?; + + // Results should be equivalent (order may differ) + let sort_values = extract_values_unordered(&sort_results); + let hash_values = extract_values_unordered(&hash_results); + assert_eq!(sort_values, hash_values); + Ok(()) + } + + #[tokio::test] + async fn test_sort_vs_hash_shuffle_distinct() -> Result<()> { + // Test with sort shuffle (local read is sufficient for comparison) + let sort_ctx = create_sort_shuffle_context(ReadMode::Local).await; + register_test_data(&sort_ctx).await; + let sort_results = sort_ctx + .sql("SELECT DISTINCT bool_col FROM test") + .await? + .collect() + .await?; + + // Test with hash shuffle + let hash_ctx = create_hash_shuffle_context().await; + register_test_data(&hash_ctx).await; + let hash_results = hash_ctx + .sql("SELECT DISTINCT bool_col FROM test") + .await? + .collect() + .await?; + + // Results should be equivalent + let sort_values = extract_values_unordered(&sort_results); + let hash_values = extract_values_unordered(&hash_results); + assert_eq!(sort_values, hash_values); + Ok(()) + } + + // ==================== Edge Cases ==================== + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_empty_result(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx.sql("SELECT id FROM test WHERE id > 100").await?; + let results = df.collect().await?; + + // Should return empty result without error + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 0); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_single_partition( + #[case] read_mode: ReadMode, + ) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + // Query that results in single partition output + let df = ctx.sql("SELECT COUNT(*) FROM test").await?; + let results = df.collect().await?; + + assert!(!results.is_empty()); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_multiple_aggregates( + #[case] read_mode: ReadMode, + ) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx + .sql( + "SELECT + bool_col, + COUNT(*) as cnt, + SUM(id) as sum_id, + AVG(id) as avg_id, + MIN(id) as min_id, + MAX(id) as max_id + FROM test + GROUP BY bool_col + ORDER BY bool_col", + ) + .await?; + let results = df.collect().await?; + + assert!(!results.is_empty()); + // Verify we have 2 groups (true and false) + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_having_clause(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx + .sql( + "SELECT bool_col, COUNT(*) as cnt + FROM test + GROUP BY bool_col + HAVING COUNT(*) > 2 + ORDER BY bool_col", + ) + .await?; + let results = df.collect().await?; + + // Both groups should have count > 2 + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + Ok(()) + } + + // ==================== Subquery and Complex Queries ==================== + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_subquery(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx + .sql( + "SELECT * FROM ( + SELECT bool_col, COUNT(*) as cnt + FROM test + GROUP BY bool_col + ) sub + WHERE cnt > 0", + ) + .await?; + let results = df.collect().await?; + + assert!(!results.is_empty()); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_union(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx + .sql( + "SELECT bool_col, COUNT(*) as cnt FROM test WHERE id < 4 GROUP BY bool_col + UNION ALL + SELECT bool_col, COUNT(*) as cnt FROM test WHERE id >= 4 GROUP BY bool_col", + ) + .await?; + let results = df.collect().await?; + + // Should have results from both parts of the union + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert!(total_rows >= 2); + Ok(()) + } + + // ==================== Order By with Shuffle ==================== + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_order_by(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx.sql("SELECT id FROM test ORDER BY id").await?; + let results = df.collect().await?; + + // Verify ordering is correct + let expected = vec![ + "+----+", "| id |", "+----+", "| 0 |", "| 1 |", "| 2 |", "| 3 |", + "| 4 |", "| 5 |", "| 6 |", "| 7 |", "+----+", + ]; + assert_result_eq(expected, &results); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_order_by_desc(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx.sql("SELECT id FROM test ORDER BY id DESC").await?; + let results = df.collect().await?; + + // Verify descending order + let expected = vec![ + "+----+", "| id |", "+----+", "| 7 |", "| 6 |", "| 5 |", "| 4 |", + "| 3 |", "| 2 |", "| 1 |", "| 0 |", "+----+", + ]; + assert_result_eq(expected, &results); + Ok(()) + } + + #[rstest] + #[case::local(ReadMode::Local)] + #[case::remote_flight(ReadMode::RemoteFlight)] + #[tokio::test] + async fn test_sort_shuffle_limit(#[case] read_mode: ReadMode) -> Result<()> { + let ctx = create_sort_shuffle_context(read_mode).await; + register_test_data(&ctx).await; + + let df = ctx.sql("SELECT id FROM test ORDER BY id LIMIT 3").await?; + let results = df.collect().await?; + + let expected = vec![ + "+----+", "| id |", "+----+", "| 0 |", "| 1 |", "| 2 |", "+----+", + ]; + assert_result_eq(expected, &results); + Ok(()) + } +} diff --git a/ballista/core/proto/ballista.proto b/ballista/core/proto/ballista.proto index 34fd43811..3f28cd26e 100644 --- a/ballista/core/proto/ballista.proto +++ b/ballista/core/proto/ballista.proto @@ -35,6 +35,7 @@ message BallistaPhysicalPlanNode { ShuffleWriterExecNode shuffle_writer = 1; ShuffleReaderExecNode shuffle_reader = 2; UnresolvedShuffleExecNode unresolved_shuffle = 3; + SortShuffleWriterExecNode sort_shuffle_writer = 4; } } @@ -47,6 +48,18 @@ message ShuffleWriterExecNode { datafusion.PhysicalHashRepartition output_partitioning = 4; } +// Sort-based shuffle writer that produces consolidated files with index +message SortShuffleWriterExecNode { + string job_id = 1; + uint32 stage_id = 2; + datafusion.PhysicalPlanNode input = 3; + datafusion.PhysicalHashRepartition output_partitioning = 4; + // Configuration for sort shuffle + uint64 buffer_size = 5; + uint64 memory_limit = 6; + double spill_threshold = 7; +} + message UnresolvedShuffleExecNode { uint32 stage_id = 1; datafusion_common.Schema schema = 2; diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 1e6a593fd..86f486262 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -57,6 +57,19 @@ pub const BALLISTA_GRPC_CLIENT_TCP_KEEPALIVE_SECONDS: &str = pub const BALLISTA_GRPC_CLIENT_HTTP2_KEEPALIVE_INTERVAL_SECONDS: &str = "ballista.grpc.client.http2_keepalive_interval_seconds"; +/// Configuration key for enabling sort-based shuffle. +pub const BALLISTA_SHUFFLE_SORT_BASED_ENABLED: &str = + "ballista.shuffle.sort_based.enabled"; +/// Configuration key for sort shuffle per-partition buffer size in bytes. +pub const BALLISTA_SHUFFLE_SORT_BASED_BUFFER_SIZE: &str = + "ballista.shuffle.sort_based.buffer_size"; +/// Configuration key for sort shuffle total memory limit in bytes. +pub const BALLISTA_SHUFFLE_SORT_BASED_MEMORY_LIMIT: &str = + "ballista.shuffle.sort_based.memory_limit"; +/// Configuration key for sort shuffle spill threshold (0.0-1.0). +pub const BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD: &str = + "ballista.shuffle.sort_based.spill_threshold"; + /// Result type for configuration parsing operations. pub type ParseResult = result::Result; use std::sync::LazyLock; @@ -100,7 +113,23 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| ConfigEntry::new(BALLISTA_GRPC_CLIENT_HTTP2_KEEPALIVE_INTERVAL_SECONDS.to_string(), "HTTP/2 keep-alive interval for gRPC client in seconds".to_string(), DataType::UInt64, - Some((300).to_string())) + Some((300).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_ENABLED.to_string(), + "Enable sort-based shuffle which writes consolidated files with index".to_string(), + DataType::Boolean, + Some((false).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_BUFFER_SIZE.to_string(), + "Per-partition buffer size in bytes for sort shuffle".to_string(), + DataType::UInt64, + Some((1024 * 1024).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_MEMORY_LIMIT.to_string(), + "Total memory limit in bytes for sort shuffle buffers".to_string(), + DataType::UInt64, + Some((256 * 1024 * 1024).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD.to_string(), + "Spill threshold as decimal fraction (0.0-1.0) of memory limit".to_string(), + DataType::Utf8, + Some("0.8".to_string())) ]; entries .into_iter() @@ -264,6 +293,29 @@ impl BallistaConfig { self.get_bool_setting(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT) } + /// Returns whether sort-based shuffle is enabled. + /// + /// When enabled, shuffle writes produce a single consolidated file per input + /// partition with an index file, rather than one file per output partition. + pub fn shuffle_sort_based_enabled(&self) -> bool { + self.get_bool_setting(BALLISTA_SHUFFLE_SORT_BASED_ENABLED) + } + + /// Returns the per-partition buffer size for sort-based shuffle in bytes. + pub fn shuffle_sort_based_buffer_size(&self) -> usize { + self.get_usize_setting(BALLISTA_SHUFFLE_SORT_BASED_BUFFER_SIZE) + } + + /// Returns the total memory limit for sort-based shuffle buffers in bytes. + pub fn shuffle_sort_based_memory_limit(&self) -> usize { + self.get_usize_setting(BALLISTA_SHUFFLE_SORT_BASED_MEMORY_LIMIT) + } + + /// Returns the spill threshold for sort-based shuffle (0.0-1.0). + pub fn shuffle_sort_based_spill_threshold(&self) -> f64 { + self.get_f64_setting(BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD) + } + fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor @@ -300,6 +352,16 @@ impl BallistaConfig { v.to_string() } } + + fn get_f64_setting(&self, key: &str) -> f64 { + if let Some(v) = self.settings.get(key) { + v.parse::().unwrap() + } else { + let entries = Self::valid_entries(); + let v = entries.get(key).unwrap().default_value.as_ref().unwrap(); + v.parse::().unwrap() + } + } } impl datafusion::config::ExtensionOptions for BallistaConfig { diff --git a/ballista/core/src/diagram.rs b/ballista/core/src/diagram.rs index 7f69e8a94..e1ec42681 100644 --- a/ballista/core/src/diagram.rs +++ b/ballista/core/src/diagram.rs @@ -16,7 +16,9 @@ // under the License. use crate::error::Result; -use crate::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec}; +use crate::execution_plans::{ + ShuffleWriter, ShuffleWriterExec, SortShuffleWriterExec, UnresolvedShuffleExec, +}; use datafusion::datasource::source::DataSourceExec; use datafusion::physical_plan::ExecutionPlan; @@ -37,7 +39,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; /// /// Writes a DOT file to the specified filename that visualizes the relationships /// between execution stages and their operators. -pub fn produce_diagram(filename: &str, stages: &[Arc]) -> Result<()> { +pub fn produce_diagram(filename: &str, stages: &[Arc]) -> Result<()> { let write_file = File::create(filename)?; let mut w = BufWriter::new(&write_file); writeln!(w, "digraph G {{")?; @@ -94,6 +96,12 @@ fn build_exec_plan_diagram( "FilterExec" } else if plan.as_any().downcast_ref::().is_some() { "ShuffleWriterExec" + } else if plan + .as_any() + .downcast_ref::() + .is_some() + { + "SortShuffleWriterExec" } else if plan .as_any() .downcast_ref::() diff --git a/ballista/core/src/execution_plans/mod.rs b/ballista/core/src/execution_plans/mod.rs index 7a5e105c6..c6b6b2976 100644 --- a/ballista/core/src/execution_plans/mod.rs +++ b/ballista/core/src/execution_plans/mod.rs @@ -21,9 +21,13 @@ mod distributed_query; mod shuffle_reader; mod shuffle_writer; +mod shuffle_writer_trait; +pub mod sort_shuffle; mod unresolved_shuffle; pub use distributed_query::DistributedQueryExec; pub use shuffle_reader::ShuffleReaderExec; pub use shuffle_writer::ShuffleWriterExec; +pub use shuffle_writer_trait::ShuffleWriter; +pub use sort_shuffle::SortShuffleWriterExec; pub use unresolved_shuffle::UnresolvedShuffleExec; diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 5063d1d75..6776d3597 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -29,6 +29,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::client::BallistaClient; +use crate::execution_plans::sort_shuffle::{ + get_index_path, is_sort_shuffle_output, stream_sort_shuffle_partition, +}; use crate::extension::SessionConfigExt; use crate::serde::scheduler::{PartitionLocation, PartitionStats}; @@ -525,7 +528,31 @@ async fn fetch_partition_local( let path = &location.path; let metadata = &location.executor_meta; let partition_id = &location.partition_id; + let data_path = std::path::Path::new(path); + + // Check if this is a sort-based shuffle output (has index file) + if is_sort_shuffle_output(data_path) { + debug!( + "Reading sort-based shuffle for partition {} from {:?}", + partition_id.partition_id, data_path + ); + let index_path = get_index_path(data_path); + return stream_sort_shuffle_partition( + data_path, + &index_path, + partition_id.partition_id, + ) + .map_err(|e| { + BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + e.to_string(), + ) + }); + } + // Standard hash-based shuffle - read the file directly let reader = fetch_partition_local_inner(path).map_err(|e| { // return BallistaError::FetchFailed may let scheduler retry this task. BallistaError::FetchFailed( diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs b/ballista/core/src/execution_plans/shuffle_writer.rs index e5193c2e1..f0d9a9c74 100644 --- a/ballista/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/core/src/execution_plans/shuffle_writer.rs @@ -63,6 +63,8 @@ use datafusion::physical_plan::repartition::BatchPartitioner; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use log::{debug, info}; +use super::shuffle_writer_trait::ShuffleWriter; + /// ShuffleWriterExec represents a section of a query plan that has consistent partitioning and /// can be executed as one unit with each partition being executed in parallel. The output of each /// partition is re-partitioned and streamed to disk in Arrow IPC format. Future stages of the query @@ -502,6 +504,31 @@ impl ExecutionPlan for ShuffleWriterExec { } } +impl ShuffleWriter for ShuffleWriterExec { + fn job_id(&self) -> &str { + &self.job_id + } + + fn stage_id(&self) -> usize { + self.stage_id + } + + fn shuffle_output_partitioning(&self) -> Option<&Partitioning> { + self.shuffle_output_partitioning.as_ref() + } + + fn input_partition_count(&self) -> usize { + self.plan + .properties() + .output_partitioning() + .partition_count() + } + + fn clone_box(&self) -> Arc { + Arc::new(self.clone()) + } +} + fn result_schema() -> SchemaRef { let stats = PartitionStats::default(); Arc::new(Schema::new(vec![ diff --git a/ballista/core/src/execution_plans/shuffle_writer_trait.rs b/ballista/core/src/execution_plans/shuffle_writer_trait.rs new file mode 100644 index 000000000..dde69c11a --- /dev/null +++ b/ballista/core/src/execution_plans/shuffle_writer_trait.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common trait for shuffle writer execution plans. +//! +//! This trait provides a common interface for both standard hash-based shuffle +//! (`ShuffleWriterExec`) and sort-based shuffle (`SortShuffleWriterExec`). + +use datafusion::physical_plan::{ExecutionPlan, Partitioning}; +use std::fmt::Debug; +use std::sync::Arc; + +/// Trait for shuffle writer execution plans. +/// +/// This trait defines the common interface needed by the distributed planner +/// and execution graph to work with different shuffle implementations. +pub trait ShuffleWriter: ExecutionPlan + Debug + Send + Sync { + /// Get the Job ID for this query stage. + fn job_id(&self) -> &str; + + /// Get the Stage ID for this query stage. + fn stage_id(&self) -> usize; + + /// Get the shuffle output partitioning, if any. + /// + /// Returns `Some(partitioning)` for repartitioning stages, + /// `None` for stages that preserve the input partitioning. + fn shuffle_output_partitioning(&self) -> Option<&Partitioning>; + + /// Get the number of input partitions. + fn input_partition_count(&self) -> usize; + + /// Clone this shuffle writer as an Arc'd trait object. + fn clone_box(&self) -> Arc; +} diff --git a/ballista/core/src/execution_plans/sort_shuffle/buffer.rs b/ballista/core/src/execution_plans/sort_shuffle/buffer.rs new file mode 100644 index 000000000..61292af2c --- /dev/null +++ b/ballista/core/src/execution_plans/sort_shuffle/buffer.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! In-memory partition buffer for sort-based shuffle. +//! +//! Each output partition has a buffer that accumulates record batches +//! until the buffer is full or needs to be spilled to disk. + +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; + +/// Buffer for accumulating record batches for a single output partition. +/// +/// When the buffer exceeds its maximum size, it signals that it should be +/// spilled to disk. +#[derive(Debug)] +pub struct PartitionBuffer { + /// Partition ID this buffer is for + partition_id: usize, + /// Buffered record batches + batches: Vec, + /// Current memory usage in bytes + memory_used: usize, + /// Number of rows in the buffer + num_rows: usize, + /// Schema for this partition's data + schema: SchemaRef, +} + +impl PartitionBuffer { + /// Creates a new partition buffer. + pub fn new(partition_id: usize, schema: SchemaRef) -> Self { + Self { + partition_id, + batches: Vec::new(), + memory_used: 0, + num_rows: 0, + schema, + } + } + + /// Returns the partition ID for this buffer. + pub fn partition_id(&self) -> usize { + self.partition_id + } + + /// Returns the schema for this buffer's data. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Returns the current memory usage in bytes. + pub fn memory_used(&self) -> usize { + self.memory_used + } + + /// Returns the number of rows in the buffer. + pub fn num_rows(&self) -> usize { + self.num_rows + } + + /// Returns the number of batches in the buffer. + pub fn num_batches(&self) -> usize { + self.batches.len() + } + + /// Returns true if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.batches.is_empty() + } + + /// Appends a record batch to the buffer. + /// + /// Returns the new total memory usage after appending. + pub fn append(&mut self, batch: RecordBatch) -> usize { + let batch_size = batch.get_array_memory_size(); + self.num_rows += batch.num_rows(); + self.memory_used += batch_size; + self.batches.push(batch); + self.memory_used + } + + /// Drains all batches from the buffer, resetting it to empty. + /// + /// Returns the drained batches. + pub fn drain(&mut self) -> Vec { + self.memory_used = 0; + self.num_rows = 0; + std::mem::take(&mut self.batches) + } + + /// Takes all batches from the buffer without resetting memory tracking. + /// + /// This is useful when the caller wants to handle the batches but the + /// buffer will be discarded anyway. + pub fn take_batches(&mut self) -> Vec { + std::mem::take(&mut self.batches) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])) + } + + fn create_test_batch(schema: &SchemaRef, values: Vec) -> RecordBatch { + let array = Int32Array::from(values); + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() + } + + #[test] + fn test_new_buffer() { + let schema = create_test_schema(); + let buffer = PartitionBuffer::new(0, schema); + + assert_eq!(buffer.partition_id(), 0); + assert!(buffer.is_empty()); + assert_eq!(buffer.memory_used(), 0); + assert_eq!(buffer.num_rows(), 0); + assert_eq!(buffer.num_batches(), 0); + } + + #[test] + fn test_append() { + let schema = create_test_schema(); + let mut buffer = PartitionBuffer::new(0, schema.clone()); + + let batch = create_test_batch(&schema, vec![1, 2, 3]); + buffer.append(batch); + + assert!(!buffer.is_empty()); + assert!(buffer.memory_used() > 0); + assert_eq!(buffer.num_rows(), 3); + assert_eq!(buffer.num_batches(), 1); + } + + #[test] + fn test_drain() { + let schema = create_test_schema(); + let mut buffer = PartitionBuffer::new(0, schema.clone()); + + buffer.append(create_test_batch(&schema, vec![1, 2, 3])); + buffer.append(create_test_batch(&schema, vec![4, 5])); + + assert_eq!(buffer.num_batches(), 2); + assert_eq!(buffer.num_rows(), 5); + + let batches = buffer.drain(); + + assert_eq!(batches.len(), 2); + assert!(buffer.is_empty()); + assert_eq!(buffer.memory_used(), 0); + assert_eq!(buffer.num_rows(), 0); + } +} diff --git a/ballista/core/src/execution_plans/sort_shuffle/config.rs b/ballista/core/src/execution_plans/sort_shuffle/config.rs new file mode 100644 index 000000000..7c323f78d --- /dev/null +++ b/ballista/core/src/execution_plans/sort_shuffle/config.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Configuration for sort-based shuffle. + +use datafusion::arrow::ipc::CompressionType; + +/// Configuration for sort-based shuffle. +/// +/// Controls memory buffering, spilling behavior, and compression settings +/// for the sort-based shuffle writer. +#[derive(Debug, Clone)] +pub struct SortShuffleConfig { + /// Whether sort-based shuffle is enabled (default: false) + pub enabled: bool, + /// Per-partition buffer size in bytes before considering spill (default: 1MB) + pub buffer_size: usize, + /// Total memory limit for all shuffle buffers combined (default: 256MB) + pub memory_limit: usize, + /// Spill threshold as fraction of memory limit (default: 0.8) + /// When total memory usage exceeds `memory_limit * spill_threshold`, + /// the largest buffers are spilled to disk. + pub spill_threshold: f64, + /// Compression codec for shuffle data (default: LZ4_FRAME) + pub compression: CompressionType, +} + +impl Default for SortShuffleConfig { + fn default() -> Self { + Self { + enabled: false, + buffer_size: 1024 * 1024, // 1 MB + memory_limit: 256 * 1024 * 1024, // 256 MB + spill_threshold: 0.8, + compression: CompressionType::LZ4_FRAME, + } + } +} + +impl SortShuffleConfig { + /// Creates a new configuration with the specified settings. + pub fn new( + enabled: bool, + buffer_size: usize, + memory_limit: usize, + spill_threshold: f64, + compression: CompressionType, + ) -> Self { + Self { + enabled, + buffer_size, + memory_limit, + spill_threshold, + compression, + } + } + + /// Returns the memory threshold at which spilling should occur. + pub fn spill_memory_threshold(&self) -> usize { + (self.memory_limit as f64 * self.spill_threshold) as usize + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = SortShuffleConfig::default(); + assert!(!config.enabled); + assert_eq!(config.buffer_size, 1024 * 1024); + assert_eq!(config.memory_limit, 256 * 1024 * 1024); + assert!((config.spill_threshold - 0.8).abs() < f64::EPSILON); + } + + #[test] + fn test_spill_memory_threshold() { + let config = SortShuffleConfig { + memory_limit: 100, + spill_threshold: 0.8, + ..Default::default() + }; + assert_eq!(config.spill_memory_threshold(), 80); + } +} diff --git a/ballista/core/src/execution_plans/sort_shuffle/index.rs b/ballista/core/src/execution_plans/sort_shuffle/index.rs new file mode 100644 index 000000000..c1e630873 --- /dev/null +++ b/ballista/core/src/execution_plans/sort_shuffle/index.rs @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shuffle index file for sort-based shuffle. +//! +//! The index file stores byte offsets for each partition in the consolidated +//! shuffle data file. Format: +//! +//! ```text +//! [i64: offset_0][i64: offset_1]...[i64: offset_n][i64: total_length] +//! ``` +//! +//! - All values are little-endian i64 +//! - `offset_i` = byte offset where partition `i` starts +//! - Last entry is total file length +//! - Partition `i` data spans `[offset_i, offset_{i+1})` + +use crate::error::{BallistaError, Result}; +use std::fs::File; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::path::Path; + +/// Shuffle index that maps partition IDs to byte offsets in the data file. +#[derive(Debug, Clone)] +pub struct ShuffleIndex { + /// Byte offsets for each partition. Length is partition_count + 1, + /// where the last entry is the total file length. + offsets: Vec, +} + +impl ShuffleIndex { + /// Creates a new shuffle index for the given number of partitions. + /// + /// All offsets are initialized to 0. + pub fn new(partition_count: usize) -> Self { + Self { + offsets: vec![0i64; partition_count + 1], + } + } + + /// Returns the number of partitions in this index. + pub fn partition_count(&self) -> usize { + self.offsets.len().saturating_sub(1) + } + + /// Sets the byte offset for a partition. + /// + /// # Panics + /// Panics if `partition_id >= partition_count`. + pub fn set_offset(&mut self, partition_id: usize, offset: i64) { + self.offsets[partition_id] = offset; + } + + /// Sets the total file length (stored as the last entry). + pub fn set_total_length(&mut self, length: i64) { + if let Some(last) = self.offsets.last_mut() { + *last = length; + } + } + + /// Returns the byte range `(start, end)` for the given partition. + /// + /// The partition data spans `[start, end)` bytes in the data file. + /// + /// # Panics + /// Panics if `partition_id >= partition_count`. + pub fn get_partition_range(&self, partition_id: usize) -> (i64, i64) { + (self.offsets[partition_id], self.offsets[partition_id + 1]) + } + + /// Returns the size in bytes for the given partition. + pub fn get_partition_size(&self, partition_id: usize) -> i64 { + let (start, end) = self.get_partition_range(partition_id); + end - start + } + + /// Returns true if the partition has data (size > 0). + pub fn partition_has_data(&self, partition_id: usize) -> bool { + self.get_partition_size(partition_id) > 0 + } + + /// Writes the index to a file. + pub fn write_to_file(&self, path: &Path) -> Result<()> { + let file = File::create(path).map_err(BallistaError::IoError)?; + let mut writer = BufWriter::new(file); + + for &offset in &self.offsets { + writer + .write_all(&offset.to_le_bytes()) + .map_err(BallistaError::IoError)?; + } + + writer.flush().map_err(BallistaError::IoError)?; + Ok(()) + } + + /// Reads an index from a file. + pub fn read_from_file(path: &Path) -> Result { + let file = File::open(path).map_err(BallistaError::IoError)?; + let metadata = file.metadata().map_err(BallistaError::IoError)?; + let file_size = metadata.len() as usize; + + // Each offset is 8 bytes (i64) + if !file_size.is_multiple_of(8) { + return Err(BallistaError::General(format!( + "Invalid index file size: {file_size} (must be multiple of 8)" + ))); + } + + let entry_count = file_size / 8; + if entry_count < 2 { + return Err(BallistaError::General(format!( + "Index file too small: {entry_count} entries (need at least 2)" + ))); + } + + let mut reader = BufReader::new(file); + let mut offsets = Vec::with_capacity(entry_count); + let mut buf = [0u8; 8]; + + for _ in 0..entry_count { + reader + .read_exact(&mut buf) + .map_err(BallistaError::IoError)?; + offsets.push(i64::from_le_bytes(buf)); + } + + Ok(Self { offsets }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn test_new_index() { + let index = ShuffleIndex::new(4); + assert_eq!(index.partition_count(), 4); + assert_eq!(index.offsets.len(), 5); + } + + #[test] + fn test_set_offsets() { + let mut index = ShuffleIndex::new(3); + index.set_offset(0, 0); + index.set_offset(1, 100); + index.set_offset(2, 250); + index.set_total_length(500); + + assert_eq!(index.get_partition_range(0), (0, 100)); + assert_eq!(index.get_partition_range(1), (100, 250)); + assert_eq!(index.get_partition_range(2), (250, 500)); + + assert_eq!(index.get_partition_size(0), 100); + assert_eq!(index.get_partition_size(1), 150); + assert_eq!(index.get_partition_size(2), 250); + } + + #[test] + fn test_partition_has_data() { + let mut index = ShuffleIndex::new(3); + index.set_offset(0, 0); + index.set_offset(1, 0); // Empty partition + index.set_offset(2, 100); + index.set_total_length(200); + + assert!(!index.partition_has_data(0)); + assert!(index.partition_has_data(1)); + assert!(index.partition_has_data(2)); + } + + #[test] + fn test_write_and_read() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let index_path = temp_dir.path().join("test.index"); + + // Create and write index + let mut index = ShuffleIndex::new(3); + index.set_offset(0, 0); + index.set_offset(1, 100); + index.set_offset(2, 300); + index.set_total_length(500); + index.write_to_file(&index_path)?; + + // Read it back + let loaded = ShuffleIndex::read_from_file(&index_path)?; + + assert_eq!(loaded.partition_count(), 3); + assert_eq!(loaded.get_partition_range(0), (0, 100)); + assert_eq!(loaded.get_partition_range(1), (100, 300)); + assert_eq!(loaded.get_partition_range(2), (300, 500)); + + Ok(()) + } +} diff --git a/ballista/core/src/execution_plans/sort_shuffle/mod.rs b/ballista/core/src/execution_plans/sort_shuffle/mod.rs new file mode 100644 index 000000000..fa5634b6c --- /dev/null +++ b/ballista/core/src/execution_plans/sort_shuffle/mod.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort-based shuffle implementation for Ballista. +//! +//! This module provides an alternative to the hash-based shuffle. It writes +//! a single consolidated file per input partition (sorted by output partition ID) +//! along with an index file mapping partition IDs to batch ranges. +//! +//! This approach reduces file count from `N × M` (N input partitions × M output partitions) +//! to `2 × N` files (one data + one index per input partition). +//! +//! The algorithm follows the approach used by Apache Spark: internally, results from +//! individual map tasks are kept in memory until they can't fit. Then, these are +//! sorted based on the target partition and written to a single file. On the reduce +//! side, tasks read the relevant sorted blocks. + +mod buffer; +mod config; +mod index; +mod reader; +mod spill; +mod writer; + +pub use buffer::PartitionBuffer; +pub use config::SortShuffleConfig; +pub use index::ShuffleIndex; +pub use reader::{ + get_index_path, is_sort_shuffle_output, read_all_batches, + read_sort_shuffle_partition, stream_sort_shuffle_partition, +}; +pub use spill::SpillManager; +pub use writer::SortShuffleWriterExec; diff --git a/ballista/core/src/execution_plans/sort_shuffle/reader.rs b/ballista/core/src/execution_plans/sort_shuffle/reader.rs new file mode 100644 index 000000000..3fa4c9fcc --- /dev/null +++ b/ballista/core/src/execution_plans/sort_shuffle/reader.rs @@ -0,0 +1,315 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Reader for sort-based shuffle output files. +//! +//! Reads partition data from the consolidated data file using the index +//! file to locate partition boundaries. Uses Arrow IPC FileReader for +//! efficient random access to specific batches. + +use crate::error::{BallistaError, Result}; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ipc::reader::FileReader; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::common::DataFusionError; +use datafusion::physical_plan::SendableRecordBatchStream; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use std::fs::File; +use std::path::Path; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use super::index::ShuffleIndex; + +/// Checks if a shuffle output uses the sort-based format by looking for +/// the index file. +pub fn is_sort_shuffle_output(data_path: &Path) -> bool { + let index_path = data_path.with_extension("arrow.index"); + index_path.exists() +} + +/// Gets the index file path for a data file. +pub fn get_index_path(data_path: &Path) -> std::path::PathBuf { + data_path.with_extension("arrow.index") +} + +/// Reads all batches for a specific partition from a sort shuffle data file. +/// +/// Uses Arrow IPC FileReader for efficient random access - directly reads +/// only the batches belonging to the requested partition without scanning +/// through preceding data. +/// +/// # Arguments +/// * `data_path` - Path to the consolidated data file (Arrow IPC File format) +/// * `index_path` - Path to the index file +/// * `partition_id` - The partition to read +/// +/// # Returns +/// Vector of record batches for the requested partition. +pub fn read_sort_shuffle_partition( + data_path: &Path, + index_path: &Path, + partition_id: usize, +) -> Result> { + // Load the index + let index = ShuffleIndex::read_from_file(index_path)?; + + if partition_id >= index.partition_count() { + return Err(BallistaError::General(format!( + "Partition {partition_id} not found in index (max: {})", + index.partition_count() + ))); + } + + // Check if partition has data + if !index.partition_has_data(partition_id) { + return Ok(Vec::new()); + } + + // Get the batch range for this partition from the index + // The index stores cumulative batch counts: + // - offset[i] = starting batch index for partition i + // - offset[i+1] (or total_length for last partition) = ending batch index (exclusive) + let (start_batch, end_batch) = index.get_partition_range(partition_id); + let start_batch = start_batch as usize; + let end_batch = end_batch as usize; + + // Open the data file with FileReader for random access + let file = File::open(data_path).map_err(BallistaError::IoError)?; + let mut reader = FileReader::try_new(file, None)?; + + let mut batches = Vec::with_capacity(end_batch - start_batch); + + // Use FileReader's set_index() for random access to specific batches + // This positions the reader directly at the starting batch index + reader.set_index(start_batch)?; + + // Read only the batches we need for this partition + for _ in start_batch..end_batch { + match reader.next() { + Some(Ok(batch)) => batches.push(batch), + Some(Err(e)) => return Err(e.into()), + None => break, + } + } + + Ok(batches) +} + +/// A stream that reads batches from a sort shuffle partition lazily. +/// +/// Wraps an Arrow FileReader and yields batches one at a time without +/// loading them all into memory upfront. +struct SortShufflePartitionStream { + reader: FileReader, + schema: SchemaRef, + remaining: usize, +} + +impl SortShufflePartitionStream { + fn new(reader: FileReader, schema: SchemaRef, num_batches: usize) -> Self { + Self { + reader, + schema, + remaining: num_batches, + } + } +} + +impl futures::Stream for SortShufflePartitionStream { + type Item = std::result::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(None); + } + + match self.reader.next() { + Some(Ok(batch)) => { + self.remaining -= 1; + Poll::Ready(Some(Ok(batch))) + } + Some(Err(e)) => { + self.remaining = 0; + Poll::Ready(Some(Err(DataFusionError::ArrowError(Box::new(e), None)))) + } + None => { + self.remaining = 0; + Poll::Ready(None) + } + } + } +} + +impl datafusion::physical_plan::RecordBatchStream for SortShufflePartitionStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Returns a stream of record batches for a specific partition from a sort shuffle data file. +/// +/// Unlike `read_sort_shuffle_partition`, this returns a lazy stream that reads batches +/// on-demand rather than loading all batches into memory upfront. +/// +/// # Arguments +/// * `data_path` - Path to the consolidated data file (Arrow IPC File format) +/// * `index_path` - Path to the index file +/// * `partition_id` - The partition to read +/// +/// # Returns +/// A stream of record batches for the requested partition. +pub fn stream_sort_shuffle_partition( + data_path: &Path, + index_path: &Path, + partition_id: usize, +) -> Result { + // Load the index + let index = ShuffleIndex::read_from_file(index_path)?; + + if partition_id >= index.partition_count() { + return Err(BallistaError::General(format!( + "Partition {partition_id} not found in index (max: {})", + index.partition_count() + ))); + } + + // Open the data file to get the schema + let file = File::open(data_path).map_err(BallistaError::IoError)?; + let reader = FileReader::try_new(file, None)?; + let schema = reader.schema(); + + // Check if partition has data + if !index.partition_has_data(partition_id) { + // Return empty stream with the schema + let empty_stream = futures::stream::empty(); + return Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + empty_stream, + ))); + } + + // Get the batch range for this partition + let (start_batch, end_batch) = index.get_partition_range(partition_id); + let start_batch = start_batch as usize; + let end_batch = end_batch as usize; + let num_batches = end_batch - start_batch; + + // Re-open and position the reader at the start batch + let file = File::open(data_path).map_err(BallistaError::IoError)?; + let mut reader = FileReader::try_new(file, None)?; + reader.set_index(start_batch)?; + + Ok(Box::pin(SortShufflePartitionStream::new( + reader, + schema, + num_batches, + ))) +} + +/// Reads all batches from a sort shuffle data file. +/// +/// # Arguments +/// * `data_path` - Path to the consolidated data file (Arrow IPC File format) +/// +/// # Returns +/// Vector of all record batches in the file. +pub fn read_all_batches(data_path: &Path) -> Result> { + let file = File::open(data_path).map_err(BallistaError::IoError)?; + let reader = FileReader::try_new(file, None)?; + + let mut batches = Vec::with_capacity(reader.num_batches()); + for batch_result in reader { + batches.push(batch_result?); + } + + Ok(batches) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::ipc::CompressionType; + use datafusion::arrow::ipc::writer::{FileWriter, IpcWriteOptions}; + use std::io::BufWriter; + use std::sync::Arc; + use tempfile::TempDir; + + fn create_test_schema() -> datafusion::arrow::datatypes::SchemaRef { + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])) + } + + fn create_test_batch( + schema: &datafusion::arrow::datatypes::SchemaRef, + values: Vec, + ) -> RecordBatch { + let array = Int32Array::from(values); + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() + } + + #[test] + fn test_is_sort_shuffle_output() { + let temp_dir = TempDir::new().unwrap(); + let data_path = temp_dir.path().join("data.arrow"); + let index_path = temp_dir.path().join("data.arrow.index"); + + // No index file + std::fs::write(&data_path, b"test").unwrap(); + assert!(!is_sort_shuffle_output(&data_path)); + + // With index file + std::fs::write(&index_path, b"test").unwrap(); + assert!(is_sort_shuffle_output(&data_path)); + } + + #[test] + fn test_read_all_batches() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let schema = create_test_schema(); + let data_path = temp_dir.path().join("data.arrow"); + + // Write test data using FileWriter (IPC File format) + let file = File::create(&data_path).unwrap(); + let mut buffered = BufWriter::new(file); + let options = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME)) + .unwrap(); + let mut writer = + FileWriter::try_new_with_options(&mut buffered, &schema, options).unwrap(); + + writer + .write(&create_test_batch(&schema, vec![1, 2, 3])) + .unwrap(); + writer + .write(&create_test_batch(&schema, vec![4, 5])) + .unwrap(); + writer.finish().unwrap(); + + // Read back + let batches = read_all_batches(&data_path)?; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 3); + assert_eq!(batches[1].num_rows(), 2); + + Ok(()) + } +} diff --git a/ballista/core/src/execution_plans/sort_shuffle/spill.rs b/ballista/core/src/execution_plans/sort_shuffle/spill.rs new file mode 100644 index 000000000..254775a1f --- /dev/null +++ b/ballista/core/src/execution_plans/sort_shuffle/spill.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spill manager for sort-based shuffle. +//! +//! Handles writing partition buffers to disk when memory pressure is high, +//! and reading them back during the finalization phase. + +use crate::error::{BallistaError, Result}; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ipc::reader::StreamReader; +use datafusion::arrow::ipc::writer::StreamWriter; +use datafusion::arrow::ipc::{CompressionType, writer::IpcWriteOptions}; +use datafusion::arrow::record_batch::RecordBatch; +use log::debug; +use std::collections::HashMap; +use std::fs::File; +use std::io::BufWriter; +use std::path::PathBuf; + +/// Manages spill files for sort-based shuffle. +/// +/// When partition buffers exceed memory limits, they are spilled to disk +/// as Arrow IPC files. During finalization, these spill files are read +/// back and merged into the consolidated output file. +#[derive(Debug)] +pub struct SpillManager { + /// Base directory for spill files + spill_dir: PathBuf, + /// Spill files per output partition: partition_id -> Vec + spill_files: HashMap>, + /// Counter for generating unique spill file names + spill_counter: usize, + /// Compression codec for spill files + compression: CompressionType, + /// Total number of spills performed + total_spills: usize, + /// Total bytes spilled + total_bytes_spilled: u64, +} + +impl SpillManager { + /// Creates a new spill manager. + /// + /// # Arguments + /// * `work_dir` - Base work directory + /// * `job_id` - Job identifier + /// * `stage_id` - Stage identifier + /// * `input_partition` - Input partition number + /// * `compression` - Compression codec for spill files + pub fn new( + work_dir: &str, + job_id: &str, + stage_id: usize, + input_partition: usize, + compression: CompressionType, + ) -> Result { + let mut spill_dir = PathBuf::from(work_dir); + spill_dir.push(job_id); + spill_dir.push(format!("{stage_id}")); + spill_dir.push(format!("{input_partition}")); + spill_dir.push("spill"); + + // Create spill directory + std::fs::create_dir_all(&spill_dir).map_err(BallistaError::IoError)?; + + Ok(Self { + spill_dir, + spill_files: HashMap::new(), + spill_counter: 0, + compression, + total_spills: 0, + total_bytes_spilled: 0, + }) + } + + /// Spills batches for a partition to disk. + /// + /// Returns the number of bytes written. + pub fn spill( + &mut self, + partition_id: usize, + batches: Vec, + schema: &SchemaRef, + ) -> Result { + if batches.is_empty() { + return Ok(0); + } + + let spill_path = self.next_spill_path(partition_id); + debug!( + "Spilling {} batches for partition {} to {:?}", + batches.len(), + partition_id, + spill_path + ); + + let file = File::create(&spill_path).map_err(BallistaError::IoError)?; + let buffered = BufWriter::new(file); + + let options = + IpcWriteOptions::default().try_with_compression(Some(self.compression))?; + + let mut writer = StreamWriter::try_new_with_options(buffered, schema, options)?; + + for batch in &batches { + writer.write(batch)?; + } + + writer.finish()?; + + let bytes_written = std::fs::metadata(&spill_path) + .map_err(BallistaError::IoError)? + .len(); + + // Track the spill file + self.spill_files + .entry(partition_id) + .or_default() + .push(spill_path); + + self.total_spills += 1; + self.total_bytes_spilled += bytes_written; + + Ok(bytes_written) + } + + /// Returns the spill files for a partition. + pub fn get_spill_files(&self, partition_id: usize) -> &[PathBuf] { + self.spill_files + .get(&partition_id) + .map(|v| v.as_slice()) + .unwrap_or(&[]) + } + + /// Returns true if the partition has spill files. + pub fn has_spill_files(&self, partition_id: usize) -> bool { + self.spill_files + .get(&partition_id) + .is_some_and(|v| !v.is_empty()) + } + + /// Reads all spill files for a partition and returns the batches. + pub fn read_spill_files(&self, partition_id: usize) -> Result> { + let mut all_batches = Vec::new(); + + for spill_path in self.get_spill_files(partition_id) { + let file = File::open(spill_path).map_err(BallistaError::IoError)?; + let reader = StreamReader::try_new(file, None)?; + + for batch_result in reader { + all_batches.push(batch_result?); + } + } + + Ok(all_batches) + } + + /// Cleans up all spill files. + pub fn cleanup(&self) -> Result<()> { + if self.spill_dir.exists() { + std::fs::remove_dir_all(&self.spill_dir).map_err(BallistaError::IoError)?; + } + Ok(()) + } + + /// Returns the total number of spills performed. + pub fn total_spills(&self) -> usize { + self.total_spills + } + + /// Returns the total bytes spilled to disk. + pub fn total_bytes_spilled(&self) -> u64 { + self.total_bytes_spilled + } + + /// Generates the next spill file path for a partition. + fn next_spill_path(&mut self, partition_id: usize) -> PathBuf { + let path = self.spill_dir.join(format!( + "part-{partition_id}-spill-{}.arrow", + self.spill_counter + )); + self.spill_counter += 1; + path + } +} + +impl Drop for SpillManager { + fn drop(&mut self) { + // Best-effort cleanup on drop + if let Err(e) = self.cleanup() { + debug!("Failed to cleanup spill files: {e:?}"); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + use tempfile::TempDir; + + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])) + } + + fn create_test_batch(schema: &SchemaRef, values: Vec) -> RecordBatch { + let array = Int32Array::from(values); + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() + } + + #[test] + fn test_spill_and_read() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let schema = create_test_schema(); + + let mut manager = SpillManager::new( + temp_dir.path().to_str().unwrap(), + "job1", + 1, + 0, + CompressionType::LZ4_FRAME, + )?; + + // Spill some batches + let batches = vec![ + create_test_batch(&schema, vec![1, 2, 3]), + create_test_batch(&schema, vec![4, 5]), + ]; + let bytes = manager.spill(0, batches, &schema)?; + assert!(bytes > 0); + + // Verify spill tracking + assert!(manager.has_spill_files(0)); + assert!(!manager.has_spill_files(1)); + assert_eq!(manager.total_spills(), 1); + + // Read back + let read_batches = manager.read_spill_files(0)?; + assert_eq!(read_batches.len(), 2); + assert_eq!(read_batches[0].num_rows(), 3); + assert_eq!(read_batches[1].num_rows(), 2); + + Ok(()) + } + + #[test] + fn test_multiple_spills() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let schema = create_test_schema(); + + let mut manager = SpillManager::new( + temp_dir.path().to_str().unwrap(), + "job1", + 1, + 0, + CompressionType::LZ4_FRAME, + )?; + + // Multiple spills for same partition + manager.spill(0, vec![create_test_batch(&schema, vec![1, 2])], &schema)?; + manager.spill(0, vec![create_test_batch(&schema, vec![3, 4])], &schema)?; + + assert_eq!(manager.get_spill_files(0).len(), 2); + assert_eq!(manager.total_spills(), 2); + + // Read all back + let batches = manager.read_spill_files(0)?; + assert_eq!(batches.len(), 2); + + Ok(()) + } + + #[test] + fn test_cleanup() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let schema = create_test_schema(); + + let mut manager = SpillManager::new( + temp_dir.path().to_str().unwrap(), + "job1", + 1, + 0, + CompressionType::LZ4_FRAME, + )?; + + manager.spill(0, vec![create_test_batch(&schema, vec![1, 2])], &schema)?; + + let spill_dir = manager.spill_dir.clone(); + assert!(spill_dir.exists()); + + manager.cleanup()?; + assert!(!spill_dir.exists()); + + Ok(()) + } +} diff --git a/ballista/core/src/execution_plans/sort_shuffle/writer.rs b/ballista/core/src/execution_plans/sort_shuffle/writer.rs new file mode 100644 index 000000000..debdad15c --- /dev/null +++ b/ballista/core/src/execution_plans/sort_shuffle/writer.rs @@ -0,0 +1,699 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort-based shuffle writer execution plan. +//! +//! This execution plan writes shuffle output as a single consolidated file +//! per input partition, along with an index file mapping partition IDs to +//! byte offsets. + +use std::any::Any; +use std::fs::File; +use std::future::Future; +use std::io::BufWriter; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Instant; + +use super::super::shuffle_writer_trait::ShuffleWriter; +use super::buffer::PartitionBuffer; +use super::config::SortShuffleConfig; +use super::index::ShuffleIndex; +use super::spill::SpillManager; +use crate::serde::protobuf::ShuffleWritePartition; + +use datafusion::arrow::array::{ + ArrayBuilder, ArrayRef, StringBuilder, StructBuilder, UInt32Builder, UInt64Builder, +}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::error::ArrowError; +use datafusion::arrow::ipc::writer::{FileWriter, IpcWriteOptions}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::context::TaskContext; +use datafusion::physical_plan::memory::MemoryStream; +use datafusion::physical_plan::metrics::{ + self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, +}; +use datafusion::physical_plan::repartition::BatchPartitioner; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, +}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use log::{debug, info}; + +use crate::serde::scheduler::PartitionStats; + +/// Result of finalizing shuffle output: (data_path, index_path, partition_write_stats) +/// where partition_write_stats is (partition_id, num_batches, num_rows, num_bytes) +type FinalizeResult = (PathBuf, PathBuf, Vec<(usize, u64, u64, u64)>); + +/// Sort-based shuffle writer that produces a single consolidated output file +/// per input partition with an index file for partition offsets. +#[derive(Debug, Clone)] +pub struct SortShuffleWriterExec { + /// Unique ID for the job (query) that this stage is a part of + job_id: String, + /// Unique query stage ID within the job + stage_id: usize, + /// Physical execution plan for this query stage + plan: Arc, + /// Path to write output streams to + work_dir: String, + /// Shuffle output partitioning (must be Hash partitioning) + shuffle_output_partitioning: Partitioning, + /// Sort shuffle configuration + config: SortShuffleConfig, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Plan properties + properties: PlanProperties, +} + +#[derive(Debug, Clone)] +struct SortShuffleWriteMetrics { + /// Time spent writing batches to the output file + write_time: metrics::Time, + /// Time spent partitioning input batches + repart_time: metrics::Time, + /// Time spent spilling to disk + spill_time: metrics::Time, + /// Number of input rows + input_rows: metrics::Count, + /// Number of output rows + output_rows: metrics::Count, + /// Number of spills + spill_count: metrics::Count, + /// Bytes spilled to disk + spill_bytes: metrics::Count, +} + +impl SortShuffleWriteMetrics { + fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + Self { + write_time: MetricBuilder::new(metrics).subset_time("write_time", partition), + repart_time: MetricBuilder::new(metrics) + .subset_time("repart_time", partition), + spill_time: MetricBuilder::new(metrics).subset_time("spill_time", partition), + input_rows: MetricBuilder::new(metrics).counter("input_rows", partition), + output_rows: MetricBuilder::new(metrics).output_rows(partition), + spill_count: MetricBuilder::new(metrics).counter("spill_count", partition), + spill_bytes: MetricBuilder::new(metrics).counter("spill_bytes", partition), + } + } +} + +impl SortShuffleWriterExec { + /// Create a new sort-based shuffle writer. + pub fn try_new( + job_id: String, + stage_id: usize, + plan: Arc, + work_dir: String, + shuffle_output_partitioning: Partitioning, + config: SortShuffleConfig, + ) -> Result { + // Sort shuffle only supports hash partitioning + match &shuffle_output_partitioning { + Partitioning::Hash(_, _) => {} + other => { + return Err(DataFusionError::Plan(format!( + "SortShuffleWriterExec only supports Hash partitioning, got: {other:?}" + ))); + } + } + + let properties = PlanProperties::new( + datafusion::physical_expr::EquivalenceProperties::new(plan.schema()), + shuffle_output_partitioning.clone(), + datafusion::physical_plan::execution_plan::EmissionType::Incremental, + datafusion::physical_plan::execution_plan::Boundedness::Bounded, + ); + + Ok(Self { + job_id, + stage_id, + plan, + work_dir, + shuffle_output_partitioning, + config, + metrics: ExecutionPlanMetricsSet::new(), + properties, + }) + } + + /// Get the Job ID for this query stage + pub fn job_id(&self) -> &str { + &self.job_id + } + + /// Get the Stage ID for this query stage + pub fn stage_id(&self) -> usize { + self.stage_id + } + + /// Get the shuffle output partitioning + pub fn shuffle_output_partitioning(&self) -> &Partitioning { + &self.shuffle_output_partitioning + } + + /// Get the sort shuffle configuration + pub fn config(&self) -> &SortShuffleConfig { + &self.config + } + + /// Get the input partition count + pub fn input_partition_count(&self) -> usize { + self.plan + .properties() + .output_partitioning() + .partition_count() + } + + /// Execute the sort-based shuffle write for a single input partition. + pub fn execute_shuffle_write( + self, + input_partition: usize, + context: Arc, + ) -> impl Future>> { + let metrics = SortShuffleWriteMetrics::new(input_partition, &self.metrics); + let config = self.config.clone(); + let plan = self.plan.clone(); + let work_dir = self.work_dir.clone(); + let job_id = self.job_id.clone(); + let stage_id = self.stage_id; + let partitioning = self.shuffle_output_partitioning.clone(); + + async move { + let now = Instant::now(); + let mut stream = plan.execute(input_partition, context)?; + let schema = stream.schema(); + + let Partitioning::Hash(exprs, num_output_partitions) = partitioning else { + return Err(DataFusionError::Internal( + "Expected hash partitioning".to_string(), + )); + }; + + // Create partition buffers + let mut buffers: Vec = (0..num_output_partitions) + .map(|i| PartitionBuffer::new(i, schema.clone())) + .collect(); + + // Create spill manager + let mut spill_manager = SpillManager::new( + &work_dir, + &job_id, + stage_id, + input_partition, + config.compression, + ) + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + // Create batch partitioner + let mut partitioner = BatchPartitioner::try_new( + Partitioning::Hash(exprs, num_output_partitions), + metrics.repart_time.clone(), + )?; + + // Process input stream + while let Some(result) = stream.next().await { + let input_batch = result?; + metrics.input_rows.add(input_batch.num_rows()); + + // Partition the batch + partitioner.partition( + input_batch, + |output_partition, output_batch| { + buffers[output_partition].append(output_batch); + Ok(()) + }, + )?; + + // Check if we need to spill + let total_memory: usize = buffers.iter().map(|b| b.memory_used()).sum(); + if total_memory > config.spill_memory_threshold() { + let timer = metrics.spill_time.timer(); + spill_largest_buffers( + &mut buffers, + &mut spill_manager, + &schema, + config.spill_memory_threshold() / 2, + )?; + timer.done(); + } + } + + // Finalize: write consolidated output file + let timer = metrics.write_time.timer(); + let (data_path, index_path, partition_stats) = finalize_output( + &work_dir, + &job_id, + stage_id, + input_partition, + &mut buffers, + &mut spill_manager, + &schema, + &config, + )?; + timer.done(); + + // Update metrics + metrics.spill_count.add(spill_manager.total_spills()); + metrics + .spill_bytes + .add(spill_manager.total_bytes_spilled() as usize); + + let total_rows: u64 = partition_stats.iter().map(|(_, _, r, _)| *r).sum(); + metrics.output_rows.add(total_rows as usize); + + // Cleanup spill files + spill_manager + .cleanup() + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + info!( + "Sort shuffle write for partition {} completed in {} seconds. \ + Output: {:?}, Index: {:?}, Spills: {}, Spill bytes: {}", + input_partition, + now.elapsed().as_secs(), + data_path, + index_path, + spill_manager.total_spills(), + spill_manager.total_bytes_spilled() + ); + + // Build result - one entry per output partition that has data + let mut results = Vec::new(); + for (part_id, num_batches, num_rows, num_bytes) in partition_stats { + if num_rows > 0 { + results.push(ShuffleWritePartition { + partition_id: part_id as u64, + path: data_path.to_string_lossy().to_string(), + num_batches, + num_rows, + num_bytes, + }); + } + } + + Ok(results) + } + } +} + +/// Spills the largest buffers until total memory is below the target. +fn spill_largest_buffers( + buffers: &mut [PartitionBuffer], + spill_manager: &mut SpillManager, + schema: &SchemaRef, + target_memory: usize, +) -> Result<()> { + loop { + let total_memory: usize = buffers.iter().map(|b| b.memory_used()).sum(); + if total_memory <= target_memory { + break; + } + + // Find the largest buffer + let largest_idx = buffers + .iter() + .enumerate() + .max_by_key(|(_, b)| b.memory_used()) + .map(|(i, _)| i); + + match largest_idx { + Some(idx) if buffers[idx].memory_used() > 0 => { + let partition_id = buffers[idx].partition_id(); + let batches = buffers[idx].drain(); + spill_manager + .spill(partition_id, batches, schema) + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + } + _ => break, // No more buffers to spill + } + } + Ok(()) +} + +/// Finalizes the output by writing the consolidated data file and index file. +/// +/// Returns (data_path, index_path, partition_stats) where partition_stats is +/// a vector of (partition_id, num_batches, num_rows, num_bytes) tuples. +#[allow(clippy::too_many_arguments)] +fn finalize_output( + work_dir: &str, + job_id: &str, + stage_id: usize, + input_partition: usize, + buffers: &mut [PartitionBuffer], + spill_manager: &mut SpillManager, + schema: &SchemaRef, + config: &SortShuffleConfig, +) -> Result { + let num_partitions = buffers.len(); + let mut index = ShuffleIndex::new(num_partitions); + let mut partition_stats = Vec::with_capacity(num_partitions); + + // Create output directory + let mut output_dir = PathBuf::from(work_dir); + output_dir.push(job_id); + output_dir.push(format!("{stage_id}")); + output_dir.push(format!("{input_partition}")); + std::fs::create_dir_all(&output_dir)?; + + let data_path = output_dir.join("data.arrow"); + let index_path = output_dir.join("data.arrow.index"); + + debug!("Writing consolidated shuffle output to {:?}", data_path); + + // Use FileWriter for random access support via FileReader + let file = File::create(&data_path)?; + let mut buffered = BufWriter::new(file); + + let options = + IpcWriteOptions::default().try_with_compression(Some(config.compression))?; + let mut writer = FileWriter::try_new_with_options(&mut buffered, schema, options)?; + + // Track cumulative batch counts - index stores the starting batch index for each partition + // FileReader supports random access to batches by index + let mut cumulative_batch_count: i64 = 0; + + // Write partitions in order + for (partition_id, buffer) in buffers.iter_mut().enumerate() { + // Set the starting batch index for this partition + index.set_offset(partition_id, cumulative_batch_count); + + let mut partition_rows: u64 = 0; + let mut partition_batches: u64 = 0; + let mut partition_bytes: u64 = 0; + + // First, write any spill files for this partition + if spill_manager.has_spill_files(partition_id) { + let spill_batches = spill_manager + .read_spill_files(partition_id) + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + for batch in spill_batches { + partition_rows += batch.num_rows() as u64; + partition_bytes += batch.get_array_memory_size() as u64; + partition_batches += 1; + writer.write(&batch)?; + } + } + + // Then write remaining buffered data + let buffered_batches = buffer.take_batches(); + for batch in buffered_batches { + partition_rows += batch.num_rows() as u64; + partition_bytes += batch.get_array_memory_size() as u64; + partition_batches += 1; + writer.write(&batch)?; + } + + partition_stats.push(( + partition_id, + partition_batches, + partition_rows, + partition_bytes, + )); + + cumulative_batch_count += partition_batches as i64; + } + + // Finish writing (this writes the IPC footer for random access) + writer.finish()?; + + // Store total batch count + index.set_total_length(cumulative_batch_count); + + // Write index file + index + .write_to_file(&index_path) + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + Ok((data_path, index_path, partition_stats)) +} + +impl DisplayAs for SortShuffleWriterExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "SortShuffleWriterExec: partitioning={}", + self.shuffle_output_partitioning + ) + } + DisplayFormatType::TreeRender => { + write!(f, "partitioning={}", self.shuffle_output_partitioning) + } + } + } +} + +impl ExecutionPlan for SortShuffleWriterExec { + fn name(&self) -> &str { + "SortShuffleWriterExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.plan.schema() + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.plan] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + if children.len() == 1 { + let input = children.pop().ok_or_else(|| { + DataFusionError::Plan( + "SortShuffleWriterExec expects single child".to_owned(), + ) + })?; + + Ok(Arc::new(SortShuffleWriterExec::try_new( + self.job_id.clone(), + self.stage_id, + input, + self.work_dir.clone(), + self.shuffle_output_partitioning.clone(), + self.config.clone(), + )?)) + } else { + Err(DataFusionError::Plan( + "SortShuffleWriterExec expects single child".to_owned(), + )) + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let schema = result_schema(); + + let schema_captured = schema.clone(); + let fut_stream = self + .clone() + .execute_shuffle_write(partition, context) + .and_then(|part_loc| async move { + // Build metadata result batch + let num_writers = part_loc.len(); + let mut partition_builder = UInt32Builder::with_capacity(num_writers); + let mut path_builder = + StringBuilder::with_capacity(num_writers, num_writers * 100); + let mut num_rows_builder = UInt64Builder::with_capacity(num_writers); + let mut num_batches_builder = UInt64Builder::with_capacity(num_writers); + let mut num_bytes_builder = UInt64Builder::with_capacity(num_writers); + + for loc in &part_loc { + path_builder.append_value(loc.path.clone()); + partition_builder.append_value(loc.partition_id as u32); + num_rows_builder.append_value(loc.num_rows); + num_batches_builder.append_value(loc.num_batches); + num_bytes_builder.append_value(loc.num_bytes); + } + + // Build arrays + let partition_num: ArrayRef = Arc::new(partition_builder.finish()); + let path: ArrayRef = Arc::new(path_builder.finish()); + let field_builders: Vec> = vec![ + Box::new(num_rows_builder), + Box::new(num_batches_builder), + Box::new(num_bytes_builder), + ]; + let mut stats_builder = StructBuilder::new( + PartitionStats::default().arrow_struct_fields(), + field_builders, + ); + for _ in 0..num_writers { + stats_builder.append(true); + } + let stats = Arc::new(stats_builder.finish()); + + // Build result batch containing metadata + let batch = RecordBatch::try_new( + schema_captured.clone(), + vec![partition_num, path, stats], + )?; + + debug!("SORT SHUFFLE RESULTS METADATA:\n{batch:?}"); + + MemoryStream::try_new(vec![batch], schema_captured, None) + }) + .map_err(|e| ArrowError::ExternalError(Box::new(e))); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::once(fut_stream).try_flatten(), + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.plan.partition_statistics(partition) + } +} + +impl ShuffleWriter for SortShuffleWriterExec { + fn job_id(&self) -> &str { + &self.job_id + } + + fn stage_id(&self) -> usize { + self.stage_id + } + + fn shuffle_output_partitioning(&self) -> Option<&Partitioning> { + Some(&self.shuffle_output_partitioning) + } + + fn input_partition_count(&self) -> usize { + self.plan + .properties() + .output_partitioning() + .partition_count() + } + + fn clone_box(&self) -> Arc { + Arc::new(self.clone()) + } +} + +fn result_schema() -> SchemaRef { + let stats = PartitionStats::default(); + Arc::new(Schema::new(vec![ + Field::new("partition", DataType::UInt32, false), + Field::new("path", DataType::Utf8, false), + stats.arrow_struct_repr(), + ])) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{StringArray, UInt32Array}; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; + use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion::physical_plan::expressions::Column; + use datafusion::prelude::SessionContext; + use tempfile::TempDir; + + fn create_test_input() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, true), + Field::new("b", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![Some(1), Some(3)])), + Arc::new(StringArray::from(vec![Some("hello"), Some("world")])), + ], + )?; + let partition = vec![batch.clone(), batch]; + let partitions = vec![partition.clone(), partition]; + + let memory_data_source = + Arc::new(MemorySourceConfig::try_new(&partitions, schema, None)?); + + Ok(Arc::new(DataSourceExec::new(memory_data_source))) + } + + #[tokio::test] + async fn test_sort_shuffle_writer() -> Result<()> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let input_plan = Arc::new(CoalescePartitionsExec::new(create_test_input()?)); + let work_dir = TempDir::new()?; + + let config = SortShuffleConfig::default(); + + let writer = SortShuffleWriterExec::try_new( + "job1".to_string(), + 1, + input_plan, + work_dir.path().to_str().unwrap().to_string(), + Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2), + config, + )?; + + let mut stream = writer.execute(0, task_ctx)?; + let batches: Vec = stream + .by_ref() + .try_collect() + .await + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + assert_eq!(batches.len(), 1); + let batch = &batches[0]; + assert_eq!(batch.num_columns(), 3); + + // Verify output files exist + let output_dir = work_dir.path().join("job1").join("1").join("0"); + assert!(output_dir.join("data.arrow").exists()); + assert!(output_dir.join("data.arrow.index").exists()); + + Ok(()) + } +} diff --git a/ballista/core/src/serde/generated/ballista.rs b/ballista/core/src/serde/generated/ballista.rs index aeccb7d3d..910b58f2d 100644 --- a/ballista/core/src/serde/generated/ballista.rs +++ b/ballista/core/src/serde/generated/ballista.rs @@ -4,7 +4,10 @@ /// ///////////////////////////////////////////////////////////////////////////////////////////////// #[derive(Clone, PartialEq, ::prost::Message)] pub struct BallistaPhysicalPlanNode { - #[prost(oneof = "ballista_physical_plan_node::PhysicalPlanType", tags = "1, 2, 3")] + #[prost( + oneof = "ballista_physical_plan_node::PhysicalPlanType", + tags = "1, 2, 3, 4" + )] pub physical_plan_type: ::core::option::Option< ballista_physical_plan_node::PhysicalPlanType, >, @@ -19,6 +22,8 @@ pub mod ballista_physical_plan_node { ShuffleReader(super::ShuffleReaderExecNode), #[prost(message, tag = "3")] UnresolvedShuffle(super::UnresolvedShuffleExecNode), + #[prost(message, tag = "4")] + SortShuffleWriter(super::SortShuffleWriterExecNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -36,6 +41,27 @@ pub struct ShuffleWriterExecNode { ::datafusion_proto::protobuf::PhysicalHashRepartition, >, } +/// Sort-based shuffle writer that produces consolidated files with index +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SortShuffleWriterExecNode { + #[prost(string, tag = "1")] + pub job_id: ::prost::alloc::string::String, + #[prost(uint32, tag = "2")] + pub stage_id: u32, + #[prost(message, optional, tag = "3")] + pub input: ::core::option::Option<::datafusion_proto::protobuf::PhysicalPlanNode>, + #[prost(message, optional, tag = "4")] + pub output_partitioning: ::core::option::Option< + ::datafusion_proto::protobuf::PhysicalHashRepartition, + >, + /// Configuration for sort shuffle + #[prost(uint64, tag = "5")] + pub buffer_size: u64, + #[prost(uint64, tag = "6")] + pub memory_limit: u64, + #[prost(double, tag = "7")] + pub spill_threshold: f64, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnresolvedShuffleExecNode { #[prost(uint32, tag = "1")] diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 235603e41..13c4a53f4 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -47,8 +47,9 @@ use std::marker::PhantomData; use std::sync::Arc; use std::{convert::TryInto, io::Cursor}; +use crate::execution_plans::sort_shuffle::SortShuffleConfig; use crate::execution_plans::{ - ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec, + ShuffleReaderExec, ShuffleWriterExec, SortShuffleWriterExec, UnresolvedShuffleExec, }; use crate::serde::protobuf::ballista_physical_plan_node::PhysicalPlanType; use crate::serde::scheduler::PartitionLocation; @@ -310,6 +311,39 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { shuffle_output_partitioning, )?)) } + PhysicalPlanType::SortShuffleWriter(sort_shuffle_writer) => { + let input = inputs[0].clone(); + + let shuffle_output_partitioning = parse_protobuf_hash_partitioning( + sort_shuffle_writer.output_partitioning.as_ref(), + ctx, + input.schema().as_ref(), + self.default_codec.as_ref(), + )?; + + let partitioning = shuffle_output_partitioning.ok_or_else(|| { + DataFusionError::Internal( + "SortShuffleWriterExec requires hash partitioning".to_string(), + ) + })?; + + let config = SortShuffleConfig::new( + true, + sort_shuffle_writer.buffer_size as usize, + sort_shuffle_writer.memory_limit as usize, + sort_shuffle_writer.spill_threshold, + datafusion::arrow::ipc::CompressionType::LZ4_FRAME, + ); + + Ok(Arc::new(SortShuffleWriterExec::try_new( + sort_shuffle_writer.job_id.clone(), + sort_shuffle_writer.stage_id as usize, + input, + "".to_string(), // executor will fill this in + partitioning, + config, + )?)) + } PhysicalPlanType::ShuffleReader(shuffle_reader) => { let stage_id = shuffle_reader.stage_id as usize; let schema: SchemaRef = @@ -409,6 +443,51 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { )) })?; + Ok(()) + } else if let Some(exec) = node.as_any().downcast_ref::() { + let output_partitioning = match exec.shuffle_output_partitioning() { + Partitioning::Hash(exprs, partition_count) => { + Some(datafusion_proto::protobuf::PhysicalHashRepartition { + hash_expr: exprs + .iter() + .map(|expr| { + datafusion_proto::physical_plan::to_proto::serialize_physical_expr( + &expr.clone(), + self.default_codec.as_ref(), + ) + }) + .collect::, DataFusionError>>()?, + partition_count: *partition_count as u64, + }) + } + other => { + return Err(DataFusionError::Internal(format!( + "SortShuffleWriterExec requires Hash partitioning, got: {other:?}" + ))); + } + }; + + let config = exec.config(); + let proto = protobuf::BallistaPhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::SortShuffleWriter( + protobuf::SortShuffleWriterExecNode { + job_id: exec.job_id().to_string(), + stage_id: exec.stage_id() as u32, + input: None, + output_partitioning, + buffer_size: config.buffer_size as u64, + memory_limit: config.memory_limit as u64, + spill_threshold: config.spill_threshold, + }, + )), + }; + + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!( + "failed to encode sort shuffle writer execution plan: {e:?}" + )) + })?; + Ok(()) } else if let Some(exec) = node.as_any().downcast_ref::() { let stage_id = exec.stage_id as u32; diff --git a/ballista/executor/src/execution_engine.rs b/ballista/executor/src/execution_engine.rs index 45732828d..2417e1d8a 100644 --- a/ballista/executor/src/execution_engine.rs +++ b/ballista/executor/src/execution_engine.rs @@ -23,6 +23,7 @@ use async_trait::async_trait; use ballista_core::execution_plans::ShuffleWriterExec; +use ballista_core::execution_plans::sort_shuffle::SortShuffleWriterExec; use ballista_core::serde::protobuf::ShuffleWritePartition; use ballista_core::utils; use datafusion::error::{DataFusionError, Result}; @@ -88,61 +89,102 @@ impl ExecutionEngine for DefaultExecutionEngine { plan: Arc, work_dir: &str, ) -> Result> { - // the query plan created by the scheduler always starts with a ShuffleWriterExec - let exec = if let Some(shuffle_writer) = - plan.as_any().downcast_ref::() - { + // the query plan created by the scheduler always starts with a shuffle writer + // (either ShuffleWriterExec or SortShuffleWriterExec) + if let Some(shuffle_writer) = plan.as_any().downcast_ref::() { // recreate the shuffle writer with the correct working directory - ShuffleWriterExec::try_new( + let exec = ShuffleWriterExec::try_new( job_id, stage_id, plan.children()[0].clone(), work_dir.to_string(), shuffle_writer.shuffle_output_partitioning().cloned(), - ) + )?; + Ok(Arc::new(DefaultQueryStageExec::new( + ShuffleWriterVariant::Hash(exec), + ))) + } else if let Some(sort_shuffle_writer) = + plan.as_any().downcast_ref::() + { + // recreate the sort shuffle writer with the correct working directory + let exec = SortShuffleWriterExec::try_new( + job_id, + stage_id, + plan.children()[0].clone(), + work_dir.to_string(), + sort_shuffle_writer.shuffle_output_partitioning().clone(), + sort_shuffle_writer.config().clone(), + )?; + Ok(Arc::new(DefaultQueryStageExec::new( + ShuffleWriterVariant::Sort(exec), + ))) } else { Err(DataFusionError::Internal( - "Plan passed to new_query_stage_exec is not a ShuffleWriterExec" + "Plan passed to new_query_stage_exec is not a ShuffleWriterExec or SortShuffleWriterExec" .to_string(), )) - }?; - Ok(Arc::new(DefaultQueryStageExec::new(exec))) + } } } -/// Default query stage executor that wraps a ShuffleWriterExec. +/// Enum representing the different shuffle writer implementations. +#[derive(Debug, Clone)] +pub enum ShuffleWriterVariant { + /// Hash-based shuffle writer (original implementation). + Hash(ShuffleWriterExec), + /// Sort-based shuffle writer. + Sort(SortShuffleWriterExec), +} + +/// Default query stage executor that wraps a shuffle writer. /// -/// This executor delegates to the ShuffleWriterExec to perform the actual +/// This executor delegates to the underlying shuffle writer to perform the actual /// shuffle write operation, which partitions the data and writes it to disk. #[derive(Debug)] pub struct DefaultQueryStageExec { /// The underlying shuffle writer execution plan. - shuffle_writer: ShuffleWriterExec, + shuffle_writer: ShuffleWriterVariant, } impl DefaultQueryStageExec { /// Creates a new DefaultQueryStageExec wrapping the given shuffle writer. - pub fn new(shuffle_writer: ShuffleWriterExec) -> Self { + pub fn new(shuffle_writer: ShuffleWriterVariant) -> Self { Self { shuffle_writer } } } impl Display for DefaultQueryStageExec { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let stage_metrics: Vec = self - .shuffle_writer - .metrics() - .unwrap_or_default() - .iter() - .map(|m| m.to_string()) - .collect(); - - write!( - f, - "DefaultQueryStageExec: ({})\n{}", - stage_metrics.join(", "), - self.shuffle_writer - ) + match &self.shuffle_writer { + ShuffleWriterVariant::Hash(writer) => { + let stage_metrics: Vec = writer + .metrics() + .unwrap_or_default() + .iter() + .map(|m| m.to_string()) + .collect(); + write!( + f, + "DefaultQueryStageExec(Hash): ({})\n{}", + stage_metrics.join(", "), + writer + ) + } + ShuffleWriterVariant::Sort(writer) => { + let stage_metrics: Vec = writer + .metrics() + .unwrap_or_default() + .iter() + .map(|m| m.to_string()) + .collect(); + write!( + f, + "DefaultQueryStageExec(Sort): ({})\n{:?}", + stage_metrics.join(", "), + writer + ) + } + } } } @@ -153,13 +195,26 @@ impl QueryStageExecutor for DefaultQueryStageExec { input_partition: usize, context: Arc, ) -> Result> { - self.shuffle_writer - .clone() - .execute_shuffle_write(input_partition, context) - .await + match &self.shuffle_writer { + ShuffleWriterVariant::Hash(writer) => { + writer + .clone() + .execute_shuffle_write(input_partition, context) + .await + } + ShuffleWriterVariant::Sort(writer) => { + writer + .clone() + .execute_shuffle_write(input_partition, context) + .await + } + } } fn collect_plan_metrics(&self) -> Vec { - utils::collect_plan_metrics(&self.shuffle_writer) + match &self.shuffle_writer { + ShuffleWriterVariant::Hash(writer) => utils::collect_plan_metrics(writer), + ShuffleWriterVariant::Sort(writer) => utils::collect_plan_metrics(writer), + } } } diff --git a/ballista/executor/src/executor.rs b/ballista/executor/src/executor.rs index d37613691..c0bcdfe95 100644 --- a/ballista/executor/src/executor.rs +++ b/ballista/executor/src/executor.rs @@ -227,7 +227,7 @@ impl Executor { #[cfg(test)] mod test { - use crate::execution_engine::DefaultQueryStageExec; + use crate::execution_engine::{DefaultQueryStageExec, ShuffleWriterVariant}; use crate::executor::Executor; use ballista_core::RuntimeProducer; use ballista_core::execution_plans::ShuffleWriterExec; @@ -363,7 +363,8 @@ mod test { ) .expect("creating shuffle writer"); - let query_stage_exec = DefaultQueryStageExec::new(shuffle_write); + let query_stage_exec = + DefaultQueryStageExec::new(ShuffleWriterVariant::Hash(shuffle_write)); let executor_registration = ExecutorRegistration { id: "executor".to_string(), diff --git a/ballista/executor/src/flight_service.rs b/ballista/executor/src/flight_service.rs index a961ee6d4..0321af2c3 100644 --- a/ballista/executor/src/flight_service.rs +++ b/ballista/executor/src/flight_service.rs @@ -20,12 +20,16 @@ use datafusion::arrow::ipc::reader::StreamReader; use std::convert::TryFrom; use std::fs::File; +use std::path::Path; use std::pin::Pin; use tokio_util::io::ReaderStream; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use ballista_core::error::BallistaError; +use ballista_core::execution_plans::sort_shuffle::{ + get_index_path, is_sort_shuffle_output, stream_sort_shuffle_partition, +}; use ballista_core::serde::decode_protobuf; use ballista_core::serde::scheduler::Action as BallistaAction; use datafusion::arrow::ipc::CompressionType; @@ -95,8 +99,43 @@ impl FlightService for BallistaFlightService { decode_protobuf(&ticket.ticket).map_err(|e| from_ballista_err(&e))?; match &action { - BallistaAction::FetchPartition { path, .. } => { - debug!("FetchPartition reading {path}"); + BallistaAction::FetchPartition { + path, partition_id, .. + } => { + debug!("FetchPartition reading partition {partition_id} from {path}"); + let data_path = Path::new(path); + + // Check if this is a sort-based shuffle output + if is_sort_shuffle_output(data_path) { + debug!("Detected sort-based shuffle format for {path}"); + let index_path = get_index_path(data_path); + let stream = stream_sort_shuffle_partition( + data_path, + &index_path, + *partition_id, + ) + .map_err(|e| from_ballista_err(&e))?; + + let schema = stream.schema(); + // Map DataFusionError to FlightError + let stream = + stream.map_err(|e| FlightError::from(ArrowError::from(e))); + + let write_options: IpcWriteOptions = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME)) + .map_err(|e| from_arrow_err(&e))?; + let flight_data_stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .with_options(write_options) + .build(stream) + .map_err(|err| Status::from_error(Box::new(err))); + + return Ok(Response::new( + Box::pin(flight_data_stream) as Self::DoGetStream + )); + } + + // Standard hash-based shuffle - read the entire file let file = File::open(path) .map_err(|e| { BallistaError::General(format!( @@ -214,6 +253,19 @@ impl FlightService for BallistaFlightService { match &action { BallistaAction::FetchPartition { path, .. } => { debug!("FetchPartition reading {path}"); + let data_path = Path::new(path); + + // Block transport doesn't support sort-based shuffle because it + // transfers the entire file, which contains all partitions. + // Use flight transport (do_get) for sort-based shuffle. + if is_sort_shuffle_output(data_path) { + return Err(Status::unimplemented( + "IO_BLOCK_TRANSPORT does not support sort-based shuffle. \ + Set ballista.shuffle.remote_read_prefer_flight=true to use \ + flight transport instead.", + )); + } + let file = tokio::fs::File::open(&path).await.map_err(|e| { Status::internal(format!("Failed to open file: {e}")) })?; diff --git a/ballista/scheduler/src/planner.rs b/ballista/scheduler/src/planner.rs index d508e6b5f..63fe3d910 100644 --- a/ballista/scheduler/src/planner.rs +++ b/ballista/scheduler/src/planner.rs @@ -20,9 +20,15 @@ use std::collections::HashMap; use std::sync::Arc; +use ballista_core::config::BallistaConfig; use ballista_core::error::{BallistaError, Result}; +use ballista_core::execution_plans::ShuffleWriter; +use ballista_core::execution_plans::sort_shuffle::SortShuffleConfig; use ballista_core::{ - execution_plans::{ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec}, + execution_plans::{ + ShuffleReaderExec, ShuffleWriterExec, SortShuffleWriterExec, + UnresolvedShuffleExec, + }, serde::scheduler::PartitionLocation, }; use datafusion::config::ConfigOptions; @@ -37,24 +43,24 @@ use datafusion::physical_plan::{ use log::{debug, info}; -type PartialQueryStageResult = (Arc, Vec>); +type PartialQueryStageResult = (Arc, Vec>); /// Trait for breaking an execution plan into distributed query stages. /// /// The planner creates a DAG of stages where each stage can be executed /// independently once its input stages are complete. pub trait DistributedPlanner { - /// Returns a vector of ExecutionPlans, where the root node is a [`ShuffleWriterExec`]. + /// Returns a vector of ExecutionPlans, where the root node is a [`ShuffleWriter`]. /// /// Plans that depend on the input of other plans will have leaf nodes of type - /// [`UnresolvedShuffleExec`]. A [`ShuffleWriterExec`] is created whenever the + /// [`UnresolvedShuffleExec`]. A shuffle writer is created whenever the /// partitioning changes. fn plan_query_stages<'a>( &'a mut self, job_id: &'a str, execution_plan: Arc, config: &ConfigOptions, - ) -> Result>>; + ) -> Result>>; } /// Default implementation of [`DistributedPlanner`]. @@ -87,23 +93,24 @@ impl Default for DefaultDistributedPlanner { } impl DistributedPlanner for DefaultDistributedPlanner { - /// Returns a vector of ExecutionPlans, where the root node is a [ShuffleWriterExec]. + /// Returns a vector of ExecutionPlans, where the root node is a shuffle writer. /// Plans that depend on the input of other plans will have leaf nodes of type [UnresolvedShuffleExec]. - /// A [ShuffleWriterExec] is created whenever the partitioning changes. + /// A shuffle writer is created whenever the partitioning changes. fn plan_query_stages<'a>( &'a mut self, job_id: &'a str, execution_plan: Arc, config: &ConfigOptions, - ) -> Result>> { + ) -> Result>> { info!("planning query stages for job {job_id}"); let (new_plan, mut stages) = self.plan_query_stages_internal(job_id, execution_plan, config)?; - stages.push(create_shuffle_writer( + stages.push(create_shuffle_writer_with_config( job_id, self.next_stage_id(), new_plan, None, + config, )?); Ok(stages) } @@ -139,9 +146,14 @@ impl DefaultDistributedPlanner { { let input = children[0].clone(); let input = self.optimizer_enforce_sorting.optimize(input, config)?; - let shuffle_writer = - create_shuffle_writer(job_id, self.next_stage_id(), input, None)?; - let unresolved_shuffle = create_unresolved_shuffle(&shuffle_writer); + let shuffle_writer = create_shuffle_writer_with_config( + job_id, + self.next_stage_id(), + input, + None, + config, + )?; + let unresolved_shuffle = create_unresolved_shuffle(shuffle_writer.as_ref()); stages.push(shuffle_writer); Ok(( @@ -152,13 +164,14 @@ impl DefaultDistributedPlanner { .as_any() .downcast_ref::( ) { - let shuffle_writer = create_shuffle_writer( + let shuffle_writer = create_shuffle_writer_with_config( job_id, self.next_stage_id(), children[0].clone(), None, + config, )?; - let unresolved_shuffle = create_unresolved_shuffle(&shuffle_writer); + let unresolved_shuffle = create_unresolved_shuffle(shuffle_writer.as_ref()); stages.push(shuffle_writer); Ok(( with_new_children_if_necessary(execution_plan, vec![unresolved_shuffle])?, @@ -172,13 +185,15 @@ impl DefaultDistributedPlanner { let input = children[0].clone(); let input = self.optimizer_enforce_sorting.optimize(input, config)?; - let shuffle_writer = create_shuffle_writer( + let shuffle_writer = create_shuffle_writer_with_config( job_id, self.next_stage_id(), input, Some(repart.partitioning().to_owned()), + config, )?; - let unresolved_shuffle = create_unresolved_shuffle(&shuffle_writer); + let unresolved_shuffle = + create_unresolved_shuffle(shuffle_writer.as_ref()); stages.push(shuffle_writer); Ok((unresolved_shuffle, stages)) @@ -204,7 +219,7 @@ impl DefaultDistributedPlanner { } fn create_unresolved_shuffle( - shuffle_writer: &ShuffleWriterExec, + shuffle_writer: &dyn ShuffleWriter, ) -> Arc { Arc::new(UnresolvedShuffleExec::new( shuffle_writer.stage_id(), @@ -321,17 +336,48 @@ pub fn rollback_resolved_shuffles( Ok(with_new_children_if_necessary(stage, new_children)?) } -fn create_shuffle_writer( +fn create_shuffle_writer_with_config( job_id: &str, stage_id: usize, plan: Arc, partitioning: Option, -) -> Result> { + config: &ConfigOptions, +) -> Result> { + // Check if sort-based shuffle is enabled + let ballista_config = config + .extensions + .get::() + .cloned() + .unwrap_or_default(); + + if ballista_config.shuffle_sort_based_enabled() { + // Sort shuffle requires hash partitioning + if let Some(Partitioning::Hash(exprs, partition_count)) = partitioning { + let sort_config = SortShuffleConfig::new( + true, + ballista_config.shuffle_sort_based_buffer_size(), + ballista_config.shuffle_sort_based_memory_limit(), + ballista_config.shuffle_sort_based_spill_threshold(), + datafusion::arrow::ipc::CompressionType::LZ4_FRAME, + ); + + return Ok(Arc::new(SortShuffleWriterExec::try_new( + job_id.to_owned(), + stage_id, + plan, + "".to_owned(), + Partitioning::Hash(exprs, partition_count), + sort_config, + )?)); + } + } + + // Fall back to standard shuffle writer Ok(Arc::new(ShuffleWriterExec::try_new( job_id.to_owned(), stage_id, plan, - "".to_owned(), // executor will decide on the work_dir path + "".to_owned(), partitioning, )?)) } diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index 959272bc2..8d07ae022 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -28,7 +28,9 @@ use datafusion::prelude::SessionConfig; use log::{debug, error, info, warn}; use ballista_core::error::{BallistaError, Result}; -use ballista_core::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec}; +use ballista_core::execution_plans::{ + ShuffleWriter, ShuffleWriterExec, SortShuffleWriterExec, UnresolvedShuffleExec, +}; use ballista_core::serde::protobuf::failed_task::FailedReason; use ballista_core::serde::protobuf::job_status::Status; use ballista_core::serde::protobuf::{FailedJob, ShuffleWritePartition, job_status}; @@ -1435,7 +1437,7 @@ impl ExecutionStageBuilder { pub fn build( mut self, - stages: Vec>, + stages: Vec>, ) -> Result> { let mut execution_stages: HashMap = HashMap::new(); // First, build the dependency graph @@ -1486,8 +1488,13 @@ impl ExecutionPlanVisitor for ExecutionStageBuilder { &mut self, plan: &dyn ExecutionPlan, ) -> std::result::Result { + // Handle both ShuffleWriterExec and SortShuffleWriterExec if let Some(shuffle_write) = plan.as_any().downcast_ref::() { self.current_stage_id = shuffle_write.stage_id(); + } else if let Some(shuffle_write) = + plan.as_any().downcast_ref::() + { + self.current_stage_id = shuffle_write.stage_id(); } else if let Some(unresolved_shuffle) = plan.as_any().downcast_ref::() { @@ -1558,15 +1565,25 @@ impl Debug for TaskDescription { impl TaskDescription { /// Returns the number of output partitions this task will produce. pub fn get_output_partition_number(&self) -> usize { - let shuffle_writer = self - .plan - .as_any() - .downcast_ref::() - .unwrap(); - shuffle_writer - .shuffle_output_partitioning() - .map(|partitioning| partitioning.partition_count()) - .unwrap_or_else(|| 1) + // Try ShuffleWriterExec first + if let Some(shuffle_writer) = + self.plan.as_any().downcast_ref::() + { + return shuffle_writer + .shuffle_output_partitioning() + .map(|partitioning| partitioning.partition_count()) + .unwrap_or(1); + } + // Try SortShuffleWriterExec + if let Some(shuffle_writer) = + self.plan.as_any().downcast_ref::() + { + return shuffle_writer + .shuffle_output_partitioning() + .partition_count(); + } + // Default fallback + 1 } } diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs b/ballista/scheduler/src/state/execution_graph_dot.rs index 49ef7a0ab..002af18c5 100644 --- a/ballista/scheduler/src/state/execution_graph_dot.rs +++ b/ballista/scheduler/src/state/execution_graph_dot.rs @@ -19,7 +19,7 @@ use crate::state::execution_graph::ExecutionGraph; use ballista_core::execution_plans::{ - ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec, + ShuffleReaderExec, ShuffleWriterExec, SortShuffleWriterExec, UnresolvedShuffleExec, }; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::MemorySourceConfig; @@ -320,6 +320,11 @@ filter_expr={}", "ShuffleWriter [{} partitions]", exec.input_partition_count() ) + } else if let Some(exec) = plan.as_any().downcast_ref::() { + format!( + "SortShuffleWriter [{} partitions]", + exec.input_partition_count() + ) } else if let Some(exec) = plan.as_any().downcast_ref::() { let config = if let Some(config) = exec.data_source().as_any().downcast_ref::() diff --git a/ballista/scheduler/src/state/execution_stage.rs b/ballista/scheduler/src/state/execution_stage.rs index 915361775..2a4d21a95 100644 --- a/ballista/scheduler/src/state/execution_stage.rs +++ b/ballista/scheduler/src/state/execution_stage.rs @@ -32,7 +32,7 @@ use datafusion::prelude::SessionConfig; use log::{debug, warn}; use ballista_core::error::{BallistaError, Result}; -use ballista_core::execution_plans::ShuffleWriterExec; +use ballista_core::execution_plans::{ShuffleWriterExec, SortShuffleWriterExec}; use ballista_core::serde::protobuf::failed_task::FailedReason; use ballista_core::serde::protobuf::{ FailedTask, OperatorMetricsSet, ResultLost, SuccessfulTask, TaskStatus, @@ -979,13 +979,19 @@ impl Debug for FailedStage { } /// Get the total number of partitions for a stage with plan. -/// Only for [`ShuffleWriterExec`], the input partition count and the output partition count +/// Only for shuffle writers, the input partition count and the output partition count /// will be different. Here, we should use the input partition count. fn get_stage_partitions(plan: Arc) -> usize { - plan.as_any() - .downcast_ref::() - .map(|shuffle_writer| shuffle_writer.input_partition_count()) - .unwrap_or_else(|| plan.properties().output_partitioning().partition_count()) + // Try ShuffleWriterExec first + if let Some(shuffle_writer) = plan.as_any().downcast_ref::() { + return shuffle_writer.input_partition_count(); + } + // Try SortShuffleWriterExec + if let Some(shuffle_writer) = plan.as_any().downcast_ref::() { + return shuffle_writer.input_partition_count(); + } + // Fallback to output partitioning + plan.properties().output_partitioning().partition_count() } /// This data structure collects the partition locations for an `ExecutionStage`. diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index fc2b4f2c6..be6dcb182 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -33,6 +33,7 @@ snmalloc = ["snmalloc-rs"] [dependencies] ballista = { path = "../ballista/client", version = "51.0.0" } +ballista-core = { path = "../ballista/core", version = "51.0.0" } datafusion = { workspace = true } datafusion-proto = { workspace = true } env_logger = { workspace = true } @@ -43,6 +44,7 @@ serde = { workspace = true } serde_json = "1.0.78" snmalloc-rs = { version = "0.3", optional = true } structopt = { version = "0.3", default-features = false } +tempfile = { workspace = true } tokio = { version = "^1.44", features = [ "macros", "rt", @@ -51,4 +53,3 @@ tokio = { version = "^1.44", features = [ ] } [dev-dependencies] -ballista-core = { path = "../ballista/core", version = "51.0.0" } diff --git a/benchmarks/src/bin/shuffle_bench.rs b/benchmarks/src/bin/shuffle_bench.rs new file mode 100644 index 000000000..202c123c8 --- /dev/null +++ b/benchmarks/src/bin/shuffle_bench.rs @@ -0,0 +1,429 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark comparing hash-based and sort-based shuffle implementations. +//! +//! This benchmark generates synthetic data and measures the performance of +//! both shuffle implementations across various configurations: +//! - Different input sizes (number of rows) +//! - Different partition counts +//! - Different batch sizes +//! +//! Usage: +//! cargo run --release --bin shuffle_bench -- --help +//! cargo run --release --bin shuffle_bench -- --rows 1000000 --partitions 16 + +use datafusion::arrow::array::{Int64Array, StringArray}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::ipc::CompressionType; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::datasource::source::DataSourceExec; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::Partitioning; +use datafusion::prelude::SessionContext; +use std::fs; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use structopt::StructOpt; +use tempfile::TempDir; + +#[cfg(feature = "mimalloc")] +#[global_allocator] +static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; + +#[derive(Debug, StructOpt, Clone)] +#[structopt( + name = "shuffle_bench", + about = "Benchmark comparing hash-based and sort-based shuffle implementations" +)] +struct ShuffleBenchOpt { + /// Number of rows to generate + #[structopt(short = "r", long = "rows", default_value = "1000000")] + rows: usize, + + /// Number of output partitions + #[structopt(short = "p", long = "partitions", default_value = "16")] + partitions: usize, + + /// Number of input partitions + #[structopt(short = "i", long = "input-partitions", default_value = "4")] + input_partitions: usize, + + /// Batch size + #[structopt(short = "b", long = "batch-size", default_value = "8192")] + batch_size: usize, + + /// Number of iterations + #[structopt(short = "n", long = "iterations", default_value = "3")] + iterations: usize, + + /// Memory limit for sort shuffle (in MB) + #[structopt(short = "m", long = "memory-limit", default_value = "256")] + memory_limit_mb: usize, + + /// Buffer size for sort shuffle (in MB) + #[structopt(long = "buffer-size", default_value = "1")] + buffer_size_mb: usize, + + /// Only run hash shuffle + #[structopt(long = "hash-only")] + hash_only: bool, + + /// Only run sort shuffle + #[structopt(long = "sort-only")] + sort_only: bool, +} + +fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("partition_key", DataType::Int64, false), + Field::new("value", DataType::Utf8, true), + ])) +} + +fn generate_test_batch( + schema: &SchemaRef, + batch_size: usize, + partition_count: usize, + offset: usize, +) -> RecordBatch { + let ids: Vec = (offset..offset + batch_size).map(|i| i as i64).collect(); + let partition_keys: Vec = + ids.iter().map(|id| *id % partition_count as i64).collect(); + let values: Vec = ids.iter().map(|id| format!("value_{}", id)).collect(); + + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(Int64Array::from(partition_keys)), + Arc::new(StringArray::from(values)), + ], + ) + .unwrap() +} + +fn create_test_data( + schema: &SchemaRef, + rows: usize, + batch_size: usize, + input_partitions: usize, + partition_count: usize, +) -> Vec> { + let rows_per_partition = rows / input_partitions; + let batches_per_partition = rows_per_partition.div_ceil(batch_size); + + let mut partitions = Vec::with_capacity(input_partitions); + for p in 0..input_partitions { + let mut batches = Vec::with_capacity(batches_per_partition); + for b in 0..batches_per_partition { + let offset = p * rows_per_partition + b * batch_size; + let current_batch_size = + std::cmp::min(batch_size, rows_per_partition - b * batch_size); + if current_batch_size > 0 { + batches.push(generate_test_batch( + schema, + current_batch_size, + partition_count, + offset, + )); + } + } + partitions.push(batches); + } + partitions +} + +async fn benchmark_hash_shuffle( + data: &[Vec], + schema: SchemaRef, + output_partitions: usize, + work_dir: &str, +) -> Result<(Duration, usize), Box> { + use ballista_core::execution_plans::ShuffleWriterExec; + use ballista_core::utils; + + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + // Create input plan from data + let memory_source = + Arc::new(MemorySourceConfig::try_new(data, schema.clone(), None)?); + let input = Arc::new(DataSourceExec::new(memory_source)); + + // Create shuffle writer + let shuffle_writer = ShuffleWriterExec::try_new( + "bench_job".to_owned(), + 1, + input, + work_dir.to_owned(), + Some(Partitioning::Hash( + vec![Arc::new(Column::new("partition_key", 1))], + output_partitions, + )), + )?; + + let start = Instant::now(); + + // Execute all input partitions (not output partitions) + let input_partition_count = data.len(); + let mut total_files = 0; + for partition in 0..input_partition_count { + let mut stream = shuffle_writer.execute(partition, task_ctx.clone())?; + let batches = utils::collect_stream(&mut stream).await?; + // Count output files from the result + if let Some(batch) = batches.first() { + total_files += batch.num_rows(); + } + } + + let elapsed = start.elapsed(); + Ok((elapsed, total_files)) +} + +async fn benchmark_sort_shuffle( + data: &[Vec], + schema: SchemaRef, + output_partitions: usize, + work_dir: &str, + buffer_size: usize, + memory_limit: usize, +) -> Result<(Duration, usize), Box> { + use ballista_core::execution_plans::sort_shuffle::{ + SortShuffleConfig, SortShuffleWriterExec, + }; + use ballista_core::utils; + + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + // Create input plan from data + let memory_source = + Arc::new(MemorySourceConfig::try_new(data, schema.clone(), None)?); + let input = Arc::new(DataSourceExec::new(memory_source)); + + // Create sort shuffle config + let config = SortShuffleConfig::new( + true, + buffer_size, + memory_limit, + 0.8, + CompressionType::LZ4_FRAME, + ); + + // Create sort shuffle writer + let shuffle_writer = SortShuffleWriterExec::try_new( + "bench_job".to_owned(), + 1, + input, + work_dir.to_owned(), + Partitioning::Hash( + vec![Arc::new(Column::new("partition_key", 1))], + output_partitions, + ), + config, + )?; + + let start = Instant::now(); + + // Execute all input partitions (not output partitions) + let input_partition_count = data.len(); + let mut total_files = 0; + for partition in 0..input_partition_count { + let mut stream = shuffle_writer.execute(partition, task_ctx.clone())?; + let batches = utils::collect_stream(&mut stream).await?; + // Count output files from the result + if let Some(batch) = batches.first() { + total_files += batch.num_rows(); + } + } + + let elapsed = start.elapsed(); + Ok((elapsed, total_files)) +} + +fn count_files_in_dir(dir: &str) -> usize { + let mut count = 0; + if let Ok(entries) = fs::read_dir(dir) { + for entry in entries.flatten() { + if entry.path().is_file() { + count += 1; + } else if entry.path().is_dir() { + count += count_files_in_dir(entry.path().to_str().unwrap()); + } + } + } + count +} + +fn dir_size(dir: &str) -> u64 { + let mut size = 0; + if let Ok(entries) = fs::read_dir(dir) { + for entry in entries.flatten() { + if entry.path().is_file() { + if let Ok(meta) = entry.metadata() { + size += meta.len(); + } + } else if entry.path().is_dir() { + size += dir_size(entry.path().to_str().unwrap()); + } + } + } + size +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + let opt = ShuffleBenchOpt::from_args(); + + println!("Shuffle Benchmark Configuration:"); + println!(" Rows: {}", opt.rows); + println!(" Input partitions: {}", opt.input_partitions); + println!(" Output partitions: {}", opt.partitions); + println!(" Batch size: {}", opt.batch_size); + println!(" Iterations: {}", opt.iterations); + println!(" Sort shuffle memory limit: {} MB", opt.memory_limit_mb); + println!(" Sort shuffle buffer size: {} MB", opt.buffer_size_mb); + println!(); + + let schema = create_test_schema(); + + // Generate test data once + println!("Generating test data..."); + let data = create_test_data( + &schema, + opt.rows, + opt.batch_size, + opt.input_partitions, + opt.partitions, + ); + let total_batches: usize = data.iter().map(|p| p.len()).sum(); + println!( + "Generated {} batches across {} input partitions", + total_batches, opt.input_partitions + ); + println!(); + + let buffer_size = opt.buffer_size_mb * 1024 * 1024; + let memory_limit = opt.memory_limit_mb * 1024 * 1024; + + // Benchmark hash shuffle + if !opt.sort_only { + println!("=== Hash-Based Shuffle ==="); + let mut hash_times: Vec = Vec::new(); + let mut hash_file_count = 0; + let mut hash_total_size = 0u64; + + for i in 0..opt.iterations { + let temp_dir = TempDir::new()?; + let work_dir = temp_dir.path().to_str().unwrap(); + + let (elapsed, _files) = + benchmark_hash_shuffle(&data, schema.clone(), opt.partitions, work_dir) + .await?; + + hash_file_count = count_files_in_dir(work_dir); + hash_total_size = dir_size(work_dir); + hash_times.push(elapsed); + + println!( + " Iteration {}: {:?} ({} files, {} KB)", + i + 1, + elapsed, + hash_file_count, + hash_total_size / 1024 + ); + } + + let avg_time: Duration = + hash_times.iter().sum::() / hash_times.len() as u32; + let min_time = hash_times.iter().min().unwrap(); + let max_time = hash_times.iter().max().unwrap(); + + println!(); + println!("Hash Shuffle Results:"); + println!(" Average time: {:?}", avg_time); + println!(" Min time: {:?}", min_time); + println!(" Max time: {:?}", max_time); + println!(" Files created: {}", hash_file_count); + println!(" Total size: {} KB", hash_total_size / 1024); + println!( + " Throughput: {:.2} MB/s", + (opt.rows * 30) as f64 / avg_time.as_secs_f64() / 1024.0 / 1024.0 + ); + println!(); + } + + // Benchmark sort shuffle + if !opt.hash_only { + println!("=== Sort-Based Shuffle ==="); + let mut sort_times: Vec = Vec::new(); + let mut sort_file_count = 0; + let mut sort_total_size = 0u64; + + for i in 0..opt.iterations { + let temp_dir = TempDir::new()?; + let work_dir = temp_dir.path().to_str().unwrap(); + + let (elapsed, _files) = benchmark_sort_shuffle( + &data, + schema.clone(), + opt.partitions, + work_dir, + buffer_size, + memory_limit, + ) + .await?; + + sort_file_count = count_files_in_dir(work_dir); + sort_total_size = dir_size(work_dir); + sort_times.push(elapsed); + + println!( + " Iteration {}: {:?} ({} files, {} KB)", + i + 1, + elapsed, + sort_file_count, + sort_total_size / 1024 + ); + } + + let avg_time: Duration = + sort_times.iter().sum::() / sort_times.len() as u32; + let min_time = sort_times.iter().min().unwrap(); + let max_time = sort_times.iter().max().unwrap(); + + println!(); + println!("Sort Shuffle Results:"); + println!(" Average time: {:?}", avg_time); + println!(" Min time: {:?}", min_time); + println!(" Max time: {:?}", max_time); + println!(" Files created: {}", sort_file_count); + println!(" Total size: {} KB", sort_total_size / 1024); + println!( + " Throughput: {:.2} MB/s", + (opt.rows * 30) as f64 / avg_time.as_secs_f64() / 1024.0 / 1024.0 + ); + println!(); + } + + Ok(()) +} diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 4bfc133ef..edc922f1f 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -19,6 +19,7 @@ use ballista::extension::SessionConfigExt; use ballista::prelude::SessionContextExt; +use ballista_core::config::BALLISTA_SHUFFLE_SORT_BASED_ENABLED; use datafusion::arrow::array::*; use datafusion::arrow::datatypes::SchemaBuilder; use datafusion::arrow::util::display::array_value_to_string; @@ -122,6 +123,10 @@ struct BallistaBenchmarkOpt { /// Path to output directory where JSON summary file should be written to #[structopt(parse(from_os_str), short = "o", long = "output")] output_path: Option, + + /// Enable sort-based shuffle instead of hash-based shuffle + #[structopt(long = "sort-shuffle")] + sort_shuffle: bool, } #[derive(Debug, StructOpt, Clone)] @@ -406,12 +411,16 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> { for query in query_numbers { let mut query_run = QueryRun::new(query); - let config = SessionConfig::new_with_ballista() + let mut config = SessionConfig::new_with_ballista() .with_target_partitions(opt.partitions) .with_ballista_job_name(&format!("Query derived from TPC-H q{}", query)) .with_batch_size(opt.batch_size) .with_collect_statistics(true); + if opt.sort_shuffle { + config = config.set_str(BALLISTA_SHUFFLE_SORT_BASED_ENABLED, "true"); + } + let state = SessionStateBuilder::new() .with_default_features() .with_config(config)