@@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging
38
38
import org .apache .spark .internal .config .BUFFER_SIZE
39
39
import org .apache .spark .network .util .JavaUtils
40
40
import org .apache .spark .rdd .RDD
41
- import org .apache .spark .security .{SocketAuthHelper , SocketAuthServer }
41
+ import org .apache .spark .security .{SocketAuthHelper , SocketAuthServer , SocketFuncServer }
42
42
import org .apache .spark .util ._
43
43
44
44
@@ -137,8 +137,9 @@ private[spark] object PythonRDD extends Logging {
137
137
* (effectively a collect()), but allows you to run on a certain subset of partitions,
138
138
* or to enable local execution.
139
139
*
140
- * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
141
- * data collected from this job, and the secret for authentication.
140
+ * @return 3-tuple (as a Java array) with the port number of a local socket which serves the
141
+ * data collected from this job, the secret for authentication, and a socket auth
142
+ * server object that can be used to join the JVM serving thread in Python.
142
143
*/
143
144
def runJob (
144
145
sc : SparkContext ,
@@ -156,8 +157,9 @@ private[spark] object PythonRDD extends Logging {
156
157
/**
157
158
* A helper function to collect an RDD as an iterator, then serve it via socket.
158
159
*
159
- * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
160
- * data collected from this job, and the secret for authentication.
160
+ * @return 3-tuple (as a Java array) with the port number of a local socket which serves the
161
+ * data collected from this job, the secret for authentication, and a socket auth
162
+ * server object that can be used to join the JVM serving thread in Python.
161
163
*/
162
164
def collectAndServe [T ](rdd : RDD [T ]): Array [Any ] = {
163
165
serveIterator(rdd.collect().iterator, s " serve RDD ${rdd.id}" )
@@ -168,58 +170,59 @@ private[spark] object PythonRDD extends Logging {
168
170
* are collected as separate jobs, by order of index. Partition data is first requested by a
169
171
* non-zero integer to start a collection job. The response is prefaced by an integer with 1
170
172
* meaning partition data will be served, 0 meaning the local iterator has been consumed,
171
- * and -1 meaining an error occurred during collection. This function is used by
173
+ * and -1 meaning an error occurred during collection. This function is used by
172
174
* pyspark.rdd._local_iterator_from_socket().
173
175
*
174
- * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
175
- * data collected from these jobs, and the secret for authentication.
176
+ * @return 3-tuple (as a Java array) with the port number of a local socket which serves the
177
+ * data collected from this job, the secret for authentication, and a socket auth
178
+ * server object that can be used to join the JVM serving thread in Python.
176
179
*/
177
180
def toLocalIteratorAndServe [T ](rdd : RDD [T ]): Array [Any ] = {
178
- val (port, secret) = SocketAuthServer .setupOneConnectionServer(
179
- authHelper, " serve toLocalIterator" ) { s =>
180
- val out = new DataOutputStream (s.getOutputStream)
181
- val in = new DataInputStream (s.getInputStream)
182
- Utils .tryWithSafeFinally {
183
-
181
+ val handleFunc = (sock : Socket ) => {
182
+ val out = new DataOutputStream (sock.getOutputStream)
183
+ val in = new DataInputStream (sock.getInputStream)
184
+ Utils .tryWithSafeFinallyAndFailureCallbacks(block = {
184
185
// Collects a partition on each iteration
185
186
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
186
187
rdd.sparkContext.runJob(rdd, (iter : Iterator [Any ]) => iter.toArray, Seq (i)).head
187
188
}
188
189
189
- // Read request for data and send next partition if nonzero
190
+ // Write data until iteration is complete, client stops iteration, or error occurs
190
191
var complete = false
191
- while (! complete && in.readInt() != 0 ) {
192
- if (collectPartitionIter.hasNext) {
193
- try {
194
- // Attempt to collect the next partition
195
- val partitionArray = collectPartitionIter.next()
196
-
197
- // Send response there is a partition to read
198
- out.writeInt(1 )
199
-
200
- // Write the next object and signal end of data for this iteration
201
- writeIteratorToStream(partitionArray.toIterator, out)
202
- out.writeInt(SpecialLengths .END_OF_DATA_SECTION )
203
- out.flush()
204
- } catch {
205
- case e : SparkException =>
206
- // Send response that an error occurred followed by error message
207
- out.writeInt(- 1 )
208
- writeUTF(e.getMessage, out)
209
- complete = true
210
- }
192
+ while (! complete) {
193
+
194
+ // Read request for data, value of zero will stop iteration or non-zero to continue
195
+ if (in.readInt() == 0 ) {
196
+ complete = true
197
+ } else if (collectPartitionIter.hasNext) {
198
+
199
+ // Client requested more data, attempt to collect the next partition
200
+ val partitionArray = collectPartitionIter.next()
201
+
202
+ // Send response there is a partition to read
203
+ out.writeInt(1 )
204
+
205
+ // Write the next object and signal end of data for this iteration
206
+ writeIteratorToStream(partitionArray.toIterator, out)
207
+ out.writeInt(SpecialLengths .END_OF_DATA_SECTION )
208
+ out.flush()
211
209
} else {
212
210
// Send response there are no more partitions to read and close
213
211
out.writeInt(0 )
214
212
complete = true
215
213
}
216
214
}
217
- } {
215
+ })(catchBlock = {
216
+ // Send response that an error occurred, original exception is re-thrown
217
+ out.writeInt(- 1 )
218
+ }, finallyBlock = {
218
219
out.close()
219
220
in.close()
220
- }
221
+ })
221
222
}
222
- Array (port, secret)
223
+
224
+ val server = new SocketFuncServer (authHelper, " serve toLocalIterator" , handleFunc)
225
+ Array (server.port, server.secret, server)
223
226
}
224
227
225
228
def readRDDFromFile (
@@ -443,8 +446,9 @@ private[spark] object PythonRDD extends Logging {
443
446
*
444
447
* The thread will terminate after all the data are sent or any exceptions happen.
445
448
*
446
- * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
447
- * data collected from this job, and the secret for authentication.
449
+ * @return 3-tuple (as a Java array) with the port number of a local socket which serves the
450
+ * data collected from this job, the secret for authentication, and a socket auth
451
+ * server object that can be used to join the JVM serving thread in Python.
448
452
*/
449
453
def serveIterator (items : Iterator [_], threadName : String ): Array [Any ] = {
450
454
serveToStream(threadName) { out =>
@@ -464,10 +468,14 @@ private[spark] object PythonRDD extends Logging {
464
468
*
465
469
* The thread will terminate after the block of code is executed or any
466
470
* exceptions happen.
471
+ *
472
+ * @return 3-tuple (as a Java array) with the port number of a local socket which serves the
473
+ * data collected from this job, the secret for authentication, and a socket auth
474
+ * server object that can be used to join the JVM serving thread in Python.
467
475
*/
468
476
private [spark] def serveToStream (
469
477
threadName : String )(writeFunc : OutputStream => Unit ): Array [Any ] = {
470
- SocketAuthHelper .serveToStream(threadName, authHelper)(writeFunc)
478
+ SocketAuthServer .serveToStream(threadName, authHelper)(writeFunc)
471
479
}
472
480
473
481
private def getMergedConf (confAsMap : java.util.HashMap [String , String ],
0 commit comments