@@ -40,32 +40,37 @@ import org.apache.spark.sql.types.StringType
40
40
*
41
41
* To register Scala UDF in SQL:
42
42
* {{{
43
- * registerTestUDF(TestScalaUDF(name = "udf_name"), spark)
43
+ * val scalaTestUDF = TestScalaUDF(name = "udf_name")
44
+ * registerTestUDF(scalaTestUDF, spark)
44
45
* }}}
45
46
*
46
47
* To register Python UDF in SQL:
47
48
* {{{
48
- * registerTestUDF(TestPythonUDF(name = "udf_name"), spark)
49
+ * val pythonTestUDF = TestPythonUDF(name = "udf_name")
50
+ * registerTestUDF(pythonTestUDF, spark)
49
51
* }}}
50
52
*
51
53
* To register Scalar Pandas UDF in SQL:
52
54
* {{{
53
- * registerTestUDF(TestScalarPandasUDF(name = "udf_name"), spark)
55
+ * val pandasTestUDF = TestScalarPandasUDF(name = "udf_name")
56
+ * registerTestUDF(pandasTestUDF, spark)
54
57
* }}}
55
58
*
56
59
* To use it in Scala API and SQL:
57
60
* {{{
58
61
* sql("SELECT udf_name(1)")
59
- * spark.select(expr("udf_name(1)")
62
+ * spark.range(10).select(expr("udf_name(id)")
63
+ * spark.range(10).select(pandasTestUDF($"id"))
60
64
* }}}
61
65
*/
62
66
object IntegratedUDFTestUtils extends SQLHelper {
63
67
import scala .sys .process ._
64
68
65
69
private lazy val pythonPath = sys.env.getOrElse(" PYTHONPATH" , " " )
66
70
private lazy val sparkHome = if (sys.props.contains(Tests .IS_TESTING .key)) {
67
- assert(sys.props.contains(" spark.test.home" ), " spark.test.home is not set." )
68
- sys.props(" spark.test.home" )
71
+ assert(sys.props.contains(" spark.test.home" ) ||
72
+ sys.env.contains(" SPARK_HOME" ), " spark.test.home or SPARK_HOME is not set." )
73
+ sys.props.getOrElse(" spark.test.home" , sys.env(" SPARK_HOME" ))
69
74
} else {
70
75
assert(sys.env.contains(" SPARK_HOME" ), " SPARK_HOME is not set." )
71
76
sys.env(" SPARK_HOME" )
@@ -186,14 +191,18 @@ object IntegratedUDFTestUtils extends SQLHelper {
186
191
/**
187
192
* A base trait for various UDFs defined in this object.
188
193
*/
189
- sealed trait TestUDF
194
+ sealed trait TestUDF {
195
+ def apply (exprs : Column * ): Column
196
+
197
+ val prettyName : String
198
+ }
190
199
191
200
/**
192
201
* A Python UDF that takes one column and returns a string column.
193
202
* Equivalent to `udf(lambda x: str(x), "string")`
194
203
*/
195
204
case class TestPythonUDF (name : String ) extends TestUDF {
196
- lazy val udf = UserDefinedPythonFunction (
205
+ private [ IntegratedUDFTestUtils ] lazy val udf = UserDefinedPythonFunction (
197
206
name = name,
198
207
func = PythonFunction (
199
208
command = pythonFunc,
@@ -206,14 +215,18 @@ object IntegratedUDFTestUtils extends SQLHelper {
206
215
dataType = StringType ,
207
216
pythonEvalType = PythonEvalType .SQL_BATCHED_UDF ,
208
217
udfDeterministic = true )
218
+
219
+ def apply (exprs : Column * ): Column = udf(exprs : _* )
220
+
221
+ val prettyName : String = " Regular Python UDF"
209
222
}
210
223
211
224
/**
212
225
* A Scalar Pandas UDF that takes one column and returns a string column.
213
226
* Equivalent to `pandas_udf(lambda x: x.apply(str), "string", PandasUDFType.SCALAR)`.
214
227
*/
215
228
case class TestScalarPandasUDF (name : String ) extends TestUDF {
216
- lazy val udf = UserDefinedPythonFunction (
229
+ private [ IntegratedUDFTestUtils ] lazy val udf = UserDefinedPythonFunction (
217
230
name = name,
218
231
func = PythonFunction (
219
232
command = pandasFunc,
@@ -226,17 +239,25 @@ object IntegratedUDFTestUtils extends SQLHelper {
226
239
dataType = StringType ,
227
240
pythonEvalType = PythonEvalType .SQL_SCALAR_PANDAS_UDF ,
228
241
udfDeterministic = true )
242
+
243
+ def apply (exprs : Column * ): Column = udf(exprs : _* )
244
+
245
+ val prettyName : String = " Scalar Pandas UDF"
229
246
}
230
247
231
248
/**
232
249
* A Scala UDF that takes one column and returns a string column.
233
250
* Equivalent to `udf((input: Any) => input.toString)`.
234
251
*/
235
252
case class TestScalaUDF (name : String ) extends TestUDF {
236
- lazy val udf = SparkUserDefinedFunction (
253
+ private [ IntegratedUDFTestUtils ] lazy val udf = SparkUserDefinedFunction (
237
254
(input : Any ) => input.toString,
238
255
StringType ,
239
256
inputSchemas = Seq .fill(1 )(None ))
257
+
258
+ def apply (exprs : Column * ): Column = udf(exprs : _* )
259
+
260
+ val prettyName : String = " Scala UDF"
240
261
}
241
262
242
263
/**
0 commit comments