Skip to content

Commit 5866561

Browse files
authored
feat: Set/cancel with job tag and make max broadcast table size configurable (apache#1693)
1 parent e26d8d1 commit 5866561

File tree

5 files changed

+168
-12
lines changed

5 files changed

+168
-12
lines changed

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer
4646
import com.google.common.base.Objects
4747

4848
import org.apache.comet.CometRuntimeException
49+
import org.apache.comet.shims.ShimCometBroadcastExchangeExec
4950

5051
/**
5152
* A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a
@@ -64,8 +65,8 @@ case class CometBroadcastExchangeExec(
6465
mode: BroadcastMode,
6566
override val child: SparkPlan)
6667
extends BroadcastExchangeLike
68+
with ShimCometBroadcastExchangeExec
6769
with CometPlan {
68-
import CometBroadcastExchangeExec._
6970

7071
override val runId: UUID = UUID.randomUUID
7172

@@ -117,11 +118,7 @@ case class CometBroadcastExchangeExec(
117118
session,
118119
CometBroadcastExchangeExec.executionContext) {
119120
try {
120-
// Setup a job group here so later it may get cancelled by groupId if necessary.
121-
sparkContext.setJobGroup(
122-
runId.toString,
123-
s"broadcast exchange (runId $runId)",
124-
interruptOnCancel = true)
121+
setJobGroupOrTag(sparkContext, this)
125122
val beforeCollect = System.nanoTime()
126123

127124
val countsAndBytes = child match {
@@ -167,9 +164,10 @@ case class CometBroadcastExchangeExec(
167164
val dataSize = batches.map(_.size).sum
168165

169166
longMetric("dataSize") += dataSize
170-
if (dataSize >= MAX_BROADCAST_TABLE_BYTES) {
167+
val maxBytes = maxBroadcastTableBytes(conf)
168+
if (dataSize >= maxBytes) {
171169
throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(
172-
MAX_BROADCAST_TABLE_BYTES,
170+
maxBytes,
173171
dataSize)
174172
}
175173

@@ -233,7 +231,7 @@ case class CometBroadcastExchangeExec(
233231
case ex: TimeoutException =>
234232
logError(s"Could not execute broadcast in $timeout secs.", ex)
235233
if (!relationFuture.isDone) {
236-
sparkContext.cancelJobGroup(runId.toString)
234+
cancelJobGroup(sparkContext, this)
237235
relationFuture.cancel(true)
238236
}
239237
throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex))
@@ -259,8 +257,6 @@ case class CometBroadcastExchangeExec(
259257
}
260258

261259
object CometBroadcastExchangeExec {
262-
val MAX_BROADCAST_TABLE_BYTES: Long = 8L << 30
263-
264260
private[comet] val executionContext = ExecutionContext.fromExecutorService(
265261
ThreadUtils.newDaemonCachedThreadPool(
266262
"comet-broadcast-exchange",
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.comet.shims
20+
21+
import org.apache.comet.shims.ShimCometBroadcastExchangeExec.SPARK_MAX_BROADCAST_TABLE_SIZE
22+
import org.apache.spark.SparkContext
23+
import org.apache.spark.network.util.JavaUtils
24+
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
25+
import org.apache.spark.sql.internal.SQLConf
26+
27+
trait ShimCometBroadcastExchangeExec {
28+
29+
def setJobGroupOrTag(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = {
30+
// Setup a job group here so later it may get cancelled by groupId if necessary.
31+
sc.setJobGroup(
32+
broadcastExchange.runId.toString,
33+
s"broadcast exchange (runId ${broadcastExchange.runId})",
34+
interruptOnCancel = true)
35+
}
36+
37+
def cancelJobGroup(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = {
38+
sc.cancelJobGroup(broadcastExchange.runId.toString)
39+
}
40+
41+
def maxBroadcastTableBytes(conf: SQLConf): Long = {
42+
JavaUtils.byteStringAsBytes(conf.getConfString(SPARK_MAX_BROADCAST_TABLE_SIZE, "8GB"))
43+
}
44+
45+
}
46+
47+
object ShimCometBroadcastExchangeExec {
48+
val SPARK_MAX_BROADCAST_TABLE_SIZE = "spark.sql.maxBroadcastTableSize"
49+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.comet.shims
20+
21+
import org.apache.comet.shims.ShimCometBroadcastExchangeExec.SPARK_MAX_BROADCAST_TABLE_SIZE
22+
import org.apache.spark.SparkContext
23+
import org.apache.spark.network.util.JavaUtils
24+
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
25+
import org.apache.spark.sql.internal.SQLConf
26+
27+
trait ShimCometBroadcastExchangeExec {
28+
29+
def setJobGroupOrTag(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = {
30+
// Setup a job tag here so later it may get cancelled by tag if necessary.
31+
sc.addJobTag(broadcastExchange.jobTag)
32+
sc.setInterruptOnCancel(true)
33+
}
34+
35+
def cancelJobGroup(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = {
36+
sc.cancelJobsWithTag(broadcastExchange.jobTag)
37+
}
38+
39+
def maxBroadcastTableBytes(conf: SQLConf): Long = {
40+
JavaUtils.byteStringAsBytes(conf.getConfString(SPARK_MAX_BROADCAST_TABLE_SIZE, "8GB"))
41+
}
42+
43+
}
44+
45+
object ShimCometBroadcastExchangeExec {
46+
val SPARK_MAX_BROADCAST_TABLE_SIZE = "spark.sql.maxBroadcastTableSize"
47+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.comet.shims
20+
21+
import org.apache.comet.shims.ShimCometBroadcastExchangeExec.SPARK_MAX_BROADCAST_TABLE_SIZE
22+
import org.apache.spark.SparkContext
23+
import org.apache.spark.network.util.JavaUtils
24+
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
25+
import org.apache.spark.sql.internal.SQLConf
26+
27+
trait ShimCometBroadcastExchangeExec {
28+
29+
def setJobGroupOrTag(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = {
30+
// Setup a job tag here so later it may get cancelled by tag if necessary.
31+
sc.addJobTag(broadcastExchange.jobTag)
32+
sc.setInterruptOnCancel(true)
33+
}
34+
35+
def cancelJobGroup(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = {
36+
sc.cancelJobsWithTag(broadcastExchange.jobTag)
37+
}
38+
39+
def maxBroadcastTableBytes(conf: SQLConf): Long = {
40+
JavaUtils.byteStringAsBytes(conf.getConfString(SPARK_MAX_BROADCAST_TABLE_SIZE, "8GB"))
41+
}
42+
43+
}
44+
45+
object ShimCometBroadcastExchangeExec {
46+
val SPARK_MAX_BROADCAST_TABLE_SIZE = "spark.sql.maxBroadcastTableSize"
47+
}

spark/src/test/scala/org/apache/comet/CometNativeSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.comet
2222
import org.apache.spark.{SparkEnv, SparkException}
2323
import org.apache.spark.sql.CometTestBase
2424
import org.apache.spark.sql.catalyst.expressions.PrettyAttribute
25-
import org.apache.spark.sql.comet.{CometExec, CometExecUtils}
25+
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometExec, CometExecUtils}
2626
import org.apache.spark.sql.types.LongType
2727
import org.apache.spark.sql.vectorized.ColumnarBatch
2828

@@ -97,4 +97,21 @@ class CometNativeSuite extends CometTestBase {
9797
}
9898
}
9999
}
100+
101+
test("test maxBroadcastTableSize") {
102+
withSQLConf("spark.sql.maxBroadcastTableSize" -> "10B") {
103+
spark.range(0, 1000).createOrReplaceTempView("t1")
104+
spark.range(0, 100).createOrReplaceTempView("t2")
105+
val df = spark.sql("select /*+ BROADCAST(t2) */ * from t1 join t2 on t1.id = t2.id")
106+
val exception = intercept[SparkException] {
107+
df.collect()
108+
}
109+
assert(
110+
exception.getMessage.contains("Cannot broadcast the table that is larger than 10.0 B"))
111+
val broadcasts = collect(df.queryExecution.executedPlan) {
112+
case p: CometBroadcastExchangeExec => p
113+
}
114+
assert(broadcasts.size == 1)
115+
}
116+
}
100117
}

0 commit comments

Comments
 (0)