Skip to content

Commit 6695a69

Browse files
gaogaotiantianHyukjinKwon
authored andcommitted
[SPARK-54701][PYTHON] Improve the runnerConf chain for Python workers
### What changes were proposed in this pull request? `runnerConf` now honors the parent `runnerConf`. It inherits the `runnerConf` instead of overwrite it. ### Why are the changes needed? To make it flexible for any class in the chain to add some extra configs for the runner. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #53462 from gaogaotiantian/improve-runnerConf-inherot. Authored-by: Tian Gao <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent b45f071 commit 6695a69

File tree

6 files changed

+23
-11
lines changed

6 files changed

+23
-11
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
214214
protected val hideTraceback: Boolean = false
215215
protected val simplifiedTraceback: Boolean = false
216216

217-
protected val runnerConf: Map[String, String] = Map.empty
217+
protected def runnerConf: Map[String, String] = Map.empty
218218

219219
// All the Python functions should have the same exec, version and envvars.
220220
protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class ArrowPythonRunner(
104104
_schema: StructType,
105105
_timeZoneId: String,
106106
largeVarTypes: Boolean,
107-
protected override val runnerConf: Map[String, String],
107+
pythonRunnerConf: Map[String, String],
108108
pythonMetrics: Map[String, SQLMetric],
109109
jobArtifactUUID: Option[String],
110110
sessionUUID: Option[String],
@@ -113,6 +113,8 @@ class ArrowPythonRunner(
113113
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
114114
pythonMetrics, jobArtifactUUID, sessionUUID) {
115115

116+
override protected def runnerConf: Map[String, String] = super.runnerConf ++ pythonRunnerConf
117+
116118
override protected def writeUDF(dataOut: DataOutputStream): Unit =
117119
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
118120
}
@@ -128,7 +130,7 @@ class ArrowPythonWithNamedArgumentRunner(
128130
_schema: StructType,
129131
_timeZoneId: String,
130132
largeVarTypes: Boolean,
131-
protected override val runnerConf: Map[String, String],
133+
pythonRunnerConf: Map[String, String],
132134
pythonMetrics: Map[String, SQLMetric],
133135
jobArtifactUUID: Option[String],
134136
sessionUUID: Option[String],
@@ -137,6 +139,8 @@ class ArrowPythonWithNamedArgumentRunner(
137139
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes,
138140
pythonMetrics, jobArtifactUUID, sessionUUID) {
139141

142+
override protected def runnerConf: Map[String, String] = super.runnerConf ++ pythonRunnerConf
143+
140144
override protected def writeUDF(dataOut: DataOutputStream): Unit = {
141145
if (evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) {
142146
PythonWorkerUtils.writeUTF(schema.json, dataOut)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner(
3939
protected override val schema: StructType,
4040
protected override val timeZoneId: String,
4141
protected override val largeVarTypes: Boolean,
42-
protected override val runnerConf: Map[String, String],
42+
pythonRunnerConf: Map[String, String],
4343
override val pythonMetrics: Map[String, SQLMetric],
4444
jobArtifactUUID: Option[String],
4545
sessionUUID: Option[String])
@@ -49,6 +49,8 @@ class ArrowPythonUDTFRunner(
4949
with BatchedPythonArrowInput
5050
with BasicPythonArrowOutput {
5151

52+
override protected def runnerConf: Map[String, String] = super.runnerConf ++ pythonRunnerConf
53+
5254
override protected def writeUDF(dataOut: DataOutputStream): Unit = {
5355
// For arrow-optimized Python UDTFs (@udtf(useArrow=True)), we need to write
5456
// the schema to the worker to support UDT (user-defined type).

sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class CoGroupedArrowPythonRunner(
4545
rightSchema: StructType,
4646
timeZoneId: String,
4747
largeVarTypes: Boolean,
48-
protected override val runnerConf: Map[String, String],
48+
pythonRunnerConf: Map[String, String],
4949
override val pythonMetrics: Map[String, SQLMetric],
5050
jobArtifactUUID: Option[String],
5151
sessionUUID: Option[String],
@@ -55,6 +55,8 @@ class CoGroupedArrowPythonRunner(
5555
funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics)
5656
with BasicPythonArrowOutput {
5757

58+
override protected def runnerConf: Map[String, String] = super.runnerConf ++ pythonRunnerConf
59+
5860
override val envVars: util.Map[String, String] = {
5961
val envVars = new util.HashMap(funcs.head._1.funcs.head.envVars)
6062
sessionUUID.foreach { uuid =>

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,11 @@ class ApplyInPandasWithStatePythonRunner(
113113
// applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
114114
// Configurations are both applied to executor and Python worker, set them to the worker conf
115115
// to let Python worker read the config properly.
116-
override protected val runnerConf: Map[String, String] = initialRunnerConf +
117-
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
118-
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)
116+
override protected def runnerConf: Map[String, String] =
117+
super.runnerConf ++ initialRunnerConf ++ Map(
118+
SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString,
119+
SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString
120+
)
119121

120122
private val stateRowDeserializer = stateEncoder.createDeserializer()
121123

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,11 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
238238
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
239239
protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
240240

241-
override protected val runnerConf: Map[String, String] = initialRunnerConf +
242-
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
243-
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)
241+
override protected def runnerConf: Map[String, String] =
242+
super.runnerConf ++ initialRunnerConf ++ Map(
243+
SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString,
244+
SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString
245+
)
244246

245247
// Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s
246248
// constructor.

0 commit comments

Comments
 (0)