diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index e84257ea67..f5ed086d27 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -160,11 +160,27 @@ impl SparkBloomFilter { } pub fn merge_filter(&mut self, other: &[u8]) { + // Extract bits data if other is in Spark's full serialization format + // We need to compute the expected size and extract data before borrowing self.bits mutably + let expected_bits_size = self.bits.byte_size(); + const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4) + + let bits_data = if other.len() == SPARK_HEADER_SIZE + expected_bits_size { + // This is Spark's full format, extract bits data (skip header) + &other[SPARK_HEADER_SIZE..] + } else { + // This is already just bits data (Comet format) + other + }; + assert_eq!( - other.len(), - self.bits.byte_size(), - "Cannot merge SparkBloomFilters with different lengths." + bits_data.len(), + expected_bits_size, + "Cannot merge SparkBloomFilters with different lengths. Expected {} bytes, got {} bytes (full buffer size: {} bytes)", + expected_bits_size, + bits_data.len(), + other.len() ); - self.bits.merge_bits(other); + self.bits.merge_bits(bits_data); } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..7860b8380f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1069,7 +1069,13 @@ trait CometBaseAggregate { val multiMode = modes.size > 1 // For a final mode HashAggregate, we only need to transform the HashAggregate // if there is Comet partial aggregation. - val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + // Exception: BloomFilterAggregate supports Spark partial / Comet final because + // merge_filter() handles Spark's serialization format (12-byte header + bits). + val hasBloomFilterAgg = aggregate.aggregateExpressions.exists(expr => + expr.aggregateFunction.getClass.getSimpleName == "BloomFilterAggregate") + val sparkFinalMode = modes.contains(Final) && + findCometPartialAgg(aggregate.child).isEmpty && + !hasBloomFilterAgg if (multiMode || sparkFinalMode) { return None diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 1b2373ad71..b40ce6c2bb 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Bloom import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -1149,6 +1150,68 @@ class CometExecSuite extends CometTestBase { spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) } + test("bloom_filter_agg - Spark partial / Comet final merge") { + // This test exercises the merge_filter() fix that handles Spark's full serialization + // format (12-byte header + bits) when merging from Spark partial to Comet final aggregates. + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + // Helper to count operators in plan + def countOperators(plan: SparkPlan, opClass: Class[_]): Int = { + stripAQEPlan(plan).collect { + case stage: QueryStageExec => + countOperators(stage.plan, opClass) + case op if op.getClass.isAssignableFrom(opClass) => 1 + }.sum + } + + withParquetTable( + (0 until 1000) + .map(_ => (Random.nextInt(1000), Random.nextInt(100))), + "tbl") { + + withSQLConf( + // Disable Comet partial aggregates to force Spark partial / Comet final scenario + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + + val df = sql( + "SELECT bloom_filter_agg(cast(_2 as long), cast(1000 as long)) FROM tbl GROUP BY _1") + + // Verify the query executes successfully (tests merge_filter compatibility) + checkSparkAnswer(df) + + // Verify we have Spark partial aggregates and Comet final aggregates + val plan = stripAQEPlan(df.queryExecution.executedPlan) + val sparkPartialAggs = plan.collect { + case agg: HashAggregateExec if agg.aggregateExpressions.exists(_.mode == Partial) => agg + } + val cometFinalAggs = plan.collect { + case agg: CometHashAggregateExec if agg.aggregateExpressions.exists(_.mode == Final) => + agg + } + + assert( + sparkPartialAggs.nonEmpty, + s"Expected Spark partial aggregates but found none. Plan: $plan") + assert( + cometFinalAggs.nonEmpty, + s"Expected Comet final aggregates but found none. Plan: $plan") + } + } + + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + test("sort (non-global)") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc)