Skip to content

Commit c277afb

Browse files
committed
[SPARK-27992][PYTHON] Allow Python to join with connection thread to propagate errors
## What changes were proposed in this pull request? Currently with `toLocalIterator()` and `toPandas()` with Arrow enabled, if the Spark job being run in the background serving thread errors, it will be caught and sent to Python through the PySpark serializer. This is not the ideal solution because it is only catch a SparkException, it won't handle an error that occurs in the serializer, and each method has to have it's own special handling to propagate the error. This PR instead returns the Python Server object along with the serving port and authentication info, so that it allows the Python caller to join with the serving thread. During the call to join, the serving thread Future is completed either successfully or with an exception. In the latter case, the exception will be propagated to Python through the Py4j call. ## How was this patch tested? Existing tests Closes apache#24834 from BryanCutler/pyspark-propagate-server-error-SPARK-27992. Authored-by: Bryan Cutler <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent 7eeca02 commit c277afb

File tree

9 files changed

+161
-126
lines changed

9 files changed

+161
-126
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging
3838
import org.apache.spark.internal.config.BUFFER_SIZE
3939
import org.apache.spark.network.util.JavaUtils
4040
import org.apache.spark.rdd.RDD
41-
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
41+
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer}
4242
import org.apache.spark.util._
4343

4444

@@ -137,8 +137,9 @@ private[spark] object PythonRDD extends Logging {
137137
* (effectively a collect()), but allows you to run on a certain subset of partitions,
138138
* or to enable local execution.
139139
*
140-
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
141-
* data collected from this job, and the secret for authentication.
140+
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
141+
* data collected from this job, the secret for authentication, and a socket auth
142+
* server object that can be used to join the JVM serving thread in Python.
142143
*/
143144
def runJob(
144145
sc: SparkContext,
@@ -156,8 +157,9 @@ private[spark] object PythonRDD extends Logging {
156157
/**
157158
* A helper function to collect an RDD as an iterator, then serve it via socket.
158159
*
159-
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
160-
* data collected from this job, and the secret for authentication.
160+
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
161+
* data collected from this job, the secret for authentication, and a socket auth
162+
* server object that can be used to join the JVM serving thread in Python.
161163
*/
162164
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
163165
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
@@ -168,58 +170,59 @@ private[spark] object PythonRDD extends Logging {
168170
* are collected as separate jobs, by order of index. Partition data is first requested by a
169171
* non-zero integer to start a collection job. The response is prefaced by an integer with 1
170172
* meaning partition data will be served, 0 meaning the local iterator has been consumed,
171-
* and -1 meaining an error occurred during collection. This function is used by
173+
* and -1 meaning an error occurred during collection. This function is used by
172174
* pyspark.rdd._local_iterator_from_socket().
173175
*
174-
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
175-
* data collected from these jobs, and the secret for authentication.
176+
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
177+
* data collected from this job, the secret for authentication, and a socket auth
178+
* server object that can be used to join the JVM serving thread in Python.
176179
*/
177180
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
178-
val (port, secret) = SocketAuthServer.setupOneConnectionServer(
179-
authHelper, "serve toLocalIterator") { s =>
180-
val out = new DataOutputStream(s.getOutputStream)
181-
val in = new DataInputStream(s.getInputStream)
182-
Utils.tryWithSafeFinally {
183-
181+
val handleFunc = (sock: Socket) => {
182+
val out = new DataOutputStream(sock.getOutputStream)
183+
val in = new DataInputStream(sock.getInputStream)
184+
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
184185
// Collects a partition on each iteration
185186
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
186187
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
187188
}
188189

189-
// Read request for data and send next partition if nonzero
190+
// Write data until iteration is complete, client stops iteration, or error occurs
190191
var complete = false
191-
while (!complete && in.readInt() != 0) {
192-
if (collectPartitionIter.hasNext) {
193-
try {
194-
// Attempt to collect the next partition
195-
val partitionArray = collectPartitionIter.next()
196-
197-
// Send response there is a partition to read
198-
out.writeInt(1)
199-
200-
// Write the next object and signal end of data for this iteration
201-
writeIteratorToStream(partitionArray.toIterator, out)
202-
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
203-
out.flush()
204-
} catch {
205-
case e: SparkException =>
206-
// Send response that an error occurred followed by error message
207-
out.writeInt(-1)
208-
writeUTF(e.getMessage, out)
209-
complete = true
210-
}
192+
while (!complete) {
193+
194+
// Read request for data, value of zero will stop iteration or non-zero to continue
195+
if (in.readInt() == 0) {
196+
complete = true
197+
} else if (collectPartitionIter.hasNext) {
198+
199+
// Client requested more data, attempt to collect the next partition
200+
val partitionArray = collectPartitionIter.next()
201+
202+
// Send response there is a partition to read
203+
out.writeInt(1)
204+
205+
// Write the next object and signal end of data for this iteration
206+
writeIteratorToStream(partitionArray.toIterator, out)
207+
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
208+
out.flush()
211209
} else {
212210
// Send response there are no more partitions to read and close
213211
out.writeInt(0)
214212
complete = true
215213
}
216214
}
217-
} {
215+
})(catchBlock = {
216+
// Send response that an error occurred, original exception is re-thrown
217+
out.writeInt(-1)
218+
}, finallyBlock = {
218219
out.close()
219220
in.close()
220-
}
221+
})
221222
}
222-
Array(port, secret)
223+
224+
val server = new SocketFuncServer(authHelper, "serve toLocalIterator", handleFunc)
225+
Array(server.port, server.secret, server)
223226
}
224227

225228
def readRDDFromFile(
@@ -443,8 +446,9 @@ private[spark] object PythonRDD extends Logging {
443446
*
444447
* The thread will terminate after all the data are sent or any exceptions happen.
445448
*
446-
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
447-
* data collected from this job, and the secret for authentication.
449+
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
450+
* data collected from this job, the secret for authentication, and a socket auth
451+
* server object that can be used to join the JVM serving thread in Python.
448452
*/
449453
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
450454
serveToStream(threadName) { out =>
@@ -464,10 +468,14 @@ private[spark] object PythonRDD extends Logging {
464468
*
465469
* The thread will terminate after the block of code is executed or any
466470
* exceptions happen.
471+
*
472+
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
473+
* data collected from this job, the secret for authentication, and a socket auth
474+
* server object that can be used to join the JVM serving thread in Python.
467475
*/
468476
private[spark] def serveToStream(
469477
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
470-
SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc)
478+
SocketAuthServer.serveToStream(threadName, authHelper)(writeFunc)
471479
}
472480

473481
private def getMergedConf(confAsMap: java.util.HashMap[String, String],

core/src/main/scala/org/apache/spark/api/r/RRDD.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
2929
import org.apache.spark.broadcast.Broadcast
3030
import org.apache.spark.internal.Logging
3131
import org.apache.spark.rdd.RDD
32-
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
32+
import org.apache.spark.security.SocketAuthServer
3333

3434
private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
3535
parent: RDD[T],
@@ -166,7 +166,7 @@ private[spark] object RRDD {
166166

167167
private[spark] def serveToStream(
168168
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
169-
SocketAuthHelper.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc)
169+
SocketAuthServer.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc)
170170
}
171171
}
172172

core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.security
1919

20-
import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream}
20+
import java.io.{DataInputStream, DataOutputStream}
2121
import java.net.Socket
2222
import java.nio.charset.StandardCharsets.UTF_8
2323

@@ -113,21 +113,4 @@ private[spark] class SocketAuthHelper(conf: SparkConf) {
113113
dout.write(bytes, 0, bytes.length)
114114
dout.flush()
115115
}
116-
117-
}
118-
119-
private[spark] object SocketAuthHelper {
120-
def serveToStream(
121-
threadName: String,
122-
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
123-
val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { s =>
124-
val out = new BufferedOutputStream(s.getOutputStream())
125-
Utils.tryWithSafeFinally {
126-
writeFunc(out)
127-
} {
128-
out.close()
129-
}
130-
}
131-
Array(port, secret)
132-
}
133116
}

core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.security
1919

20+
import java.io.{BufferedOutputStream, OutputStream}
2021
import java.net.{InetAddress, ServerSocket, Socket}
2122

2223
import scala.concurrent.Promise
@@ -25,12 +26,15 @@ import scala.util.Try
2526

2627
import org.apache.spark.SparkEnv
2728
import org.apache.spark.network.util.JavaUtils
28-
import org.apache.spark.util.ThreadUtils
29+
import org.apache.spark.util.{ThreadUtils, Utils}
2930

3031

3132
/**
3233
* Creates a server in the JVM to communicate with external processes (e.g., Python and R) for
3334
* handling one batch of data, with authentication and error handling.
35+
*
36+
* The socket server can only accept one connection, or close if no connection
37+
* in 15 seconds.
3438
*/
3539
private[spark] abstract class SocketAuthServer[T](
3640
authHelper: SocketAuthHelper,
@@ -41,10 +45,30 @@ private[spark] abstract class SocketAuthServer[T](
4145

4246
private val promise = Promise[T]()
4347

44-
val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock =>
45-
promise.complete(Try(handleConnection(sock)))
48+
private def startServer(): (Int, String) = {
49+
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
50+
// Close the socket if no connection in 15 seconds
51+
serverSocket.setSoTimeout(15000)
52+
53+
new Thread(threadName) {
54+
setDaemon(true)
55+
override def run(): Unit = {
56+
var sock: Socket = null
57+
try {
58+
sock = serverSocket.accept()
59+
authHelper.authClient(sock)
60+
promise.complete(Try(handleConnection(sock)))
61+
} finally {
62+
JavaUtils.closeQuietly(serverSocket)
63+
JavaUtils.closeQuietly(sock)
64+
}
65+
}
66+
}.start()
67+
(serverSocket.getLocalPort, authHelper.secret)
4668
}
4769

70+
val (port, secret) = startServer()
71+
4872
/**
4973
* Handle a connection which has already been authenticated. Any error from this function
5074
* will clean up this connection and the entire server, and get propagated to [[getResult]].
@@ -66,42 +90,50 @@ private[spark] abstract class SocketAuthServer[T](
6690

6791
}
6892

93+
/**
94+
* Create a socket server class and run user function on the socket in a background thread
95+
* that can read and write to the socket input/output streams. The function is passed in a
96+
* socket that has been connected and authenticated.
97+
*/
98+
private[spark] class SocketFuncServer(
99+
authHelper: SocketAuthHelper,
100+
threadName: String,
101+
func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, threadName) {
102+
103+
override def handleConnection(sock: Socket): Unit = {
104+
func(sock)
105+
}
106+
}
107+
69108
private[spark] object SocketAuthServer {
70109

71110
/**
72-
* Create a socket server and run user function on the socket in a background thread.
111+
* Convenience function to create a socket server and run a user function in a background
112+
* thread to write to an output stream.
73113
*
74114
* The socket server can only accept one connection, or close if no connection
75115
* in 15 seconds.
76116
*
77-
* The thread will terminate after the supplied user function, or if there are any exceptions.
78-
*
79-
* If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]]
80-
*
81-
* @return The port number of a local socket and the secret for authentication.
117+
* @param threadName Name for the background serving thread.
118+
* @param authHelper SocketAuthHelper for authentication
119+
* @param writeFunc User function to write to a given OutputStream
120+
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
121+
* data collected from this job, the secret for authentication, and a socket auth
122+
* server object that can be used to join the JVM serving thread in Python.
82123
*/
83-
def setupOneConnectionServer(
84-
authHelper: SocketAuthHelper,
85-
threadName: String)
86-
(func: Socket => Unit): (Int, String) = {
87-
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
88-
// Close the socket if no connection in 15 seconds
89-
serverSocket.setSoTimeout(15000)
90-
91-
new Thread(threadName) {
92-
setDaemon(true)
93-
override def run(): Unit = {
94-
var sock: Socket = null
95-
try {
96-
sock = serverSocket.accept()
97-
authHelper.authClient(sock)
98-
func(sock)
99-
} finally {
100-
JavaUtils.closeQuietly(serverSocket)
101-
JavaUtils.closeQuietly(sock)
102-
}
124+
def serveToStream(
125+
threadName: String,
126+
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
127+
val handleFunc = (sock: Socket) => {
128+
val out = new BufferedOutputStream(sock.getOutputStream())
129+
Utils.tryWithSafeFinally {
130+
writeFunc(out)
131+
} {
132+
out.close()
103133
}
104-
}.start()
105-
(serverSocket.getLocalPort, authHelper.secret)
134+
}
135+
136+
val server = new SocketFuncServer(authHelper, threadName, handleFunc)
137+
Array(server.port, server.secret, server)
106138
}
107139
}

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging {
13891389
originalThrowable = cause
13901390
try {
13911391
logError("Aborting task", originalThrowable)
1392-
TaskContext.get().markTaskFailed(originalThrowable)
1392+
if (TaskContext.get() != null) {
1393+
TaskContext.get().markTaskFailed(originalThrowable)
1394+
}
13931395
catchBlock
13941396
} catch {
13951397
case t: Throwable =>

0 commit comments

Comments
 (0)