Skip to content

Commit 656ece1

Browse files
panbingkunMaxGekk
authored andcommitted
[SPARK-50081][SQL] Codegen Support for XPath*(by Invoke & RuntimeReplaceable)
### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `xpath*`, include: - `xpath_boolean` - `xpath_short` - `xpath_int` - `xpath_long` - `xpath_float` - `xpath_double` - `xpath_string` - `xpath` ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: `XPathFunctionsSuite`, `XPathExpressionSuite`, `CollationSQLExpressionsSuite`#`*XPath*`, `CollationExpressionWalkerSuite`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48610 from panbingkun/xpath_codegen. Lead-authored-by: panbingkun <[email protected]> Co-authored-by: panbingkun <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 779a526 commit 656ece1

File tree

12 files changed

+162
-66
lines changed

12 files changed

+162
-66
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.xml
1919

20+
import org.apache.spark.sql.catalyst.util.GenericArrayData
2021
import org.apache.spark.sql.catalyst.xml.XmlInferSchema
2122
import org.apache.spark.sql.internal.SQLConf
22-
import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
23+
import org.apache.spark.sql.types._
2324
import org.apache.spark.unsafe.types.UTF8String
2425

2526
object XmlExpressionEvalUtils {
@@ -40,3 +41,82 @@ object XmlExpressionEvalUtils {
4041
UTF8String.fromString(dataType.sql)
4142
}
4243
}
44+
45+
trait XPathEvaluator {
46+
47+
protected val path: UTF8String
48+
49+
@transient protected lazy val xpathUtil: UDFXPathUtil = new UDFXPathUtil
50+
51+
final def evaluate(xml: UTF8String): Any = {
52+
if (xml == null || xml.toString.isEmpty || path == null || path.toString.isEmpty) return null
53+
doEvaluate(xml)
54+
}
55+
56+
def doEvaluate(xml: UTF8String): Any
57+
}
58+
59+
case class XPathBooleanEvaluator(path: UTF8String) extends XPathEvaluator {
60+
override def doEvaluate(xml: UTF8String): Any = {
61+
xpathUtil.evalBoolean(xml.toString, path.toString)
62+
}
63+
}
64+
65+
case class XPathShortEvaluator(path: UTF8String) extends XPathEvaluator {
66+
override def doEvaluate(xml: UTF8String): Any = {
67+
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
68+
if (ret eq null) null.asInstanceOf[Short] else ret.shortValue()
69+
}
70+
}
71+
72+
case class XPathIntEvaluator(path: UTF8String) extends XPathEvaluator {
73+
override def doEvaluate(xml: UTF8String): Any = {
74+
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
75+
if (ret eq null) null.asInstanceOf[Int] else ret.intValue()
76+
}
77+
}
78+
79+
case class XPathLongEvaluator(path: UTF8String) extends XPathEvaluator {
80+
override def doEvaluate(xml: UTF8String): Any = {
81+
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
82+
if (ret eq null) null.asInstanceOf[Long] else ret.longValue()
83+
}
84+
}
85+
86+
case class XPathFloatEvaluator(path: UTF8String) extends XPathEvaluator {
87+
override def doEvaluate(xml: UTF8String): Any = {
88+
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
89+
if (ret eq null) null.asInstanceOf[Float] else ret.floatValue()
90+
}
91+
}
92+
93+
case class XPathDoubleEvaluator(path: UTF8String) extends XPathEvaluator {
94+
override def doEvaluate(xml: UTF8String): Any = {
95+
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
96+
if (ret eq null) null.asInstanceOf[Double] else ret.doubleValue()
97+
}
98+
}
99+
100+
case class XPathStringEvaluator(path: UTF8String) extends XPathEvaluator {
101+
override def doEvaluate(xml: UTF8String): Any = {
102+
val ret = xpathUtil.evalString(xml.toString, path.toString)
103+
UTF8String.fromString(ret)
104+
}
105+
}
106+
107+
case class XPathListEvaluator(path: UTF8String) extends XPathEvaluator {
108+
override def doEvaluate(xml: UTF8String): Any = {
109+
val nodeList = xpathUtil.evalNodeList(xml.toString, path.toString)
110+
if (nodeList ne null) {
111+
val ret = new Array[AnyRef](nodeList.getLength)
112+
var i = 0
113+
while (i < nodeList.getLength) {
114+
ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue)
115+
i += 1
116+
}
117+
new GenericArrayData(ret)
118+
} else {
119+
null
120+
}
121+
}
122+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.Cast._
24-
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
25-
import org.apache.spark.sql.catalyst.util.GenericArrayData
24+
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
2625
import org.apache.spark.sql.internal.SQLConf
2726
import org.apache.spark.sql.internal.types.StringTypeWithCollation
2827
import org.apache.spark.sql.types._
@@ -34,10 +33,9 @@ import org.apache.spark.unsafe.types.UTF8String
3433
* This is not the world's most efficient implementation due to type conversion, but works.
3534
*/
3635
abstract class XPathExtract
37-
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
36+
extends BinaryExpression with RuntimeReplaceable with ExpectsInputTypes {
3837
override def left: Expression = xml
3938
override def right: Expression = path
40-
override def nullIntolerant: Boolean = true
4139

4240
/** XPath expressions are always nullable, e.g. if the xml string is empty. */
4341
override def nullable: Boolean = true
@@ -60,12 +58,20 @@ abstract class XPathExtract
6058
}
6159
}
6260

63-
@transient protected lazy val xpathUtil = new UDFXPathUtil
64-
@transient protected lazy val pathString: String = path.eval().asInstanceOf[UTF8String].toString
65-
6661
/** Concrete implementations need to override the following three methods. */
6762
def xml: Expression
6863
def path: Expression
64+
65+
@transient protected lazy val pathUTF8String: UTF8String = path.eval().asInstanceOf[UTF8String]
66+
67+
protected def evaluator: XPathEvaluator
68+
69+
override def replacement: Expression = Invoke(
70+
Literal.create(evaluator, ObjectType(classOf[XPathEvaluator])),
71+
"evaluate",
72+
dataType,
73+
Seq(xml),
74+
Seq(xml.dataType))
6975
}
7076

7177
// scalastyle:off line.size.limit
@@ -81,11 +87,9 @@ abstract class XPathExtract
8187
// scalastyle:on line.size.limit
8288
case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract with Predicate {
8389

84-
override def prettyName: String = "xpath_boolean"
90+
@transient override lazy val evaluator: XPathEvaluator = XPathBooleanEvaluator(pathUTF8String)
8591

86-
override def nullSafeEval(xml: Any, path: Any): Any = {
87-
xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString)
88-
}
92+
override def prettyName: String = "xpath_boolean"
8993

9094
override protected def withNewChildrenInternal(
9195
newLeft: Expression, newRight: Expression): XPathBoolean = copy(xml = newLeft, path = newRight)
@@ -103,14 +107,12 @@ case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract
103107
group = "xml_funcs")
104108
// scalastyle:on line.size.limit
105109
case class XPathShort(xml: Expression, path: Expression) extends XPathExtract {
110+
111+
@transient override lazy val evaluator: XPathEvaluator = XPathShortEvaluator(pathUTF8String)
112+
106113
override def prettyName: String = "xpath_short"
107114
override def dataType: DataType = ShortType
108115

109-
override def nullSafeEval(xml: Any, path: Any): Any = {
110-
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
111-
if (ret eq null) null else ret.shortValue()
112-
}
113-
114116
override protected def withNewChildrenInternal(
115117
newLeft: Expression, newRight: Expression): XPathShort = copy(xml = newLeft, path = newRight)
116118
}
@@ -127,14 +129,12 @@ case class XPathShort(xml: Expression, path: Expression) extends XPathExtract {
127129
group = "xml_funcs")
128130
// scalastyle:on line.size.limit
129131
case class XPathInt(xml: Expression, path: Expression) extends XPathExtract {
132+
133+
@transient override lazy val evaluator: XPathEvaluator = XPathIntEvaluator(pathUTF8String)
134+
130135
override def prettyName: String = "xpath_int"
131136
override def dataType: DataType = IntegerType
132137

133-
override def nullSafeEval(xml: Any, path: Any): Any = {
134-
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
135-
if (ret eq null) null else ret.intValue()
136-
}
137-
138138
override protected def withNewChildrenInternal(
139139
newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight)
140140
}
@@ -151,14 +151,12 @@ case class XPathInt(xml: Expression, path: Expression) extends XPathExtract {
151151
group = "xml_funcs")
152152
// scalastyle:on line.size.limit
153153
case class XPathLong(xml: Expression, path: Expression) extends XPathExtract {
154+
155+
@transient override lazy val evaluator: XPathEvaluator = XPathLongEvaluator(pathUTF8String)
156+
154157
override def prettyName: String = "xpath_long"
155158
override def dataType: DataType = LongType
156159

157-
override def nullSafeEval(xml: Any, path: Any): Any = {
158-
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
159-
if (ret eq null) null else ret.longValue()
160-
}
161-
162160
override protected def withNewChildrenInternal(
163161
newLeft: Expression, newRight: Expression): XPathLong = copy(xml = newLeft, path = newRight)
164162
}
@@ -175,14 +173,12 @@ case class XPathLong(xml: Expression, path: Expression) extends XPathExtract {
175173
group = "xml_funcs")
176174
// scalastyle:on line.size.limit
177175
case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
176+
177+
@transient override lazy val evaluator: XPathEvaluator = XPathFloatEvaluator(pathUTF8String)
178+
178179
override def prettyName: String = "xpath_float"
179180
override def dataType: DataType = FloatType
180181

181-
override def nullSafeEval(xml: Any, path: Any): Any = {
182-
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
183-
if (ret eq null) null else ret.floatValue()
184-
}
185-
186182
override protected def withNewChildrenInternal(
187183
newLeft: Expression, newRight: Expression): XPathFloat = copy(xml = newLeft, path = newRight)
188184
}
@@ -199,15 +195,13 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
199195
group = "xml_funcs")
200196
// scalastyle:on line.size.limit
201197
case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
198+
199+
@transient override lazy val evaluator: XPathEvaluator = XPathDoubleEvaluator(pathUTF8String)
200+
202201
override def prettyName: String =
203202
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("xpath_double")
204203
override def dataType: DataType = DoubleType
205204

206-
override def nullSafeEval(xml: Any, path: Any): Any = {
207-
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
208-
if (ret eq null) null else ret.doubleValue()
209-
}
210-
211205
override protected def withNewChildrenInternal(
212206
newLeft: Expression, newRight: Expression): XPathDouble = copy(xml = newLeft, path = newRight)
213207
}
@@ -224,14 +218,12 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
224218
group = "xml_funcs")
225219
// scalastyle:on line.size.limit
226220
case class XPathString(xml: Expression, path: Expression) extends XPathExtract {
221+
222+
@transient override lazy val evaluator: XPathEvaluator = XPathStringEvaluator(pathUTF8String)
223+
227224
override def prettyName: String = "xpath_string"
228225
override def dataType: DataType = SQLConf.get.defaultStringType
229226

230-
override def nullSafeEval(xml: Any, path: Any): Any = {
231-
val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString)
232-
UTF8String.fromString(ret)
233-
}
234-
235227
override protected def withNewChildrenInternal(
236228
newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight)
237229
}
@@ -250,24 +242,12 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract {
250242
group = "xml_funcs")
251243
// scalastyle:on line.size.limit
252244
case class XPathList(xml: Expression, path: Expression) extends XPathExtract {
245+
246+
@transient override lazy val evaluator: XPathEvaluator = XPathListEvaluator(pathUTF8String)
247+
253248
override def prettyName: String = "xpath"
254249
override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType)
255250

256-
override def nullSafeEval(xml: Any, path: Any): Any = {
257-
val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString)
258-
if (nodeList ne null) {
259-
val ret = new Array[AnyRef](nodeList.getLength)
260-
var i = 0
261-
while (i < nodeList.getLength) {
262-
ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue)
263-
i += 1
264-
}
265-
new GenericArrayData(ret)
266-
} else {
267-
null
268-
}
269-
}
270-
271251
override protected def withNewChildrenInternal(
272252
newLeft: Expression, newRight: Expression): XPathList = copy(xml = newLeft, path = newRight)
273253
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [xpath(s#0, a/b/text()) AS xpath(s, a/b/text())#0]
1+
Project [invoke(XPathListEvaluator(a/b/text()).evaluate(s#0)) AS xpath(s, a/b/text())#0]
22
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [xpath_boolean(s#0, a/b) AS xpath_boolean(s, a/b)#0]
1+
Project [invoke(XPathBooleanEvaluator(a/b).evaluate(s#0)) AS xpath_boolean(s, a/b)#0]
22
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [xpath_double(s#0, a/b) AS xpath_double(s, a/b)#0]
1+
Project [invoke(XPathDoubleEvaluator(a/b).evaluate(s#0)) AS xpath_double(s, a/b)#0]
22
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [xpath_float(s#0, a/b) AS xpath_float(s, a/b)#0]
1+
Project [invoke(XPathFloatEvaluator(a/b).evaluate(s#0)) AS xpath_float(s, a/b)#0]
22
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [xpath_int(s#0, a/b) AS xpath_int(s, a/b)#0]
1+
Project [invoke(XPathIntEvaluator(a/b).evaluate(s#0)) AS xpath_int(s, a/b)#0]
22
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [xpath_long(s#0, a/b) AS xpath_long(s, a/b)#0L]
1+
Project [invoke(XPathLongEvaluator(a/b).evaluate(s#0)) AS xpath_long(s, a/b)#0L]
22
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [xpath_number(s#0, a/b) AS xpath_number(s, a/b)#0]
1+
Project [invoke(XPathDoubleEvaluator(a/b).evaluate(s#0)) AS xpath_number(s, a/b)#0]
22
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [xpath_short(s#0, a/b) AS xpath_short(s, a/b)#0]
1+
Project [invoke(XPathShortEvaluator(a/b).evaluate(s#0)) AS xpath_short(s, a/b)#0]
22
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]

0 commit comments

Comments
 (0)