@@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
2121import org .apache .spark .sql .catalyst .analysis .TypeCheckResult .DataTypeMismatch
2222import org .apache .spark .sql .catalyst .expressions ._
2323import org .apache .spark .sql .catalyst .expressions .Cast ._
24- import org .apache .spark .sql .catalyst .expressions .codegen .CodegenFallback
25- import org .apache .spark .sql .catalyst .util .GenericArrayData
24+ import org .apache .spark .sql .catalyst .expressions .objects .Invoke
2625import org .apache .spark .sql .internal .SQLConf
2726import org .apache .spark .sql .internal .types .StringTypeWithCollation
2827import org .apache .spark .sql .types ._
@@ -34,10 +33,9 @@ import org.apache.spark.unsafe.types.UTF8String
3433 * This is not the world's most efficient implementation due to type conversion, but works.
3534 */
3635abstract class XPathExtract
37- extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
36+ extends BinaryExpression with RuntimeReplaceable with ExpectsInputTypes {
3837 override def left : Expression = xml
3938 override def right : Expression = path
40- override def nullIntolerant : Boolean = true
4139
4240 /** XPath expressions are always nullable, e.g. if the xml string is empty. */
4341 override def nullable : Boolean = true
@@ -60,12 +58,20 @@ abstract class XPathExtract
6058 }
6159 }
6260
63- @ transient protected lazy val xpathUtil = new UDFXPathUtil
64- @ transient protected lazy val pathString : String = path.eval().asInstanceOf [UTF8String ].toString
65-
6661 /** Concrete implementations need to override the following three methods. */
6762 def xml : Expression
6863 def path : Expression
64+
65+ @ transient protected lazy val pathUTF8String : UTF8String = path.eval().asInstanceOf [UTF8String ]
66+
67+ protected def evaluator : XPathEvaluator
68+
69+ override def replacement : Expression = Invoke (
70+ Literal .create(evaluator, ObjectType (classOf [XPathEvaluator ])),
71+ " evaluate" ,
72+ dataType,
73+ Seq (xml),
74+ Seq (xml.dataType))
6975}
7076
7177// scalastyle:off line.size.limit
@@ -81,11 +87,9 @@ abstract class XPathExtract
8187// scalastyle:on line.size.limit
8288case class XPathBoolean (xml : Expression , path : Expression ) extends XPathExtract with Predicate {
8389
84- override def prettyName : String = " xpath_boolean "
90+ @ transient override lazy val evaluator : XPathEvaluator = XPathBooleanEvaluator (pathUTF8String)
8591
86- override def nullSafeEval (xml : Any , path : Any ): Any = {
87- xpathUtil.evalBoolean(xml.asInstanceOf [UTF8String ].toString, pathString)
88- }
92+ override def prettyName : String = " xpath_boolean"
8993
9094 override protected def withNewChildrenInternal (
9195 newLeft : Expression , newRight : Expression ): XPathBoolean = copy(xml = newLeft, path = newRight)
@@ -103,14 +107,12 @@ case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract
103107 group = " xml_funcs" )
104108// scalastyle:on line.size.limit
105109case class XPathShort (xml : Expression , path : Expression ) extends XPathExtract {
110+
111+ @ transient override lazy val evaluator : XPathEvaluator = XPathShortEvaluator (pathUTF8String)
112+
106113 override def prettyName : String = " xpath_short"
107114 override def dataType : DataType = ShortType
108115
109- override def nullSafeEval (xml : Any , path : Any ): Any = {
110- val ret = xpathUtil.evalNumber(xml.asInstanceOf [UTF8String ].toString, pathString)
111- if (ret eq null ) null else ret.shortValue()
112- }
113-
114116 override protected def withNewChildrenInternal (
115117 newLeft : Expression , newRight : Expression ): XPathShort = copy(xml = newLeft, path = newRight)
116118}
@@ -127,14 +129,12 @@ case class XPathShort(xml: Expression, path: Expression) extends XPathExtract {
127129 group = " xml_funcs" )
128130// scalastyle:on line.size.limit
129131case class XPathInt (xml : Expression , path : Expression ) extends XPathExtract {
132+
133+ @ transient override lazy val evaluator : XPathEvaluator = XPathIntEvaluator (pathUTF8String)
134+
130135 override def prettyName : String = " xpath_int"
131136 override def dataType : DataType = IntegerType
132137
133- override def nullSafeEval (xml : Any , path : Any ): Any = {
134- val ret = xpathUtil.evalNumber(xml.asInstanceOf [UTF8String ].toString, pathString)
135- if (ret eq null ) null else ret.intValue()
136- }
137-
138138 override protected def withNewChildrenInternal (
139139 newLeft : Expression , newRight : Expression ): Expression = copy(xml = newLeft, path = newRight)
140140}
@@ -151,14 +151,12 @@ case class XPathInt(xml: Expression, path: Expression) extends XPathExtract {
151151 group = " xml_funcs" )
152152// scalastyle:on line.size.limit
153153case class XPathLong (xml : Expression , path : Expression ) extends XPathExtract {
154+
155+ @ transient override lazy val evaluator : XPathEvaluator = XPathLongEvaluator (pathUTF8String)
156+
154157 override def prettyName : String = " xpath_long"
155158 override def dataType : DataType = LongType
156159
157- override def nullSafeEval (xml : Any , path : Any ): Any = {
158- val ret = xpathUtil.evalNumber(xml.asInstanceOf [UTF8String ].toString, pathString)
159- if (ret eq null ) null else ret.longValue()
160- }
161-
162160 override protected def withNewChildrenInternal (
163161 newLeft : Expression , newRight : Expression ): XPathLong = copy(xml = newLeft, path = newRight)
164162}
@@ -175,14 +173,12 @@ case class XPathLong(xml: Expression, path: Expression) extends XPathExtract {
175173 group = " xml_funcs" )
176174// scalastyle:on line.size.limit
177175case class XPathFloat (xml : Expression , path : Expression ) extends XPathExtract {
176+
177+ @ transient override lazy val evaluator : XPathEvaluator = XPathFloatEvaluator (pathUTF8String)
178+
178179 override def prettyName : String = " xpath_float"
179180 override def dataType : DataType = FloatType
180181
181- override def nullSafeEval (xml : Any , path : Any ): Any = {
182- val ret = xpathUtil.evalNumber(xml.asInstanceOf [UTF8String ].toString, pathString)
183- if (ret eq null ) null else ret.floatValue()
184- }
185-
186182 override protected def withNewChildrenInternal (
187183 newLeft : Expression , newRight : Expression ): XPathFloat = copy(xml = newLeft, path = newRight)
188184}
@@ -199,15 +195,13 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
199195 group = " xml_funcs" )
200196// scalastyle:on line.size.limit
201197case class XPathDouble (xml : Expression , path : Expression ) extends XPathExtract {
198+
199+ @ transient override lazy val evaluator : XPathEvaluator = XPathDoubleEvaluator (pathUTF8String)
200+
202201 override def prettyName : String =
203202 getTagValue(FunctionRegistry .FUNC_ALIAS ).getOrElse(" xpath_double" )
204203 override def dataType : DataType = DoubleType
205204
206- override def nullSafeEval (xml : Any , path : Any ): Any = {
207- val ret = xpathUtil.evalNumber(xml.asInstanceOf [UTF8String ].toString, pathString)
208- if (ret eq null ) null else ret.doubleValue()
209- }
210-
211205 override protected def withNewChildrenInternal (
212206 newLeft : Expression , newRight : Expression ): XPathDouble = copy(xml = newLeft, path = newRight)
213207}
@@ -224,14 +218,12 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
224218 group = " xml_funcs" )
225219// scalastyle:on line.size.limit
226220case class XPathString (xml : Expression , path : Expression ) extends XPathExtract {
221+
222+ @ transient override lazy val evaluator : XPathEvaluator = XPathStringEvaluator (pathUTF8String)
223+
227224 override def prettyName : String = " xpath_string"
228225 override def dataType : DataType = SQLConf .get.defaultStringType
229226
230- override def nullSafeEval (xml : Any , path : Any ): Any = {
231- val ret = xpathUtil.evalString(xml.asInstanceOf [UTF8String ].toString, pathString)
232- UTF8String .fromString(ret)
233- }
234-
235227 override protected def withNewChildrenInternal (
236228 newLeft : Expression , newRight : Expression ): Expression = copy(xml = newLeft, path = newRight)
237229}
@@ -250,24 +242,12 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract {
250242 group = " xml_funcs" )
251243// scalastyle:on line.size.limit
252244case class XPathList (xml : Expression , path : Expression ) extends XPathExtract {
245+
246+ @ transient override lazy val evaluator : XPathEvaluator = XPathListEvaluator (pathUTF8String)
247+
253248 override def prettyName : String = " xpath"
254249 override def dataType : DataType = ArrayType (SQLConf .get.defaultStringType)
255250
256- override def nullSafeEval (xml : Any , path : Any ): Any = {
257- val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf [UTF8String ].toString, pathString)
258- if (nodeList ne null ) {
259- val ret = new Array [AnyRef ](nodeList.getLength)
260- var i = 0
261- while (i < nodeList.getLength) {
262- ret(i) = UTF8String .fromString(nodeList.item(i).getNodeValue)
263- i += 1
264- }
265- new GenericArrayData (ret)
266- } else {
267- null
268- }
269- }
270-
271251 override protected def withNewChildrenInternal (
272252 newLeft : Expression , newRight : Expression ): XPathList = copy(xml = newLeft, path = newRight)
273253}
0 commit comments