@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121
2222import java .util .Locale
2323
24- import org .apache .spark .sql .catalyst .expressions .{Attribute , Cast , Expression , InitCap , Length , Like , Literal , Lower , RLike , StringLPad , StringRepeat , StringRPad , Substring , Upper }
24+ import org .apache .spark .sql .catalyst .expressions .{Attribute , Cast , Expression , InitCap , Length , Like , Literal , Lower , RegExpReplace , RLike , StringLPad , StringRepeat , StringRPad , Substring , Upper }
2525import org .apache .spark .sql .types .{BinaryType , DataTypes , LongType , StringType }
2626
2727import org .apache .comet .CometConf
@@ -204,6 +204,44 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] {
204204 }
205205}
206206
207+ object CometRegExpReplace extends CometExpressionSerde [RegExpReplace ] {
208+ override def getSupportLevel (expr : RegExpReplace ): SupportLevel = {
209+ if (! RegExp .isSupportedPattern(expr.regexp.toString) &&
210+ ! CometConf .COMET_REGEXP_ALLOW_INCOMPATIBLE .get()) {
211+ withInfo(
212+ expr,
213+ s " Regexp pattern ${expr.regexp} is not compatible with Spark. " +
214+ s " Set ${CometConf .COMET_REGEXP_ALLOW_INCOMPATIBLE .key}=true " +
215+ " to allow it anyway." )
216+ return Incompatible ()
217+ }
218+ expr.pos match {
219+ case Literal (value, DataTypes .IntegerType ) if value == 1 => Compatible ()
220+ case _ =>
221+ Unsupported (Some (" Comet only supports regexp_replace with an offset of 1 (no offset)." ))
222+ }
223+ }
224+
225+ override def convert (
226+ expr : RegExpReplace ,
227+ inputs : Seq [Attribute ],
228+ binding : Boolean ): Option [Expr ] = {
229+ val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding)
230+ val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding)
231+ val replacementExpr = exprToProtoInternal(expr.rep, inputs, binding)
232+ // DataFusion's regexp_replace stops at the first match. We need to add the 'g' flag
233+ // to apply the regex globally to match Spark behavior.
234+ val flagsExpr = exprToProtoInternal(Literal (" g" ), inputs, binding)
235+ val optExpr = scalarFunctionExprToProto(
236+ " regexp_replace" ,
237+ subjectExpr,
238+ patternExpr,
239+ replacementExpr,
240+ flagsExpr)
241+ optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.rep, expr.pos)
242+ }
243+ }
244+
207245trait CommonStringExprs {
208246
209247 def stringDecode (
0 commit comments