From 1abdb192e94d69c996eb8153f0679d65b13d590f Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sun, 28 Dec 2025 15:12:58 +0530 Subject: [PATCH 1/4] Fix BloomFilter buffer incompatibility between Spark and Comet Handle Spark's full serialization format (12-byte header + bits) in merge_filter() to support Spark partial / Comet final execution. The fix automatically detects the format and extracts bits data accordingly. Fixes #2889 --- .../src/bloom_filter/spark_bloom_filter.rs | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index e84257ea67..656171c947 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -159,12 +159,35 @@ impl SparkBloomFilter { self.bits.to_bytes() } + /// Extracts bits data from Spark's full serialization format. + /// Spark's format includes a 12-byte header (version + num_hash_functions + num_words) + /// followed by the bits data. This function extracts just the bits data. + fn extract_bits_from_spark_format(&self, buf: &[u8]) -> &[u8] { + const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) + + // Check if this is Spark's full serialization format + let expected_bits_size = self.bits.byte_size(); + if buf.len() == SPARK_HEADER_SIZE + expected_bits_size { + // This is Spark's full format, extract bits data (skip header) + &buf[SPARK_HEADER_SIZE..] + } else { + // This is already just bits data (Comet format) + buf + } + } + pub fn merge_filter(&mut self, other: &[u8]) { + // Extract bits data if other is in Spark's full serialization format + let bits_data = self.extract_bits_from_spark_format(other); + assert_eq!( - other.len(), + bits_data.len(), + self.bits.byte_size(), + "Cannot merge SparkBloomFilters with different lengths. Expected {} bytes, got {} bytes (full buffer size: {} bytes)", self.bits.byte_size(), - "Cannot merge SparkBloomFilters with different lengths." + bits_data.len(), + other.len() ); - self.bits.merge_bits(other); + self.bits.merge_bits(bits_data); } } From 5994c3fd3400c23e3cf7f8d08ae3ee0dfe76d09b Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 3 Jan 2026 19:54:50 +0530 Subject: [PATCH 2/4] minor change --- native/spark-expr/src/bloom_filter/spark_bloom_filter.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index 656171c947..1e315e9f88 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -164,7 +164,7 @@ impl SparkBloomFilter { /// followed by the bits data. This function extracts just the bits data. fn extract_bits_from_spark_format(&self, buf: &[u8]) -> &[u8] { const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) - + // Check if this is Spark's full serialization format let expected_bits_size = self.bits.byte_size(); if buf.len() == SPARK_HEADER_SIZE + expected_bits_size { @@ -179,7 +179,7 @@ impl SparkBloomFilter { pub fn merge_filter(&mut self, other: &[u8]) { // Extract bits data if other is in Spark's full serialization format let bits_data = self.extract_bits_from_spark_format(other); - + assert_eq!( bits_data.len(), self.bits.byte_size(), From 030c67b6174f5e5ca05f376a15216041f2cc3538 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Tue, 6 Jan 2026 10:05:29 +0530 Subject: [PATCH 3/4] Fix Rust lifetime and borrow checker errors in merge_filter --- .../src/bloom_filter/spark_bloom_filter.rs | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index 1e315e9f88..f5ed086d27 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -159,32 +159,25 @@ impl SparkBloomFilter { self.bits.to_bytes() } - /// Extracts bits data from Spark's full serialization format. - /// Spark's format includes a 12-byte header (version + num_hash_functions + num_words) - /// followed by the bits data. This function extracts just the bits data. - fn extract_bits_from_spark_format(&self, buf: &[u8]) -> &[u8] { + pub fn merge_filter(&mut self, other: &[u8]) { + // Extract bits data if other is in Spark's full serialization format + // We need to compute the expected size and extract data before borrowing self.bits mutably + let expected_bits_size = self.bits.byte_size(); const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) - // Check if this is Spark's full serialization format - let expected_bits_size = self.bits.byte_size(); - if buf.len() == SPARK_HEADER_SIZE + expected_bits_size { + let bits_data = if other.len() == SPARK_HEADER_SIZE + expected_bits_size { // This is Spark's full format, extract bits data (skip header) - &buf[SPARK_HEADER_SIZE..] + &other[SPARK_HEADER_SIZE..] } else { // This is already just bits data (Comet format) - buf - } - } - - pub fn merge_filter(&mut self, other: &[u8]) { - // Extract bits data if other is in Spark's full serialization format - let bits_data = self.extract_bits_from_spark_format(other); + other + }; assert_eq!( bits_data.len(), - self.bits.byte_size(), + expected_bits_size, "Cannot merge SparkBloomFilters with different lengths. Expected {} bytes, got {} bytes (full buffer size: {} bytes)", - self.bits.byte_size(), + expected_bits_size, bits_data.len(), other.len() ); From 49169a6aee875146961de287227f3a3e05de9ec6 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Thu, 8 Jan 2026 20:41:58 +0530 Subject: [PATCH 4/4] Remove fallback and add test for Spark partial / Comet final BloomFilterAggregate merge --- .../apache/spark/sql/comet/operators.scala | 8 ++- .../apache/comet/exec/CometExecSuite.scala | 65 ++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..7860b8380f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1069,7 +1069,13 @@ trait CometBaseAggregate { val multiMode = modes.size > 1 // For a final mode HashAggregate, we only need to transform the HashAggregate // if there is Comet partial aggregation. - val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + // Exception: BloomFilterAggregate supports Spark partial / Comet final because + // merge_filter() handles Spark's serialization format (12-byte header + bits). + val hasBloomFilterAgg = aggregate.aggregateExpressions.exists(expr => + expr.aggregateFunction.getClass.getSimpleName == "BloomFilterAggregate") + val sparkFinalMode = modes.contains(Final) && + findCometPartialAgg(aggregate.child).isEmpty && + !hasBloomFilterAgg if (multiMode || sparkFinalMode) { return None diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 1b2373ad71..b40ce6c2bb 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Bloom import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -1149,6 +1150,68 @@ class CometExecSuite extends CometTestBase { spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) } + test("bloom_filter_agg - Spark partial / Comet final merge") { + // This test exercises the merge_filter() fix that handles Spark's full serialization + // format (12-byte header + bits) when merging from Spark partial to Comet final aggregates. + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + // Helper to count operators in plan + def countOperators(plan: SparkPlan, opClass: Class[_]): Int = { + stripAQEPlan(plan).collect { + case stage: QueryStageExec => + countOperators(stage.plan, opClass) + case op if op.getClass.isAssignableFrom(opClass) => 1 + }.sum + } + + withParquetTable( + (0 until 1000) + .map(_ => (Random.nextInt(1000), Random.nextInt(100))), + "tbl") { + + withSQLConf( + // Disable Comet partial aggregates to force Spark partial / Comet final scenario + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + + val df = sql( + "SELECT bloom_filter_agg(cast(_2 as long), cast(1000 as long)) FROM tbl GROUP BY _1") + + // Verify the query executes successfully (tests merge_filter compatibility) + checkSparkAnswer(df) + + // Verify we have Spark partial aggregates and Comet final aggregates + val plan = stripAQEPlan(df.queryExecution.executedPlan) + val sparkPartialAggs = plan.collect { + case agg: HashAggregateExec if agg.aggregateExpressions.exists(_.mode == Partial) => agg + } + val cometFinalAggs = plan.collect { + case agg: CometHashAggregateExec if agg.aggregateExpressions.exists(_.mode == Final) => + agg + } + + assert( + sparkPartialAggs.nonEmpty, + s"Expected Spark partial aggregates but found none. Plan: $plan") + assert( + cometFinalAggs.nonEmpty, + s"Expected Comet final aggregates but found none. Plan: $plan") + } + } + + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + test("sort (non-global)") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc)