Skip to content

Commit 7c2c84a

Browse files
committed
[SPARK-53243][PYTHON][SQL] List the supported eval types in arrow nodes
### What changes were proposed in this pull request? List the supported eval types in arrow nodes ### Why are the changes needed? validate the eval types and make code more readability ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #51970 from zhengruifeng/arrow_check_eval_type. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 7fe2f5e commit 7c2c84a

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import java.io.File
2121

2222
import scala.collection.mutable.ArrayBuffer
2323

24-
import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
25-
import org.apache.spark.api.python.ChainedPythonFunctions
24+
import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException, TaskContext}
25+
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2626
import org.apache.spark.rdd.RDD
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions._
@@ -47,8 +47,10 @@ case class ArrowAggregatePythonExec(
4747
aggExpressions: Seq[AggregateExpression],
4848
resultExpressions: Seq[NamedExpression],
4949
child: SparkPlan,
50-
evalType: Int)
51-
extends UnaryExecNode with PythonSQLMetrics {
50+
evalType: Int) extends UnaryExecNode with PythonSQLMetrics {
51+
if (!supportedPythonEvalTypes.contains(evalType)) {
52+
throw SparkException.internalError(s"Unexpected eval type $evalType")
53+
}
5254

5355
override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
5456

@@ -217,6 +219,11 @@ case class ArrowAggregatePythonExec(
217219

218220
newIter
219221
}
222+
223+
private def supportedPythonEvalTypes: Array[Int] =
224+
Array(
225+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
226+
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
220227
}
221228

222229
object ArrowAggregatePythonExec {

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.python
1919

2020
import scala.jdk.CollectionConverters._
2121

22-
import org.apache.spark.{JobArtifactSet, TaskContext}
23-
import org.apache.spark.api.python.ChainedPythonFunctions
22+
import org.apache.spark.{JobArtifactSet, SparkException, TaskContext}
23+
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -62,9 +62,14 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int)
6262
/**
6363
* A physical plan that evaluates a [[PythonUDF]].
6464
*/
65-
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan,
66-
evalType: Int)
67-
extends EvalPythonExec with PythonSQLMetrics {
65+
case class ArrowEvalPythonExec(
66+
udfs: Seq[PythonUDF],
67+
resultAttrs: Seq[Attribute],
68+
child: SparkPlan,
69+
evalType: Int) extends EvalPythonExec with PythonSQLMetrics {
70+
if (!supportedPythonEvalTypes.contains(evalType)) {
71+
throw SparkException.internalError(s"Unexpected eval type $evalType")
72+
}
6873

6974
private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
7075

@@ -85,6 +90,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
8590

8691
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
8792
copy(child = newChild)
93+
94+
private def supportedPythonEvalTypes: Array[Int] =
95+
Array(
96+
PythonEvalType.SQL_ARROW_BATCHED_UDF,
97+
PythonEvalType.SQL_SCALAR_ARROW_UDF,
98+
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
99+
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
100+
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
88101
}
89102

90103
class ArrowEvalPythonEvaluatorFactory(

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20+
import org.apache.spark.SparkException
21+
import org.apache.spark.api.python.PythonEvalType
2022
import org.apache.spark.rdd.RDD
2123
import org.apache.spark.sql.catalyst.InternalRow
2224
import org.apache.spark.sql.catalyst.expressions._
@@ -74,8 +76,11 @@ case class ArrowWindowPythonExec(
7476
partitionSpec: Seq[Expression],
7577
orderSpec: Seq[SortOrder],
7678
child: SparkPlan,
77-
evalType: Int)
78-
extends WindowExecBase with PythonSQLMetrics {
79+
evalType: Int) extends WindowExecBase with PythonSQLMetrics {
80+
if (!supportedPythonEvalTypes.contains(evalType)) {
81+
throw SparkException.internalError(s"Unexpected eval type $evalType")
82+
}
83+
7984
override lazy val metrics: Map[String, SQLMetric] = pythonMetrics ++ Map(
8085
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")
8186
)
@@ -105,6 +110,11 @@ case class ArrowWindowPythonExec(
105110

106111
override protected def withNewChildInternal(newChild: SparkPlan): ArrowWindowPythonExec =
107112
copy(child = newChild)
113+
114+
private def supportedPythonEvalTypes: Array[Int] =
115+
Array(
116+
PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
117+
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF)
108118
}
109119

110120
object ArrowWindowPythonExec {

0 commit comments

Comments
 (0)