Skip to content

Commit e299ed1

Browse files
author
James Cor
committed
fixup replace
1 parent 856c070 commit e299ed1

File tree

4 files changed

+204
-108
lines changed

4 files changed

+204
-108
lines changed

sql/expression/function/regexp_replace.go

Lines changed: 165 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,80 @@ package function
1717
import (
1818
"fmt"
1919
"strings"
20+
"sync"
2021

2122
"gopkg.in/src-d/go-errors.v1"
2223

24+
regex "github.com/dolthub/go-icu-regex"
25+
2326
"github.com/dolthub/go-mysql-server/sql"
27+
"github.com/dolthub/go-mysql-server/sql/expression"
2428
"github.com/dolthub/go-mysql-server/sql/types"
2529
)
2630

2731
// RegexpReplace implements the REGEXP_REPLACE function.
2832
// https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-replace
2933
type RegexpReplace struct {
30-
args []sql.Expression
34+
Text sql.Expression
35+
Pattern sql.Expression
36+
RText sql.Expression
37+
Position sql.Expression
38+
Occurrence sql.Expression
39+
Flags sql.Expression
40+
41+
cacheVal bool
42+
cachedVal any
43+
cacheRegex bool
44+
re regex.Regex
45+
compileOnce sync.Once
46+
compileErr error
3147
}
3248

3349
var _ sql.FunctionExpression = (*RegexpReplace)(nil)
3450
var _ sql.CollationCoercible = (*RegexpReplace)(nil)
51+
var _ sql.Disposable = (*RegexpReplace)(nil)
3552

3653
// NewRegexpReplace creates a new RegexpReplace expression.
3754
func NewRegexpReplace(args ...sql.Expression) (sql.Expression, error) {
38-
if len(args) < 3 || len(args) > 6 {
55+
var r *RegexpReplace
56+
switch len(args) {
57+
case 6:
58+
r = &RegexpReplace{
59+
Text: args[0],
60+
Pattern: args[1],
61+
RText: args[2],
62+
Position: args[3],
63+
Occurrence: args[4],
64+
Flags: args[5],
65+
}
66+
case 5:
67+
r = &RegexpReplace{
68+
Text: args[0],
69+
Pattern: args[1],
70+
RText: args[2],
71+
Position: args[3],
72+
Occurrence: args[4],
73+
}
74+
case 4:
75+
r = &RegexpReplace{
76+
Text: args[0],
77+
Pattern: args[1],
78+
RText: args[2],
79+
Position: args[3],
80+
Occurrence: expression.NewLiteral(0, types.Int32),
81+
}
82+
case 3:
83+
r = &RegexpReplace{
84+
Text: args[0],
85+
Pattern: args[1],
86+
RText: args[2],
87+
Position: expression.NewLiteral(1, types.Int32),
88+
Occurrence: expression.NewLiteral(0, types.Int32),
89+
}
90+
default:
3991
return nil, sql.ErrInvalidArgumentNumber.New("regexp_replace", "3,4,5 or 6", len(args))
4092
}
41-
42-
return &RegexpReplace{args: args}, nil
93+
return r, nil
4394
}
4495

4596
// FunctionName implements sql.FunctionExpression
@@ -57,14 +108,11 @@ func (r *RegexpReplace) Type() sql.Type { return types.LongText }
57108

58109
// CollationCoercibility implements the interface sql.CollationCoercible.
59110
func (r *RegexpReplace) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
60-
if len(r.args) == 0 {
61-
return sql.Collation_binary, 6
62-
}
63-
collation, coercibility = sql.GetCoercibility(ctx, r.args[0])
64-
for i := 1; i < len(r.args) && i < 3; i++ {
65-
nextCollation, nextCoercibility := sql.GetCoercibility(ctx, r.args[i])
66-
collation, coercibility = sql.ResolveCoercibility(collation, coercibility, nextCollation, nextCoercibility)
67-
}
111+
collation, coercibility = sql.GetCoercibility(ctx, r.Text)
112+
nextCollation, nextCoercibility := sql.GetCoercibility(ctx, r.Pattern)
113+
collation, coercibility = sql.ResolveCoercibility(collation, coercibility, nextCollation, nextCoercibility)
114+
nextCollation, nextCoercibility = sql.GetCoercibility(ctx, r.RText)
115+
collation, coercibility = sql.ResolveCoercibility(collation, coercibility, nextCollation, nextCoercibility)
68116
return collation, coercibility
69117
}
70118

@@ -73,152 +121,163 @@ func (r *RegexpReplace) IsNullable() bool { return true }
73121

74122
// Children implements the sql.Expression interface.
75123
func (r *RegexpReplace) Children() []sql.Expression {
76-
return r.args
124+
var children = []sql.Expression{r.Text, r.Pattern, r.RText, r.Position, r.Occurrence}
125+
if r.Flags != nil {
126+
children = append(children, r.Flags)
127+
}
128+
return children
77129
}
78130

79131
// Resolved implements the sql.Expression interface.
80132
func (r *RegexpReplace) Resolved() bool {
81-
for _, arg := range r.args {
82-
if !arg.Resolved() {
83-
return false
84-
}
85-
}
86-
return true
133+
return r.Text.Resolved() &&
134+
r.Pattern.Resolved() &&
135+
r.RText.Resolved() &&
136+
r.Position.Resolved() &&
137+
r.Occurrence.Resolved() &&
138+
(r.Flags == nil || r.Flags.Resolved())
87139
}
88140

89141
// WithChildren implements the sql.Expression interface.
90142
func (r *RegexpReplace) WithChildren(children ...sql.Expression) (sql.Expression, error) {
91-
if len(children) != len(r.args) {
92-
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), len(r.args))
143+
required := 3
144+
if r.Flags != nil {
145+
required = 4
146+
}
147+
if len(children) != required {
148+
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required)
149+
}
150+
151+
// Copy over the regex instance, in case it has already been set to avoid leaking it.
152+
replace, err := NewRegexpReplace(children...)
153+
if err != nil {
154+
if r.re != nil {
155+
if err = r.re.Close(); err != nil {
156+
return nil, err
157+
}
158+
}
159+
return nil, err
160+
}
161+
if r.re != nil {
162+
replace.(*RegexpReplace).re = r.re
93163
}
94-
return NewRegexpReplace(children...)
164+
return replace, nil
95165
}
96166

97167
func (r *RegexpReplace) String() string {
98168
var args []string
99-
for _, e := range r.args {
169+
for _, e := range r.Children() {
100170
args = append(args, e.String())
101171
}
102172
return fmt.Sprintf("%s(%s)", r.FunctionName(), strings.Join(args, ","))
103173
}
104174

175+
func (r *RegexpReplace) compile(ctx *sql.Context, row sql.Row) {
176+
r.compileOnce.Do(func() {
177+
r.cacheRegex = canBeCached(r.Pattern, r.Flags)
178+
r.cacheVal = r.cacheRegex && canBeCached(r.Text, r.RText, r.Position, r.Occurrence)
179+
if r.cacheRegex {
180+
r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), row)
181+
}
182+
})
183+
if !r.cacheRegex {
184+
if r.re != nil {
185+
if r.compileErr = r.re.Close(); r.compileErr != nil {
186+
return
187+
}
188+
}
189+
r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), row)
190+
}
191+
}
192+
105193
// Eval implements the sql.Expression interface.
106194
func (r *RegexpReplace) Eval(ctx *sql.Context, row sql.Row) (val interface{}, err error) {
107-
// Evaluate string value
108-
str, err := r.args[0].Eval(ctx, row)
195+
span, ctx := ctx.Span("function.RegexpReplace")
196+
defer span.End()
197+
198+
if r.cachedVal != nil {
199+
return r.cachedVal, nil
200+
}
201+
202+
r.compile(ctx, row)
203+
if r.compileErr != nil {
204+
return nil, r.compileErr
205+
}
206+
if r.re == nil {
207+
return nil, nil
208+
}
209+
210+
text, err := r.Text.Eval(ctx, row)
109211
if err != nil {
110212
return nil, err
111213
}
112-
if str == nil {
214+
if text == nil {
113215
return nil, nil
114216
}
115-
str, _, err = types.LongText.Convert(ctx, str)
217+
text, _, err = types.LongText.Convert(ctx, text)
116218
if err != nil {
117219
return nil, err
118220
}
119221

120-
// Convert to string
121-
_str := str.(string)
122-
123-
// Handle flags
124-
var flags sql.Expression = nil
125-
if len(r.args) == 6 {
126-
flags = r.args[5]
127-
}
128-
129-
// Create regex, should handle null pattern and null flags
130-
re, compileErr := compileRegex(ctx, r.args[1], r.args[0], flags, r.FunctionName(), row)
131-
if compileErr != nil {
132-
return nil, compileErr
222+
rText, err := r.RText.Eval(ctx, row)
223+
if err != nil {
224+
return nil, err
133225
}
134-
if re == nil {
226+
if rText == nil {
135227
return nil, nil
136228
}
137-
defer func() {
138-
if nErr := re.Close(); err == nil {
139-
err = nErr
140-
}
141-
}()
142-
if err = re.SetMatchString(ctx, _str); err != nil {
229+
rText, _, err = types.LongText.Convert(ctx, rText)
230+
if err != nil {
143231
return nil, err
144232
}
145233

146-
// Evaluate ReplaceStr
147-
replaceStr, err := r.args[2].Eval(ctx, row)
234+
pos, err := r.Position.Eval(ctx, row)
148235
if err != nil {
149236
return nil, err
150237
}
151-
if replaceStr == nil {
238+
if pos == nil {
152239
return nil, nil
153240
}
154-
replaceStr, _, err = types.LongText.Convert(ctx, replaceStr)
241+
pos, _, err = types.Int32.Convert(ctx, pos)
155242
if err != nil {
156243
return nil, err
157244
}
158-
159-
// Convert to string
160-
_replaceStr := replaceStr.(string)
161-
162-
// Do nothing if str is empty
163-
if len(_str) == 0 {
164-
return _str, nil
245+
if pos.(int32) <= 0 {
246+
return nil, sql.ErrInvalidArgumentDetails.New(r.FunctionName(), fmt.Sprintf("%d", pos.(int32)))
165247
}
166248

167-
// Default position is 1
168-
_pos := 1
169-
170-
// Check if position argument was provided
171-
if len(r.args) >= 4 {
172-
// Evaluate position argument
173-
pos, err := r.args[3].Eval(ctx, row)
174-
if err != nil {
175-
return nil, err
176-
}
177-
if pos == nil {
178-
return nil, nil
179-
}
180-
181-
// Convert to int32
182-
pos, _, err = types.Int32.Convert(ctx, pos)
183-
if err != nil {
184-
return nil, err
185-
}
186-
// Convert to int
187-
_pos = int(pos.(int32))
249+
if len(text.(string)) != 0 && int(pos.(int32)) > len(text.(string)) {
250+
return nil, errors.NewKind("Index out of bounds for regular expression search.").New()
188251
}
189252

190-
// Non-positive position throws incorrect parameter
191-
if _pos <= 0 {
192-
return nil, sql.ErrInvalidArgumentDetails.New(r.FunctionName(), fmt.Sprintf("%d", _pos))
253+
occurrence, err := r.Occurrence.Eval(ctx, row)
254+
if err != nil {
255+
return nil, err
193256
}
194-
195-
// Handle out of bounds
196-
if _pos > len(_str) {
197-
return nil, errors.NewKind("Index out of bounds for regular expression search.").New()
257+
if occurrence == nil {
258+
return nil, nil
259+
}
260+
occurrence, _, err = types.Int32.Convert(ctx, occurrence)
261+
if err != nil {
262+
return nil, err
198263
}
199264

200-
// Default occurrence is 0 (replace all occurrences)
201-
_occ := 0
265+
err = r.re.SetMatchString(ctx, text.(string))
266+
if err != nil {
267+
return nil, err
268+
}
202269

203-
// Check if Occurrence argument was provided
204-
if len(r.args) >= 5 {
205-
occ, err := r.args[4].Eval(ctx, row)
206-
if err != nil {
207-
return nil, err
208-
}
209-
if occ == nil {
210-
return nil, nil
211-
}
270+
result, err := r.re.Replace(ctx, rText.(string), int(pos.(int32)), int(occurrence.(int32)))
271+
if err != nil {
272+
return nil, err
273+
}
212274

213-
// Convert occurrence to int32
214-
occ, _, err = types.Int32.Convert(ctx, occ)
215-
if err != nil {
216-
return nil, err
217-
}
275+
return result, nil
276+
}
218277

219-
// Convert to int
220-
_occ = int(occ.(int32))
278+
// Dispose implements the sql.Disposable interface.
279+
func (r *RegexpReplace) Dispose() {
280+
if r.re != nil {
281+
_ = r.re.Close()
221282
}
222-
223-
return re.Replace(ctx, _replaceStr, _pos, _occ)
224283
}

0 commit comments

Comments
 (0)