Skip to content

Commit 262cc96

Browse files
committed
fix
1 parent 5a7f3e5 commit 262cc96

File tree

3 files changed

+98
-24
lines changed

3 files changed

+98
-24
lines changed

auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.auron.utils
1818

1919
import org.apache.spark.sql._
20+
import org.apache.spark.sql.execution.joins.AuronExistenceJoinSuite
2021

2122
class AuronSparkTestSettings extends SparkTestSettings {
2223
{
@@ -42,6 +43,8 @@ class AuronSparkTestSettings extends SparkTestSettings {
4243

4344
enableSuite[AuronTypedImperativeAggregateSuite]
4445

46+
enableSuite[AuronExistenceJoinSuite]
47+
4548
// Will be implemented in the future.
4649
override def getSQLQueryTestSettings = new SQLQueryTestSettings {
4750
override def getResourceFilePath: String = ???
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.joins
18+
19+
import org.apache.spark.sql.SparkTestsSharedSessionBase
20+
21+
class AuronExistenceJoinSuite extends ExistenceJoinSuite with SparkTestsSharedSessionBase {}

spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@ import org.apache.spark.sql.auron.NativeRDD
2727
import org.apache.spark.sql.auron.NativeSupports
2828
import org.apache.spark.sql.auron.Shims
2929
import org.apache.spark.sql.catalyst.expressions.Expression
30-
import org.apache.spark.sql.catalyst.plans.FullOuter
31-
import org.apache.spark.sql.catalyst.plans.JoinType
32-
import org.apache.spark.sql.catalyst.plans.LeftAnti
33-
import org.apache.spark.sql.catalyst.plans.LeftOuter
34-
import org.apache.spark.sql.catalyst.plans.LeftSemi
35-
import org.apache.spark.sql.catalyst.plans.RightOuter
30+
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
3631
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
3732
import org.apache.spark.sql.execution.BinaryExecNode
3833
import org.apache.spark.sql.execution.SparkPlan
@@ -43,7 +38,7 @@ import org.apache.spark.sql.types.LongType
4338

4439
import org.apache.auron.{protobuf => pb}
4540
import org.apache.auron.metric.SparkMetricNode
46-
import org.apache.auron.protobuf.JoinOn
41+
import org.apache.auron.protobuf.{EmptyPartitionsExecNode, JoinOn, PhysicalPlanNode}
4742

4843
abstract class NativeBroadcastJoinBase(
4944
override val left: SparkPlan,
@@ -127,44 +122,99 @@ abstract class NativeBroadcastJoinBase(
127122
override def doExecuteNative(): NativeRDD = {
128123
val leftRDD = NativeHelper.executeNative(left)
129124
val rightRDD = NativeHelper.executeNative(right)
130-
val nativeMetrics = SparkMetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil)
131-
val nativeSchema = this.nativeSchema
132-
val nativeJoinType = this.nativeJoinType
133-
val nativeJoinOn = this.nativeJoinOn
134125

135126
val (probedRDD, builtRDD) = broadcastSide match {
136127
case BroadcastLeft => (rightRDD, leftRDD)
137128
case BroadcastRight => (leftRDD, rightRDD)
138129
}
139130

131+
// Handle the edge case when probed side is empty (no partitions)
132+
// This matches Spark's BroadcastNestedLoopJoinExec behavior for condition.isEmpty case:
133+
// val streamExists = !streamed.executeTake(1).isEmpty
134+
// if (streamExists == exists) sparkContext.makeRDD(relation.value)
135+
// else sparkContext.emptyRDD
136+
// where exists = true for Semi, false for Anti
137+
//
138+
// Note: This optimization only applies to Semi/Anti joins.
139+
if (probedRDD.partitions.isEmpty) {
140+
joinType match {
141+
case LeftAnti =>
142+
return builtRDD
143+
case LeftSemi =>
144+
return probedRDD
145+
case _ =>
146+
}
147+
}
148+
149+
val nativeMetrics = SparkMetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil)
150+
val nativeSchema = this.nativeSchema
151+
val nativeJoinType = this.nativeJoinType
152+
val nativeJoinOn = this.nativeJoinOn
153+
140154
val probedShuffleReadFull = probedRDD.isShuffleReadFull && (broadcastSide match {
141155
case BroadcastLeft =>
142156
Seq(FullOuter, RightOuter).contains(joinType)
143157
case BroadcastRight =>
144158
Seq(FullOuter, LeftOuter, LeftSemi, LeftAnti).contains(joinType)
145159
})
146160

161+
// For ExistenceJoin with empty probed side, use builtRDD.partitions to ensure
162+
// native join can execute and finish() will output all build rows with exists=false
163+
val (rddPartitions, rddPartitioner, rddDependencies) =
164+
if (probedRDD.partitions.isEmpty && joinType.isInstanceOf[ExistenceJoin]) {
165+
(builtRDD.partitions, builtRDD.partitioner, new OneToOneDependency(builtRDD) :: Nil)
166+
} else {
167+
(probedRDD.partitions, probedRDD.partitioner, new OneToOneDependency(probedRDD) :: Nil)
168+
}
169+
147170
new NativeRDD(
148171
sparkContext,
149172
nativeMetrics,
150-
probedRDD.partitions,
151-
rddPartitioner = probedRDD.partitioner,
152-
rddDependencies = new OneToOneDependency(probedRDD) :: Nil,
173+
rddPartitions,
174+
rddPartitioner = rddPartitioner,
175+
rddDependencies = rddDependencies,
153176
probedShuffleReadFull,
154177
(partition, context) => {
155178
val partition0 = new Partition() {
156179
override def index: Int = 0
157180
}
158-
val (leftChild, rightChild) = broadcastSide match {
159-
case BroadcastLeft =>
160-
(
161-
leftRDD.nativePlan(partition0, context),
162-
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
163-
case BroadcastRight =>
164-
(
165-
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
166-
rightRDD.nativePlan(partition0, context))
167-
}
181+
val (leftChild, rightChild) =
182+
if (probedRDD.partitions.isEmpty && joinType.isInstanceOf[ExistenceJoin]) {
183+
val probedSchema = broadcastSide match {
184+
case BroadcastLeft => Util.getNativeSchema(right.output)
185+
case BroadcastRight => Util.getNativeSchema(left.output)
186+
}
187+
val emptyProbedPlan = PhysicalPlanNode
188+
.newBuilder()
189+
.setEmptyPartitions(
190+
EmptyPartitionsExecNode
191+
.newBuilder()
192+
.setNumPartitions(1)
193+
.setSchema(probedSchema)
194+
.build())
195+
.build()
196+
broadcastSide match {
197+
case BroadcastLeft =>
198+
(
199+
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
200+
emptyProbedPlan)
201+
case BroadcastRight =>
202+
(
203+
emptyProbedPlan,
204+
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
205+
}
206+
} else {
207+
broadcastSide match {
208+
case BroadcastLeft =>
209+
(
210+
leftRDD.nativePlan(partition0, context),
211+
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
212+
case BroadcastRight =>
213+
(
214+
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
215+
rightRDD.nativePlan(partition0, context))
216+
}
217+
}
168218
val cachedBuildHashMapId = s"bhm_stage${context.stageId}_rdd${builtRDD.id}"
169219

170220
val broadcastJoinExec = pb.BroadcastJoinExecNode

0 commit comments

Comments
 (0)