Skip to content

Commit ede05ff

Browse files
authored
chore: Refactor serde for RegExpReplace (#2548)
1 parent acfd03c commit ede05ff

File tree

2 files changed

+40
-31
lines changed

2 files changed

+40
-31
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
170170
classOf[Like] -> CometLike,
171171
classOf[Lower] -> CometLower,
172172
classOf[OctetLength] -> CometScalarFunction("octet_length"),
173+
classOf[RegExpReplace] -> CometRegExpReplace,
173174
classOf[Reverse] -> CometScalarFunction("reverse"),
174175
classOf[RLike] -> CometRLike,
175176
classOf[StartsWith] -> CometScalarFunction("starts_with"),
@@ -761,36 +762,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
761762
// `PromotePrecision` is just a wrapper, don't need to serialize it.
762763
exprToProtoInternal(child, inputs, binding)
763764

764-
case RegExpReplace(subject, pattern, replacement, startPosition) =>
765-
if (!RegExp.isSupportedPattern(pattern.toString) &&
766-
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
767-
withInfo(
768-
expr,
769-
s"Regexp pattern $pattern is not compatible with Spark. " +
770-
s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " +
771-
"to allow it anyway.")
772-
return None
773-
}
774-
startPosition match {
775-
case Literal(value, DataTypes.IntegerType) if value == 1 =>
776-
val subjectExpr = exprToProtoInternal(subject, inputs, binding)
777-
val patternExpr = exprToProtoInternal(pattern, inputs, binding)
778-
val replacementExpr = exprToProtoInternal(replacement, inputs, binding)
779-
// DataFusion's regexp_replace stops at the first match. We need to add the 'g' flag
780-
// to apply the regex globally to match Spark behavior.
781-
val flagsExpr = exprToProtoInternal(Literal("g"), inputs, binding)
782-
val optExpr = scalarFunctionExprToProto(
783-
"regexp_replace",
784-
subjectExpr,
785-
patternExpr,
786-
replacementExpr,
787-
flagsExpr)
788-
optExprWithInfo(optExpr, expr, subject, pattern, replacement, startPosition)
789-
case _ =>
790-
withInfo(expr, "Comet only supports regexp_replace with an offset of 1 (no offset).")
791-
None
792-
}
793-
794765
// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
795766
// char types.
796767
// See https://github.com/apache/spark/pull/38151

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Length, Like, Literal, Lower, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
2525
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
2626

2727
import org.apache.comet.CometConf
@@ -204,6 +204,44 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] {
204204
}
205205
}
206206

207+
object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] {
208+
override def getSupportLevel(expr: RegExpReplace): SupportLevel = {
209+
if (!RegExp.isSupportedPattern(expr.regexp.toString) &&
210+
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
211+
withInfo(
212+
expr,
213+
s"Regexp pattern ${expr.regexp} is not compatible with Spark. " +
214+
s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " +
215+
"to allow it anyway.")
216+
return Incompatible()
217+
}
218+
expr.pos match {
219+
case Literal(value, DataTypes.IntegerType) if value == 1 => Compatible()
220+
case _ =>
221+
Unsupported(Some("Comet only supports regexp_replace with an offset of 1 (no offset)."))
222+
}
223+
}
224+
225+
override def convert(
226+
expr: RegExpReplace,
227+
inputs: Seq[Attribute],
228+
binding: Boolean): Option[Expr] = {
229+
val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding)
230+
val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding)
231+
val replacementExpr = exprToProtoInternal(expr.rep, inputs, binding)
232+
// DataFusion's regexp_replace stops at the first match. We need to add the 'g' flag
233+
// to apply the regex globally to match Spark behavior.
234+
val flagsExpr = exprToProtoInternal(Literal("g"), inputs, binding)
235+
val optExpr = scalarFunctionExprToProto(
236+
"regexp_replace",
237+
subjectExpr,
238+
patternExpr,
239+
replacementExpr,
240+
flagsExpr)
241+
optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.rep, expr.pos)
242+
}
243+
}
244+
207245
trait CommonStringExprs {
208246

209247
def stringDecode(

0 commit comments

Comments
 (0)