Skip to content

Commit ac61f7d

Browse files
committed
[SPARK-27893][SQL][PYTHON][FOLLOW-UP] Allow Scalar Pandas and Python UDFs can be tested with Scala test base
## What changes were proposed in this pull request? After this PR, we can test Pandas and Python UDF as below **in Scala side**: ```scala import IntegratedUDFTestUtils._ val pandasTestUDF = TestScalarPandasUDF("udf") spark.range(10).select(pandasTestUDF($"id")).show() ``` ## How was this patch tested? Manually tested. Closes apache#24945 from HyukjinKwon/SPARK-27893-followup. Authored-by: HyukjinKwon <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 1d36b89 commit ac61f7d

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,32 +40,37 @@ import org.apache.spark.sql.types.StringType
4040
*
4141
* To register Scala UDF in SQL:
4242
* {{{
43-
* registerTestUDF(TestScalaUDF(name = "udf_name"), spark)
43+
* val scalaTestUDF = TestScalaUDF(name = "udf_name")
44+
* registerTestUDF(scalaTestUDF, spark)
4445
* }}}
4546
*
4647
* To register Python UDF in SQL:
4748
* {{{
48-
* registerTestUDF(TestPythonUDF(name = "udf_name"), spark)
49+
* val pythonTestUDF = TestPythonUDF(name = "udf_name")
50+
* registerTestUDF(pythonTestUDF, spark)
4951
* }}}
5052
*
5153
* To register Scalar Pandas UDF in SQL:
5254
* {{{
53-
* registerTestUDF(TestScalarPandasUDF(name = "udf_name"), spark)
55+
* val pandasTestUDF = TestScalarPandasUDF(name = "udf_name")
56+
* registerTestUDF(pandasTestUDF, spark)
5457
* }}}
5558
*
5659
* To use it in Scala API and SQL:
5760
* {{{
5861
* sql("SELECT udf_name(1)")
59-
* spark.select(expr("udf_name(1)")
62+
* spark.range(10).select(expr("udf_name(id)")
63+
* spark.range(10).select(pandasTestUDF($"id"))
6064
* }}}
6165
*/
6266
object IntegratedUDFTestUtils extends SQLHelper {
6367
import scala.sys.process._
6468

6569
private lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
6670
private lazy val sparkHome = if (sys.props.contains(Tests.IS_TESTING.key)) {
67-
assert(sys.props.contains("spark.test.home"), "spark.test.home is not set.")
68-
sys.props("spark.test.home")
71+
assert(sys.props.contains("spark.test.home") ||
72+
sys.env.contains("SPARK_HOME"), "spark.test.home or SPARK_HOME is not set.")
73+
sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
6974
} else {
7075
assert(sys.env.contains("SPARK_HOME"), "SPARK_HOME is not set.")
7176
sys.env("SPARK_HOME")
@@ -186,14 +191,18 @@ object IntegratedUDFTestUtils extends SQLHelper {
186191
/**
187192
* A base trait for various UDFs defined in this object.
188193
*/
189-
sealed trait TestUDF
194+
sealed trait TestUDF {
195+
def apply(exprs: Column*): Column
196+
197+
val prettyName: String
198+
}
190199

191200
/**
192201
* A Python UDF that takes one column and returns a string column.
193202
* Equivalent to `udf(lambda x: str(x), "string")`
194203
*/
195204
case class TestPythonUDF(name: String) extends TestUDF {
196-
lazy val udf = UserDefinedPythonFunction(
205+
private[IntegratedUDFTestUtils] lazy val udf = UserDefinedPythonFunction(
197206
name = name,
198207
func = PythonFunction(
199208
command = pythonFunc,
@@ -206,14 +215,18 @@ object IntegratedUDFTestUtils extends SQLHelper {
206215
dataType = StringType,
207216
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
208217
udfDeterministic = true)
218+
219+
def apply(exprs: Column*): Column = udf(exprs: _*)
220+
221+
val prettyName: String = "Regular Python UDF"
209222
}
210223

211224
/**
212225
* A Scalar Pandas UDF that takes one column and returns a string column.
213226
* Equivalent to `pandas_udf(lambda x: x.apply(str), "string", PandasUDFType.SCALAR)`.
214227
*/
215228
case class TestScalarPandasUDF(name: String) extends TestUDF {
216-
lazy val udf = UserDefinedPythonFunction(
229+
private[IntegratedUDFTestUtils] lazy val udf = UserDefinedPythonFunction(
217230
name = name,
218231
func = PythonFunction(
219232
command = pandasFunc,
@@ -226,17 +239,25 @@ object IntegratedUDFTestUtils extends SQLHelper {
226239
dataType = StringType,
227240
pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF,
228241
udfDeterministic = true)
242+
243+
def apply(exprs: Column*): Column = udf(exprs: _*)
244+
245+
val prettyName: String = "Scalar Pandas UDF"
229246
}
230247

231248
/**
232249
* A Scala UDF that takes one column and returns a string column.
233250
* Equivalent to `udf((input: Any) => input.toString)`.
234251
*/
235252
case class TestScalaUDF(name: String) extends TestUDF {
236-
lazy val udf = SparkUserDefinedFunction(
253+
private[IntegratedUDFTestUtils] lazy val udf = SparkUserDefinedFunction(
237254
(input: Any) => input.toString,
238255
StringType,
239256
inputSchemas = Seq.fill(1)(None))
257+
258+
def apply(exprs: Column*): Column = udf(exprs: _*)
259+
260+
val prettyName: String = "Scala UDF"
240261
}
241262

242263
/**

sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -383,24 +383,13 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
383383
val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator)
384384

385385
if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) {
386-
Seq(
386+
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf =>
387387
UDFTestCase(
388-
s"$testCaseName - Scala UDF",
388+
s"$testCaseName - ${udf.prettyName}",
389389
absPath,
390390
resultFile,
391-
TestScalaUDF(name = "udf")),
392-
393-
UDFTestCase(
394-
s"$testCaseName - Python UDF",
395-
absPath,
396-
resultFile,
397-
TestPythonUDF(name = "udf")),
398-
399-
UDFTestCase(
400-
s"$testCaseName - Scalar Pandas UDF",
401-
absPath,
402-
resultFile,
403-
TestScalarPandasUDF(name = "udf")))
391+
udf)
392+
}
404393
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}pgSQL")) {
405394
PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
406395
} else {

0 commit comments

Comments
 (0)