Skip to content

Commit ab7b961

Browse files
committed
[SPARK-23942][PYTHON][SQL] Makes collect in PySpark as action for a query executor listener
## What changes were proposed in this pull request? This PR proposes to add `collect` to a query executor as an action. Seems `collect` / `collect` with Arrow are not recognised via `QueryExecutionListener` as an action. For example, if we have a custom listener as below: ```scala package org.apache.spark.sql import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener class TestQueryExecutionListener extends QueryExecutionListener with Logging { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { logError("Look at me! I'm 'onSuccess'") } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } } ``` and set `spark.sql.queryExecutionListeners` to `org.apache.spark.sql.TestQueryExecutionListener` Other operations in PySpark or Scala side seems fine: ```python >>> sql("SELECT * FROM range(1)").show() ``` ``` 18/04/09 17:02:04 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' +---+ | id| +---+ | 0| +---+ ``` ```scala scala> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:58:41 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' res1: Array[org.apache.spark.sql.Row] = Array([0]) ``` but .. **Before** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` id 0 0 ``` **After** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:57:58 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` 18/04/09 17:53:26 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' id 0 0 ``` ## How was this patch tested? I have manually tested as described above and unit test was added. Author: hyukjinkwon <[email protected]> Closes apache#21007 from HyukjinKwon/SPARK-23942.
1 parent 14291b0 commit ab7b961

File tree

3 files changed

+134
-17
lines changed

3 files changed

+134
-17
lines changed

python/pyspark/sql/tests.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,12 @@ def __init__(self, key, value):
186186
self.value = value
187187

188188

189-
class ReusedSQLTestCase(ReusedPySparkTestCase):
190-
@classmethod
191-
def setUpClass(cls):
192-
ReusedPySparkTestCase.setUpClass()
193-
cls.spark = SparkSession(cls.sc)
194-
195-
@classmethod
196-
def tearDownClass(cls):
197-
ReusedPySparkTestCase.tearDownClass()
198-
cls.spark.stop()
189+
class SQLTestUtils(object):
190+
"""
191+
This util assumes the instance of this to have 'spark' attribute, having a spark session.
192+
It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
193+
the implementation of this class has 'spark' attribute.
194+
"""
199195

200196
@contextmanager
201197
def sql_conf(self, pairs):
@@ -204,6 +200,7 @@ def sql_conf(self, pairs):
204200
`value` to the configuration `key` and then restores it back when it exits.
205201
"""
206202
assert isinstance(pairs, dict), "pairs should be a dictionary."
203+
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
207204

208205
keys = pairs.keys()
209206
new_values = pairs.values()
@@ -219,6 +216,18 @@ def sql_conf(self, pairs):
219216
else:
220217
self.spark.conf.set(key, old_value)
221218

219+
220+
class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
221+
@classmethod
222+
def setUpClass(cls):
223+
ReusedPySparkTestCase.setUpClass()
224+
cls.spark = SparkSession(cls.sc)
225+
226+
@classmethod
227+
def tearDownClass(cls):
228+
ReusedPySparkTestCase.tearDownClass()
229+
cls.spark.stop()
230+
222231
def assertPandasEqual(self, expected, result):
223232
msg = ("DataFrames are not equal: " +
224233
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
@@ -3066,6 +3075,64 @@ def test_sparksession_with_stopped_sparkcontext(self):
30663075
sc.stop()
30673076

30683077

3078+
class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
3079+
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
3080+
# static and immutable. This can't be set or unset, for example, via `spark.conf`.
3081+
3082+
@classmethod
3083+
def setUpClass(cls):
3084+
import glob
3085+
from pyspark.find_spark_home import _find_spark_home
3086+
3087+
SPARK_HOME = _find_spark_home()
3088+
filename_pattern = (
3089+
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
3090+
"TestQueryExecutionListener.class")
3091+
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
3092+
raise unittest.SkipTest(
3093+
"'org.apache.spark.sql.TestQueryExecutionListener' is not "
3094+
"available. Will skip the related tests.")
3095+
3096+
# Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
3097+
cls.spark = SparkSession.builder \
3098+
.master("local[4]") \
3099+
.appName(cls.__name__) \
3100+
.config(
3101+
"spark.sql.queryExecutionListeners",
3102+
"org.apache.spark.sql.TestQueryExecutionListener") \
3103+
.getOrCreate()
3104+
3105+
@classmethod
3106+
def tearDownClass(cls):
3107+
cls.spark.stop()
3108+
3109+
def tearDown(self):
3110+
self.spark._jvm.OnSuccessCall.clear()
3111+
3112+
def test_query_execution_listener_on_collect(self):
3113+
self.assertFalse(
3114+
self.spark._jvm.OnSuccessCall.isCalled(),
3115+
"The callback from the query execution listener should not be called before 'collect'")
3116+
self.spark.sql("SELECT * FROM range(1)").collect()
3117+
self.assertTrue(
3118+
self.spark._jvm.OnSuccessCall.isCalled(),
3119+
"The callback from the query execution listener should be called after 'collect'")
3120+
3121+
@unittest.skipIf(
3122+
not _have_pandas or not _have_pyarrow,
3123+
_pandas_requirement_message or _pyarrow_requirement_message)
3124+
def test_query_execution_listener_on_collect_with_arrow(self):
3125+
with self.sql_conf({"spark.sql.execution.arrow.enabled": True}):
3126+
self.assertFalse(
3127+
self.spark._jvm.OnSuccessCall.isCalled(),
3128+
"The callback from the query execution listener should not be "
3129+
"called before 'toPandas'")
3130+
self.spark.sql("SELECT * FROM range(1)").toPandas()
3131+
self.assertTrue(
3132+
self.spark._jvm.OnSuccessCall.isCalled(),
3133+
"The callback from the query execution listener should be called after 'toPandas'")
3134+
3135+
30693136
class SparkSessionTests(PySparkTestCase):
30703137

30713138
# This test is separate because it's closely related with session's start and stop.

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3189,10 +3189,10 @@ class Dataset[T] private[sql](
31893189

31903190
private[sql] def collectToPython(): Int = {
31913191
EvaluatePython.registerPicklers()
3192-
withNewExecutionId {
3192+
withAction("collectToPython", queryExecution) { plan =>
31933193
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
3194-
val iter = new SerDeUtil.AutoBatchedPickler(
3195-
queryExecution.executedPlan.executeCollect().iterator.map(toJava))
3194+
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
3195+
plan.executeCollect().iterator.map(toJava))
31963196
PythonRDD.serveIterator(iter, "serve-DataFrame")
31973197
}
31983198
}
@@ -3201,8 +3201,9 @@ class Dataset[T] private[sql](
32013201
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
32023202
*/
32033203
private[sql] def collectAsArrowToPython(): Int = {
3204-
withNewExecutionId {
3205-
val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
3204+
withAction("collectAsArrowToPython", queryExecution) { plan =>
3205+
val iter: Iterator[Array[Byte]] =
3206+
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
32063207
PythonRDD.serveIterator(iter, "serve-Arrow")
32073208
}
32083209
}
@@ -3311,14 +3312,19 @@ class Dataset[T] private[sql](
33113312
}
33123313

33133314
/** Convert to an RDD of ArrowPayload byte arrays */
3314-
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
3315+
private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
33153316
val schemaCaptured = this.schema
33163317
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
33173318
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
3318-
queryExecution.toRdd.mapPartitionsInternal { iter =>
3319+
plan.execute().mapPartitionsInternal { iter =>
33193320
val context = TaskContext.get()
33203321
ArrowConverters.toPayloadIterator(
33213322
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
33223323
}
33233324
}
3325+
3326+
// This is only used in tests, for now.
3327+
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
3328+
toArrowPayload(queryExecution.executedPlan)
3329+
}
33243330
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.util.concurrent.atomic.AtomicBoolean
21+
22+
import org.apache.spark.sql.execution.QueryExecution
23+
import org.apache.spark.sql.util.QueryExecutionListener
24+
25+
26+
class TestQueryExecutionListener extends QueryExecutionListener {
27+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
28+
OnSuccessCall.isOnSuccessCalled.set(true)
29+
}
30+
31+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { }
32+
}
33+
34+
/**
35+
* This has a variable to check if `onSuccess` is actually called or not. Currently, this is for
36+
* the test case in PySpark. See SPARK-23942.
37+
*/
38+
object OnSuccessCall {
39+
val isOnSuccessCalled = new AtomicBoolean(false)
40+
41+
def isCalled(): Boolean = isOnSuccessCalled.get()
42+
43+
def clear(): Unit = isOnSuccessCalled.set(false)
44+
}

0 commit comments

Comments
 (0)