@@ -17,29 +17,80 @@ package function
1717import (
1818 "fmt"
1919 "strings"
20+ "sync"
2021
2122 "gopkg.in/src-d/go-errors.v1"
2223
24+ regex "github.com/dolthub/go-icu-regex"
25+
2326 "github.com/dolthub/go-mysql-server/sql"
27+ "github.com/dolthub/go-mysql-server/sql/expression"
2428 "github.com/dolthub/go-mysql-server/sql/types"
2529)
2630
2731// RegexpReplace implements the REGEXP_REPLACE function.
2832// https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-replace
2933type RegexpReplace struct {
30- args []sql.Expression
34+ Text sql.Expression
35+ Pattern sql.Expression
36+ RText sql.Expression
37+ Position sql.Expression
38+ Occurrence sql.Expression
39+ Flags sql.Expression
40+
41+ cacheVal bool
42+ cachedVal any
43+ cacheRegex bool
44+ re regex.Regex
45+ compileOnce sync.Once
46+ compileErr error
3147}
3248
3349var _ sql.FunctionExpression = (* RegexpReplace )(nil )
3450var _ sql.CollationCoercible = (* RegexpReplace )(nil )
51+ var _ sql.Disposable = (* RegexpReplace )(nil )
3552
3653// NewRegexpReplace creates a new RegexpReplace expression.
3754func NewRegexpReplace (args ... sql.Expression ) (sql.Expression , error ) {
38- if len (args ) < 3 || len (args ) > 6 {
55+ var r * RegexpReplace
56+ switch len (args ) {
57+ case 6 :
58+ r = & RegexpReplace {
59+ Text : args [0 ],
60+ Pattern : args [1 ],
61+ RText : args [2 ],
62+ Position : args [3 ],
63+ Occurrence : args [4 ],
64+ Flags : args [5 ],
65+ }
66+ case 5 :
67+ r = & RegexpReplace {
68+ Text : args [0 ],
69+ Pattern : args [1 ],
70+ RText : args [2 ],
71+ Position : args [3 ],
72+ Occurrence : args [4 ],
73+ }
74+ case 4 :
75+ r = & RegexpReplace {
76+ Text : args [0 ],
77+ Pattern : args [1 ],
78+ RText : args [2 ],
79+ Position : args [3 ],
80+ Occurrence : expression .NewLiteral (0 , types .Int32 ),
81+ }
82+ case 3 :
83+ r = & RegexpReplace {
84+ Text : args [0 ],
85+ Pattern : args [1 ],
86+ RText : args [2 ],
87+ Position : expression .NewLiteral (1 , types .Int32 ),
88+ Occurrence : expression .NewLiteral (0 , types .Int32 ),
89+ }
90+ default :
3991 return nil , sql .ErrInvalidArgumentNumber .New ("regexp_replace" , "3,4,5 or 6" , len (args ))
4092 }
41-
42- return & RegexpReplace {args : args }, nil
93+ return r , nil
4394}
4495
4596// FunctionName implements sql.FunctionExpression
@@ -57,14 +108,11 @@ func (r *RegexpReplace) Type() sql.Type { return types.LongText }
57108
58109// CollationCoercibility implements the interface sql.CollationCoercible.
59110func (r * RegexpReplace ) CollationCoercibility (ctx * sql.Context ) (collation sql.CollationID , coercibility byte ) {
60- if len (r .args ) == 0 {
61- return sql .Collation_binary , 6
62- }
63- collation , coercibility = sql .GetCoercibility (ctx , r .args [0 ])
64- for i := 1 ; i < len (r .args ) && i < 3 ; i ++ {
65- nextCollation , nextCoercibility := sql .GetCoercibility (ctx , r .args [i ])
66- collation , coercibility = sql .ResolveCoercibility (collation , coercibility , nextCollation , nextCoercibility )
67- }
111+ collation , coercibility = sql .GetCoercibility (ctx , r .Text )
112+ nextCollation , nextCoercibility := sql .GetCoercibility (ctx , r .Pattern )
113+ collation , coercibility = sql .ResolveCoercibility (collation , coercibility , nextCollation , nextCoercibility )
114+ nextCollation , nextCoercibility = sql .GetCoercibility (ctx , r .RText )
115+ collation , coercibility = sql .ResolveCoercibility (collation , coercibility , nextCollation , nextCoercibility )
68116 return collation , coercibility
69117}
70118
@@ -73,152 +121,163 @@ func (r *RegexpReplace) IsNullable() bool { return true }
73121
74122// Children implements the sql.Expression interface.
75123func (r * RegexpReplace ) Children () []sql.Expression {
76- return r .args
124+ var children = []sql.Expression {r .Text , r .Pattern , r .RText , r .Position , r .Occurrence }
125+ if r .Flags != nil {
126+ children = append (children , r .Flags )
127+ }
128+ return children
77129}
78130
79131// Resolved implements the sql.Expression interface.
80132func (r * RegexpReplace ) Resolved () bool {
81- for _ , arg := range r . args {
82- if ! arg . Resolved () {
83- return false
84- }
85- }
86- return true
133+ return r . Text . Resolved () &&
134+ r . Pattern . Resolved () &&
135+ r . RText . Resolved () &&
136+ r . Position . Resolved () &&
137+ r . Occurrence . Resolved () &&
138+ ( r . Flags == nil || r . Flags . Resolved ())
87139}
88140
89141// WithChildren implements the sql.Expression interface.
90142func (r * RegexpReplace ) WithChildren (children ... sql.Expression ) (sql.Expression , error ) {
91- if len (children ) != len (r .args ) {
92- return nil , sql .ErrInvalidChildrenNumber .New (r , len (children ), len (r .args ))
143+ required := 3
144+ if r .Flags != nil {
145+ required = 4
146+ }
147+ if len (children ) != required {
148+ return nil , sql .ErrInvalidChildrenNumber .New (r , len (children ), required )
149+ }
150+
151+ // Copy over the regex instance, in case it has already been set to avoid leaking it.
152+ replace , err := NewRegexpReplace (children ... )
153+ if err != nil {
154+ if r .re != nil {
155+ if err = r .re .Close (); err != nil {
156+ return nil , err
157+ }
158+ }
159+ return nil , err
160+ }
161+ if r .re != nil {
162+ replace .(* RegexpReplace ).re = r .re
93163 }
94- return NewRegexpReplace ( children ... )
164+ return replace , nil
95165}
96166
97167func (r * RegexpReplace ) String () string {
98168 var args []string
99- for _ , e := range r .args {
169+ for _ , e := range r .Children () {
100170 args = append (args , e .String ())
101171 }
102172 return fmt .Sprintf ("%s(%s)" , r .FunctionName (), strings .Join (args , "," ))
103173}
104174
175+ func (r * RegexpReplace ) compile (ctx * sql.Context , row sql.Row ) {
176+ r .compileOnce .Do (func () {
177+ r .cacheRegex = canBeCached (r .Pattern , r .Flags )
178+ r .cacheVal = r .cacheRegex && canBeCached (r .Text , r .RText , r .Position , r .Occurrence )
179+ if r .cacheRegex {
180+ r .re , r .compileErr = compileRegex (ctx , r .Pattern , r .Text , r .Flags , r .FunctionName (), row )
181+ }
182+ })
183+ if ! r .cacheRegex {
184+ if r .re != nil {
185+ if r .compileErr = r .re .Close (); r .compileErr != nil {
186+ return
187+ }
188+ }
189+ r .re , r .compileErr = compileRegex (ctx , r .Pattern , r .Text , r .Flags , r .FunctionName (), row )
190+ }
191+ }
192+
105193// Eval implements the sql.Expression interface.
106194func (r * RegexpReplace ) Eval (ctx * sql.Context , row sql.Row ) (val interface {}, err error ) {
107- // Evaluate string value
108- str , err := r .args [0 ].Eval (ctx , row )
195+ span , ctx := ctx .Span ("function.RegexpReplace" )
196+ defer span .End ()
197+
198+ if r .cachedVal != nil {
199+ return r .cachedVal , nil
200+ }
201+
202+ r .compile (ctx , row )
203+ if r .compileErr != nil {
204+ return nil , r .compileErr
205+ }
206+ if r .re == nil {
207+ return nil , nil
208+ }
209+
210+ text , err := r .Text .Eval (ctx , row )
109211 if err != nil {
110212 return nil , err
111213 }
112- if str == nil {
214+ if text == nil {
113215 return nil , nil
114216 }
115- str , _ , err = types .LongText .Convert (ctx , str )
217+ text , _ , err = types .LongText .Convert (ctx , text )
116218 if err != nil {
117219 return nil , err
118220 }
119221
120- // Convert to string
121- _str := str .(string )
122-
123- // Handle flags
124- var flags sql.Expression = nil
125- if len (r .args ) == 6 {
126- flags = r .args [5 ]
127- }
128-
129- // Create regex, should handle null pattern and null flags
130- re , compileErr := compileRegex (ctx , r .args [1 ], r .args [0 ], flags , r .FunctionName (), row )
131- if compileErr != nil {
132- return nil , compileErr
222+ rText , err := r .RText .Eval (ctx , row )
223+ if err != nil {
224+ return nil , err
133225 }
134- if re == nil {
226+ if rText == nil {
135227 return nil , nil
136228 }
137- defer func () {
138- if nErr := re .Close (); err == nil {
139- err = nErr
140- }
141- }()
142- if err = re .SetMatchString (ctx , _str ); err != nil {
229+ rText , _ , err = types .LongText .Convert (ctx , rText )
230+ if err != nil {
143231 return nil , err
144232 }
145233
146- // Evaluate ReplaceStr
147- replaceStr , err := r .args [2 ].Eval (ctx , row )
234+ pos , err := r .Position .Eval (ctx , row )
148235 if err != nil {
149236 return nil , err
150237 }
151- if replaceStr == nil {
238+ if pos == nil {
152239 return nil , nil
153240 }
154- replaceStr , _ , err = types .LongText .Convert (ctx , replaceStr )
241+ pos , _ , err = types .Int32 .Convert (ctx , pos )
155242 if err != nil {
156243 return nil , err
157244 }
158-
159- // Convert to string
160- _replaceStr := replaceStr .(string )
161-
162- // Do nothing if str is empty
163- if len (_str ) == 0 {
164- return _str , nil
245+ if pos .(int32 ) <= 0 {
246+ return nil , sql .ErrInvalidArgumentDetails .New (r .FunctionName (), fmt .Sprintf ("%d" , pos .(int32 )))
165247 }
166248
167- // Default position is 1
168- _pos := 1
169-
170- // Check if position argument was provided
171- if len (r .args ) >= 4 {
172- // Evaluate position argument
173- pos , err := r .args [3 ].Eval (ctx , row )
174- if err != nil {
175- return nil , err
176- }
177- if pos == nil {
178- return nil , nil
179- }
180-
181- // Convert to int32
182- pos , _ , err = types .Int32 .Convert (ctx , pos )
183- if err != nil {
184- return nil , err
185- }
186- // Convert to int
187- _pos = int (pos .(int32 ))
249+ if len (text .(string )) != 0 && int (pos .(int32 )) > len (text .(string )) {
250+ return nil , errors .NewKind ("Index out of bounds for regular expression search." ).New ()
188251 }
189252
190- // Non-positive position throws incorrect parameter
191- if _pos <= 0 {
192- return nil , sql . ErrInvalidArgumentDetails . New ( r . FunctionName (), fmt . Sprintf ( "%d" , _pos ))
253+ occurrence , err := r . Occurrence . Eval ( ctx , row )
254+ if err != nil {
255+ return nil , err
193256 }
194-
195- // Handle out of bounds
196- if _pos > len (_str ) {
197- return nil , errors .NewKind ("Index out of bounds for regular expression search." ).New ()
257+ if occurrence == nil {
258+ return nil , nil
259+ }
260+ occurrence , _ , err = types .Int32 .Convert (ctx , occurrence )
261+ if err != nil {
262+ return nil , err
198263 }
199264
200- // Default occurrence is 0 (replace all occurrences)
201- _occ := 0
265+ err = r .re .SetMatchString (ctx , text .(string ))
266+ if err != nil {
267+ return nil , err
268+ }
202269
203- // Check if Occurrence argument was provided
204- if len (r .args ) >= 5 {
205- occ , err := r .args [4 ].Eval (ctx , row )
206- if err != nil {
207- return nil , err
208- }
209- if occ == nil {
210- return nil , nil
211- }
270+ result , err := r .re .Replace (ctx , rText .(string ), int (pos .(int32 )), int (occurrence .(int32 )))
271+ if err != nil {
272+ return nil , err
273+ }
212274
213- // Convert occurrence to int32
214- occ , _ , err = types .Int32 .Convert (ctx , occ )
215- if err != nil {
216- return nil , err
217- }
275+ return result , nil
276+ }
218277
219- // Convert to int
220- _occ = int (occ .(int32 ))
278+ // Dispose implements the sql.Disposable interface.
279+ func (r * RegexpReplace ) Dispose () {
280+ if r .re != nil {
281+ _ = r .re .Close ()
221282 }
222-
223- return re .Replace (ctx , _replaceStr , _pos , _occ )
224283}
0 commit comments