Skip to content

Commit d4de913

Browse files
gaogaotiantianHyukjinKwon
authored andcommitted
[SPARK-54615][PYTHON] Always pass runner_conf to python worker
### What changes were proposed in this pull request? Always pass runnerConf to python worker, even if it's not used. ### Why are the changes needed? This is part of the effort to consolidate our protocol from JVM to the worker. We have different ways to pass the runner conf now and sometimes we just don't pass it. It makes the worker side code a bit messy - we need to determine whether to read the conf based on eval type. However reading an empty conf is super cheap and we can just do it regardless. With this infra, vanilla python udfs can also pass some runner conf in the future. We can do some refactoring on our JVM worker code in the future. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? `pyspark-sql` passed locally. Running the rest on CI. ### Was this patch authored or co-authored using generative AI tooling? No Closes #53353 from gaogaotiantian/always-pass-runnerconf. Authored-by: Tian Gao <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 44efa48 commit d4de913

File tree

8 files changed

+43
-51
lines changed

8 files changed

+43
-51
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
212212
protected val hideTraceback: Boolean = false
213213
protected val simplifiedTraceback: Boolean = false
214214

215+
protected val runnerConf: Map[String, String] = Map.empty
216+
215217
// All the Python functions should have the same exec, version and envvars.
216218
protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars
217219
protected val pythonExec: String = funcs.head.funcs.head.pythonExec
@@ -403,6 +405,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
403405
*/
404406
protected def writeCommand(dataOut: DataOutputStream): Unit
405407

408+
/**
409+
* Writes worker configuration to the stream connected to the Python worker.
410+
*/
411+
protected def writeRunnerConf(dataOut: DataOutputStream): Unit = {
412+
dataOut.writeInt(runnerConf.size)
413+
for ((k, v) <- runnerConf) {
414+
PythonWorkerUtils.writeUTF(k, dataOut)
415+
PythonWorkerUtils.writeUTF(v, dataOut)
416+
}
417+
}
418+
406419
/**
407420
* Writes input data to the stream connected to the Python worker.
408421
* Returns true if any data was written to the stream, false if the input is exhausted.
@@ -532,6 +545,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
532545
PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
533546

534547
dataOut.writeInt(evalType)
548+
writeRunnerConf(dataOut)
535549
writeCommand(dataOut)
536550

537551
dataOut.flush()

python/pyspark/worker.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,10 +1514,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
15141514
# It expects the UDTF to be in a specific format and performs various checks to
15151515
# ensure the UDTF is valid. This function also prepares a mapper function for applying
15161516
# the UDTF logic to input rows.
1517-
def read_udtf(pickleSer, infile, eval_type):
1517+
def read_udtf(pickleSer, infile, eval_type, runner_conf):
15181518
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
1519-
# Load conf used for arrow evaluation.
1520-
runner_conf = RunnerConf(infile)
15211519
input_types = [
15221520
field.dataType for field in _parse_datatype_json_string(utf8_deserializer.loads(infile))
15231521
]
@@ -1532,15 +1530,13 @@ def read_udtf(pickleSer, infile, eval_type):
15321530
else:
15331531
ser = ArrowStreamUDTFSerializer()
15341532
elif eval_type == PythonEvalType.SQL_ARROW_UDTF:
1535-
runner_conf = RunnerConf(infile)
15361533
# Read the table argument offsets
15371534
num_table_arg_offsets = read_int(infile)
15381535
table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)]
15391536
# Use PyArrow-native serializer for Arrow UDTFs with potential UDT support
15401537
ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets)
15411538
else:
15421539
# Each row is a group so do not batch but send one by one.
1543-
runner_conf = RunnerConf()
15441540
ser = BatchedSerializer(CPickleSerializer(), 1)
15451541

15461542
# See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
@@ -2688,7 +2684,7 @@ def mapper(_, it):
26882684
return mapper, None, ser, ser
26892685

26902686

2691-
def read_udfs(pickleSer, infile, eval_type):
2687+
def read_udfs(pickleSer, infile, eval_type, runner_conf):
26922688
state_server_port = None
26932689
key_schema = None
26942690
if eval_type in (
@@ -2716,9 +2712,6 @@ def read_udfs(pickleSer, infile, eval_type):
27162712
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
27172713
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
27182714
):
2719-
# Load conf used for pandas_udf evaluation
2720-
runner_conf = RunnerConf(infile)
2721-
27222715
state_object_schema = None
27232716
if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
27242717
state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
@@ -2870,7 +2863,6 @@ def read_udfs(pickleSer, infile, eval_type):
28702863
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
28712864
)
28722865
else:
2873-
runner_conf = RunnerConf()
28742866
batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100"))
28752867
ser = BatchedSerializer(CPickleSerializer(), batch_size)
28762868

@@ -3353,16 +3345,21 @@ def main(infile, outfile):
33533345

33543346
_accumulatorRegistry.clear()
33553347
eval_type = read_int(infile)
3348+
runner_conf = RunnerConf(infile)
33563349
if eval_type == PythonEvalType.NON_UDF:
33573350
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
33583351
elif eval_type in (
33593352
PythonEvalType.SQL_TABLE_UDF,
33603353
PythonEvalType.SQL_ARROW_TABLE_UDF,
33613354
PythonEvalType.SQL_ARROW_UDTF,
33623355
):
3363-
func, profiler, deserializer, serializer = read_udtf(pickleSer, infile, eval_type)
3356+
func, profiler, deserializer, serializer = read_udtf(
3357+
pickleSer, infile, eval_type, runner_conf
3358+
)
33643359
else:
3365-
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
3360+
func, profiler, deserializer, serializer = read_udfs(
3361+
pickleSer, infile, eval_type, runner_conf
3362+
)
33663363

33673364
init_time = time.time()
33683365

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
3535
_schema: StructType,
3636
_timeZoneId: String,
3737
protected override val largeVarTypes: Boolean,
38-
protected override val workerConf: Map[String, String],
3938
override val pythonMetrics: Map[String, SQLMetric],
4039
jobArtifactUUID: Option[String],
4140
sessionUUID: Option[String])
@@ -86,12 +85,11 @@ abstract class RowInputArrowPythonRunner(
8685
_schema: StructType,
8786
_timeZoneId: String,
8887
largeVarTypes: Boolean,
89-
workerConf: Map[String, String],
9088
pythonMetrics: Map[String, SQLMetric],
9189
jobArtifactUUID: Option[String],
9290
sessionUUID: Option[String])
9391
extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch](
94-
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf,
92+
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
9593
pythonMetrics, jobArtifactUUID, sessionUUID)
9694
with BasicPythonArrowInput
9795
with BasicPythonArrowOutput
@@ -106,13 +104,13 @@ class ArrowPythonRunner(
106104
_schema: StructType,
107105
_timeZoneId: String,
108106
largeVarTypes: Boolean,
109-
workerConf: Map[String, String],
107+
protected override val runnerConf: Map[String, String],
110108
pythonMetrics: Map[String, SQLMetric],
111109
jobArtifactUUID: Option[String],
112110
sessionUUID: Option[String],
113111
profiler: Option[String])
114112
extends RowInputArrowPythonRunner(
115-
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf,
113+
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
116114
pythonMetrics, jobArtifactUUID, sessionUUID) {
117115

118116
override protected def writeUDF(dataOut: DataOutputStream): Unit =
@@ -130,13 +128,13 @@ class ArrowPythonWithNamedArgumentRunner(
130128
_schema: StructType,
131129
_timeZoneId: String,
132130
largeVarTypes: Boolean,
133-
workerConf: Map[String, String],
131+
protected override val runnerConf: Map[String, String],
134132
pythonMetrics: Map[String, SQLMetric],
135133
jobArtifactUUID: Option[String],
136134
sessionUUID: Option[String],
137135
profiler: Option[String])
138136
extends RowInputArrowPythonRunner(
139-
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf,
137+
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes,
140138
pythonMetrics, jobArtifactUUID, sessionUUID) {
141139

142140
override protected def writeUDF(dataOut: DataOutputStream): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner(
3939
protected override val schema: StructType,
4040
protected override val timeZoneId: String,
4141
protected override val largeVarTypes: Boolean,
42-
protected override val workerConf: Map[String, String],
42+
protected override val runnerConf: Map[String, String],
4343
override val pythonMetrics: Map[String, SQLMetric],
4444
jobArtifactUUID: Option[String],
4545
sessionUUID: Option[String])

sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
2525
import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec}
2626

2727
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
28-
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker}
28+
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonWorker}
2929
import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper
3131
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -45,7 +45,7 @@ class CoGroupedArrowPythonRunner(
4545
rightSchema: StructType,
4646
timeZoneId: String,
4747
largeVarTypes: Boolean,
48-
conf: Map[String, String],
48+
protected override val runnerConf: Map[String, String],
4949
override val pythonMetrics: Map[String, SQLMetric],
5050
jobArtifactUUID: Option[String],
5151
sessionUUID: Option[String],
@@ -119,14 +119,6 @@ class CoGroupedArrowPythonRunner(
119119
private var rightGroupArrowWriter: ArrowWriterWrapper = null
120120

121121
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
122-
123-
// Write config for the worker as a number of key -> value pairs of strings
124-
dataOut.writeInt(conf.size)
125-
for ((k, v) <- conf) {
126-
PythonRDD.writeUTF(k, dataOut)
127-
PythonRDD.writeUTF(v, dataOut)
128-
}
129-
130122
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
131123
}
132124

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.arrow.vector.ipc.WriteChannel
2727
import org.apache.arrow.vector.ipc.message.MessageSerializer
2828

2929
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
30-
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
30+
import org.apache.spark.api.python.{BasePythonRunner, PythonWorker}
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.execution.arrow
3333
import org.apache.spark.sql.execution.arrow.{ArrowWriter, ArrowWriterWrapper}
@@ -42,8 +42,6 @@ import org.apache.spark.util.Utils
4242
* JVM (an iterator of internal rows + additional data if required) to Python (Arrow).
4343
*/
4444
private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
45-
protected val workerConf: Map[String, String]
46-
4745
protected val schema: StructType
4846

4947
protected val timeZoneId: String
@@ -62,14 +60,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
6260

6361
protected def writeUDF(dataOut: DataOutputStream): Unit
6462

65-
protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
66-
// Write config for the worker as a number of key -> value pairs of strings
67-
stream.writeInt(workerConf.size)
68-
for ((k, v) <- workerConf) {
69-
PythonRDD.writeUTF(k, stream)
70-
PythonRDD.writeUTF(v, stream)
71-
}
72-
}
63+
protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {}
64+
7365
private val arrowSchema = ArrowUtils.toArrowSchema(
7466
schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
7567
protected val allocator =
@@ -301,7 +293,6 @@ private[python] trait GroupedPythonArrowInput { self: RowInputArrowPythonRunner
301293
context: TaskContext): Writer = {
302294
new Writer(env, worker, inputIterator, partitionIndex, context) {
303295
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
304-
handleMetadataBeforeExec(dataOut)
305296
writeUDF(dataOut)
306297
}
307298

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class ApplyInPandasWithStatePythonRunner(
5858
argOffsets: Array[Array[Int]],
5959
inputSchema: StructType,
6060
_timeZoneId: String,
61-
initialWorkerConf: Map[String, String],
61+
initialRunnerConf: Map[String, String],
6262
stateEncoder: ExpressionEncoder[Row],
6363
keySchema: StructType,
6464
outputSchema: StructType,
@@ -113,7 +113,7 @@ class ApplyInPandasWithStatePythonRunner(
113113
// applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
114114
// Configurations are both applied to executor and Python worker, set them to the worker conf
115115
// to let Python worker read the config properly.
116-
override protected val workerConf: Map[String, String] = initialWorkerConf +
116+
override protected val runnerConf: Map[String, String] = initialRunnerConf +
117117
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
118118
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)
119119

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ class TransformWithStateInPySparkPythonRunner(
5252
_schema: StructType,
5353
processorHandle: StatefulProcessorHandleImpl,
5454
_timeZoneId: String,
55-
initialWorkerConf: Map[String, String],
55+
initialRunnerConf: Map[String, String],
5656
override val pythonMetrics: Map[String, SQLMetric],
5757
jobArtifactUUID: Option[String],
5858
groupingKeySchema: StructType,
5959
batchTimestampMs: Option[Long],
6060
eventTimeWatermarkForEviction: Option[Long])
6161
extends TransformWithStateInPySparkPythonBaseRunner[InType](
6262
funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
63-
initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
63+
initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
6464
batchTimestampMs, eventTimeWatermarkForEviction)
6565
with PythonArrowInput[InType] {
6666

@@ -126,15 +126,15 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
126126
initStateSchema: StructType,
127127
processorHandle: StatefulProcessorHandleImpl,
128128
_timeZoneId: String,
129-
initialWorkerConf: Map[String, String],
129+
initialRunnerConf: Map[String, String],
130130
override val pythonMetrics: Map[String, SQLMetric],
131131
jobArtifactUUID: Option[String],
132132
groupingKeySchema: StructType,
133133
batchTimestampMs: Option[Long],
134134
eventTimeWatermarkForEviction: Option[Long])
135135
extends TransformWithStateInPySparkPythonBaseRunner[GroupedInType](
136136
funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
137-
initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
137+
initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
138138
batchTimestampMs, eventTimeWatermarkForEviction)
139139
with PythonArrowInput[GroupedInType] {
140140

@@ -221,7 +221,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
221221
_schema: StructType,
222222
processorHandle: StatefulProcessorHandleImpl,
223223
_timeZoneId: String,
224-
initialWorkerConf: Map[String, String],
224+
initialRunnerConf: Map[String, String],
225225
override val pythonMetrics: Map[String, SQLMetric],
226226
jobArtifactUUID: Option[String],
227227
groupingKeySchema: StructType,
@@ -238,7 +238,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
238238
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
239239
protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
240240

241-
override protected val workerConf: Map[String, String] = initialWorkerConf +
241+
override protected val runnerConf: Map[String, String] = initialRunnerConf +
242242
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
243243
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)
244244

@@ -251,7 +251,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
251251

252252
override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
253253
super.handleMetadataBeforeExec(stream)
254-
// Also write the port/path number for state server
254+
// Write the port/path number for state server
255255
if (isUnixDomainSock) {
256256
stream.writeInt(-1)
257257
PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)

0 commit comments

Comments
 (0)