diff --git a/enginetest/queries/regex_queries.go b/enginetest/queries/regex_queries.go index 3d19a004ad..9c132c95fc 100644 --- a/enginetest/queries/regex_queries.go +++ b/enginetest/queries/regex_queries.go @@ -23,8 +23,7 @@ package queries import ( "gopkg.in/src-d/go-errors.v1" - regex "github.com/dolthub/go-icu-regex" - + "github.com/dolthub/go-mysql-server/internal/regex" "github.com/dolthub/go-mysql-server/sql" ) diff --git a/internal/regex/regex_cgo.go b/internal/regex/regex_cgo.go new file mode 100644 index 0000000000..035c9a3c5c --- /dev/null +++ b/internal/regex/regex_cgo.go @@ -0,0 +1,31 @@ +//go:build cgo && !gms_pure_go + +package regex + +import regex "github.com/dolthub/go-icu-regex" + +type Regex = regex.Regex + +var ( + ErrRegexNotYetSet = regex.ErrRegexNotYetSet + ErrMatchNotYetSet = regex.ErrMatchNotYetSet + ErrInvalidRegex = regex.ErrInvalidRegex +) + +type RegexFlags = regex.RegexFlags + +const ( + RegexFlags_None = regex.RegexFlags_None + RegexFlags_Case_Insensitive = regex.RegexFlags_Case_Insensitive + RegexFlags_Comments = regex.RegexFlags_Comments + RegexFlags_Dot_All = regex.RegexFlags_Dot_All + RegexFlags_Literal = regex.RegexFlags_Literal + RegexFlags_Multiline = regex.RegexFlags_Multiline + RegexFlags_Unix_Lines = regex.RegexFlags_Unix_Lines + RegexFlags_Unicode_Word = regex.RegexFlags_Unicode_Word + RegexFlags_Error_On_Unknown_Escapes = regex.RegexFlags_Error_On_Unknown_Escapes +) + +func CreateRegex(stringBufferInBytes uint32) Regex { + return regex.CreateRegex(stringBufferInBytes) +} diff --git a/internal/regex/regex_pure.go b/internal/regex/regex_pure.go new file mode 100644 index 0000000000..7cba07fbef --- /dev/null +++ b/internal/regex/regex_pure.go @@ -0,0 +1,203 @@ +//go:build !cgo || gms_pure_go + +package regex + +import ( + "context" + "fmt" + "regexp" + + "gopkg.in/src-d/go-errors.v1" +) + +type Regex interface { + SetRegexString(ctx context.Context, regexStr string, flags RegexFlags) error + SetMatchString(ctx context.Context, matchStr string) error + IndexOf(ctx context.Context, start int, occurrence int, endIndex bool) (int, error) + Matches(ctx context.Context, start int, occurrence int) (bool, error) + Replace(ctx context.Context, replacementStr string, position int, occurrence int) (string, error) + Substring(ctx context.Context, start int, occurrence int) (string, bool, error) + Close() error +} + +var ( + ErrRegexNotYetSet = errors.NewKind("SetRegexString must be called before any other function") + ErrMatchNotYetSet = errors.NewKind("SetMatchString must be called as there is nothing to match against") + ErrInvalidRegex = errors.NewKind("the given regular expression is invalid") +) + +type RegexFlags uint32 + +const ( + RegexFlags_None RegexFlags = 0 + RegexFlags_Case_Insensitive RegexFlags = 2 + RegexFlags_Comments RegexFlags = 4 + RegexFlags_Dot_All RegexFlags = 32 + RegexFlags_Literal RegexFlags = 16 + RegexFlags_Multiline RegexFlags = 8 + RegexFlags_Unix_Lines RegexFlags = 1 + RegexFlags_Unicode_Word RegexFlags = 256 + RegexFlags_Error_On_Unknown_Escapes RegexFlags = 512 +) + +func CreateRegex(stringBufferInBytes uint32) Regex { + return &privateRegex{} +} + +type privateRegex struct { + re *regexp.Regexp + str string + sset bool + + done bool + start int + locs [][]int +} + +var _ Regex = (*privateRegex)(nil) + +func (pr *privateRegex) SetRegexString(ctx context.Context, regexStr string, flags RegexFlags) (err error) { + // i : RegexFlags_Case_Insensitive + // m : RegexFlags_Multiline + // s : RegexFlags_Dot_All + // RegexFlags_Unix_Lines + var flg = "(?" + if flags&RegexFlags_Case_Insensitive != 0 { + flg += "i" + } + if flags&RegexFlags_Multiline != 0 { + flg += "m" + } + if flags&RegexFlags_Dot_All != 0 { + flg += "s" + } + if len(flg) > 2 { + flg += ")" + } else { + flg = "" + } + + pr.done = false + pr.sset = false + pr.re, err = regexp.Compile(flg + regexStr) + if err != nil { + return ErrInvalidRegex.New() + } + return nil +} + +func (pr *privateRegex) SetMatchString(ctx context.Context, matchStr string) (err error) { + if pr.re == nil { + return ErrRegexNotYetSet.New() + } + pr.done = false + pr.str = matchStr + pr.sset = true + return nil +} + +func (pr *privateRegex) do(start int) error { + if start < 1 { + start = 1 + } + if !pr.done || pr.start != start { + if pr.re == nil { + return ErrRegexNotYetSet.New() + } + if !pr.sset { + return ErrMatchNotYetSet.New() + } + pr.locs = pr.re.FindAllStringIndex(pr.str[start-1:], -1) + pr.start = start + pr.done = true + } + return nil +} + +func (pr *privateRegex) location(occurrence int) []int { + occurrence-- + if occurrence < 0 { + occurrence = 0 + } + if len(pr.locs) < occurrence+1 { + return nil + } + return pr.locs[occurrence] +} + +func (pr *privateRegex) IndexOf(ctx context.Context, start int, occurrence int, endIndex bool) (int, error) { + err := pr.do(start) + if err != nil { + return 0, err + } + loc := pr.location(occurrence) + if loc == nil { + return 0, nil + } + pos := loc[0] + if endIndex { + pos = loc[1] + } + return pos + pr.start, nil +} + +func (pr *privateRegex) Matches(ctx context.Context, start int, occurrence int) (bool, error) { + err := pr.do(start + 1) // start+1: issue #10 (https://github.com/dolthub/go-icu-regex/issues/10) + if err != nil { + return false, err + } + loc := pr.location(occurrence) + return loc != nil, nil +} + +func (pr *privateRegex) Replace(ctx context.Context, replacement string, start int, occurrence int) (string, error) { + err := pr.do(start) + if err != nil { + return "", err + } + + var locs [][]int + if occurrence == 0 { + locs = pr.locs + } else { + loc := pr.location(occurrence) + if loc != nil { + locs = [][]int{loc} + } + } + offs := pr.start - 1 + pos := offs + ret := []byte(pr.str[:pos]) + for _, loc := range locs { + ret = fmt.Appendf(ret, "%s%s", pr.str[pos:loc[0]+offs], replacement) + pos = loc[1] + offs + } + ret = fmt.Append(ret, pr.str[pos:]) + return string(ret), nil + + loc := pr.location(occurrence) + if loc == nil { + return pr.str, nil + } + return pr.str[:loc[0]+pr.start-1] + replacement + pr.str[loc[1]+pr.start-1:], nil +} + +func (pr *privateRegex) Substring(ctx context.Context, start int, occurrence int) (string, bool, error) { + err := pr.do(start) + if err != nil { + return "", false, err + } + loc := pr.location(occurrence) + if loc == nil { + return "", false, nil + } + return pr.str[loc[0]+pr.start-1 : loc[1]+pr.start-1], true, nil +} + +func (pr *privateRegex) Close() (err error) { + pr.re = nil + pr.str = "" + pr.done = false + pr.locs = nil + return nil +} diff --git a/internal/regex/regex_test.go b/internal/regex/regex_test.go new file mode 100644 index 0000000000..ae1072b5a4 --- /dev/null +++ b/internal/regex/regex_test.go @@ -0,0 +1,419 @@ +package regex_test + +import ( + "fmt" + "testing" + + "github.com/dolthub/go-mysql-server/internal/regex" +) + +func TestMatches(t *testing.T) { + tests := map[int]struct { + reg string + flag regex.RegexFlags + str string + start int + occ int + exp bool + }{ + 0: { + reg: `abc+.*this st`, + str: "Find the abc in this string", + exp: true, + }, + 1: { + reg: `abc+.*this st`, + str: "Find the abc in this here string", + exp: false, + }, + 2: { + reg: `[a-zA-Z0-9]{5} \w{4} aab`, + str: "Words like aab don't exist", + exp: true, + }, + 3: { + reg: `^[aA]bcd[eE]$`, + str: "abcde", + exp: true, + }, + 4: { + reg: `^[aA]bcd[eE]$`, + str: "Abcde", + exp: true, + }, + 5: { + reg: `^[aA]bcd[eE]$`, + str: "AbcdE", + exp: true, + }, + } + + for _, test := range tests { + name := fmt.Sprintf(`%q/%q`, test.reg, test.str) + + t.Run(name, func(t *testing.T) { + re := regex.CreateRegex(1024) + defer re.Close() + if err := re.SetRegexString(t.Context(), test.reg, test.flag); err != nil { + t.Error(err) + } + if err := re.SetMatchString(t.Context(), test.str); err != nil { + t.Error(err) + } + + r, err := re.Matches(t.Context(), test.start, test.occ) + if err != nil { + t.Error(err) + } + if r != test.exp { + t.Errorf("Matches = %v, wants %v", r, test.exp) + } + }) + } +} + +func TestReplace(t *testing.T) { + re := regex.CreateRegex(1024) + defer re.Close() + if err := re.SetRegexString(t.Context(), `[a-z]+`, regex.RegexFlags_None); err != nil { + t.Fatal(err) + } + if err := re.SetMatchString(t.Context(), "abc def ghi"); err != nil { + t.Fatal(err) + } + + tests := map[int]struct { + pos int + occ int + exp string + }{ + 0: { + pos: 1, + occ: 2, + exp: "abc X ghi", + }, + 1: { + pos: 1, + occ: 3, + exp: "abc def X", + }, + 2: { + pos: 1, + occ: 0, + exp: "X X X", + }, + 4: { + pos: 1, + occ: 4, + exp: "abc def ghi", + }, + } + for _, test := range tests { + name := fmt.Sprintf("[%d,%d]", test.pos, test.occ) + + t.Run(name, func(t *testing.T) { + r, err := re.Replace(t.Context(), "X", test.pos, test.occ) + if err != nil { + t.Error(err) + } + if r != test.exp { + t.Errorf("Replace = %q, wants %q", r, test.exp) + } + }) + } +} + +func TestIndexOf(t *testing.T) { + re := regex.CreateRegex(1024) + defer re.Close() + if err := re.SetRegexString(t.Context(), `[a-j]+`, regex.RegexFlags_None); err != nil { + t.Fatal(err) + } + + tests := map[int]struct { + str string + start int + occ int + endi bool + exp int + }{ + 0: { + str: "abc def ghi", + start: 1, + occ: 1, + endi: false, + exp: 1, + }, + 1: { + str: "abc def ghi", + start: 4, + occ: 1, + endi: false, + exp: 5, + }, + 2: { + str: "abc def ghi", + start: 8, + occ: 1, + endi: false, + exp: 9, + }, + 3: { + str: "abc def ghi", + start: 1, + occ: 3, + endi: false, + exp: 9, + }, + 4: { + str: "abc def ghi", + start: 1, + occ: 4, + endi: false, + exp: 0, + }, + 5: { + str: "abc def ghi", + start: 1, + occ: 1, + endi: true, + exp: 4, + }, + 6: { + str: "abc def ghi", + start: 4, + occ: 1, + endi: true, + exp: 8, + }, + 7: { + str: "abc def ghi", + start: 8, + occ: 1, + endi: true, + exp: 12, + }, + 8: { + str: "abc def ghi", + start: 1, + occ: 2, + endi: true, + exp: 8, + }, + 9: { + str: "abc def ghi", + start: 1, + occ: 3, + endi: true, + exp: 12, + }, + 10: { + str: "abc def ghi", + start: 1, + occ: 4, + endi: true, + exp: 0, + }, + 11: { + str: "klmno fghij abcde", + start: 1, + occ: 1, + endi: false, + exp: 7, + }, + 12: { + str: "klmno fghij abcde", + start: 1, + occ: 1, + endi: true, + exp: 12, + }, + } + + for _, test := range tests { + name := fmt.Sprintf("%q/[%d,%d,%v]", test.str, test.start, test.occ, test.endi) + t.Run(name, func(t *testing.T) { + if err := re.SetMatchString(t.Context(), test.str); err != nil { + t.Fatal(err) + } + r, err := re.IndexOf(t.Context(), test.start, test.occ, test.endi) + if err != nil { + t.Error(err) + } + if r != test.exp { + t.Errorf("IndexOf = %v, wants %v", r, test.exp) + } + }) + } +} + +func TestSubstring(t *testing.T) { + re := regex.CreateRegex(1024) + defer re.Close() + if err := re.SetRegexString(t.Context(), `[a-z]+`, regex.RegexFlags_None); err != nil { + t.Fatal(err) + } + + tests := map[int]struct { + str string + start int + occ int + expb bool + exps string + }{ + 0: { + str: "abc def ghi", + start: 1, + occ: 1, + expb: true, + exps: "abc", + }, + 1: { + str: "abc def ghi", + start: 4, + occ: 1, + expb: true, + exps: "def", + }, + 2: { + str: "abc def ghi", + start: 8, + occ: 1, + expb: true, + exps: "ghi", + }, + 3: { + str: "abc def ghi", + start: 1, + occ: 2, + expb: true, + exps: "def", + }, + 4: { + str: "abc def ghi", + start: 1, + occ: 3, + expb: true, + exps: "ghi", + }, + 5: { + str: "abc def ghi", + start: 1, + occ: 4, + expb: false, + exps: "", + }, + 6: { + str: "ghx dey abz", + start: 1, + occ: 1, + expb: true, + exps: "ghx", + }, + } + + for _, test := range tests { + name := fmt.Sprintf("%q/[%d,%d]", test.str, test.start, test.occ) + t.Run(name, func(t *testing.T) { + if err := re.SetMatchString(t.Context(), test.str); err != nil { + t.Fatal(err) + } + rs, rb, err := re.Substring(t.Context(), test.start, test.occ) + if err != nil { + t.Error(err) + } + if rs != test.exps || rb != test.expb { + t.Errorf("IndexOf = (%q, %v), wants (%q, %v)", rs, rb, test.exps, test.expb) + } + }) + } +} + +func TestCaseSensitivity(t *testing.T) { + tests := map[int]struct { + reg string + flag regex.RegexFlags + str string + exp bool + }{ + 0: { + reg: `abc`, + flag: regex.RegexFlags_Case_Insensitive, + str: "ABC", + exp: true, + }, + 1: { + reg: `abc`, + flag: regex.RegexFlags_None, + str: "ABC", + exp: false, + }, + } + + for _, test := range tests { + name := fmt.Sprintf("%q(%v)/%q", test.reg, test.flag, test.str) + t.Run(name, func(t *testing.T) { + re := regex.CreateRegex(1024) + defer re.Close() + if err := re.SetRegexString(t.Context(), test.reg, test.flag); err != nil { + t.Fatal(err) + } + if err := re.SetMatchString(t.Context(), test.str); err != nil { + t.Fatal(err) + } + r, err := re.Matches(t.Context(), 0, 0) + if err != nil { + t.Fatalf("Matches error: %v", err) + } + if r != test.exp { + t.Fatalf("Matches = %v, wants %v", r, test.exp) + } + }) + } +} + +func TestReplace2(t *testing.T) { + re := regex.CreateRegex(1024) + defer re.Close() + if err := re.SetRegexString(t.Context(), `[0-4]`, regex.RegexFlags_None); err != nil { + t.Fatal(err) + } + if err := re.SetMatchString(t.Context(), "0123456789"); err != nil { + t.Fatal(err) + } + + tests := map[int]struct { + pos int + occ int + exp string + }{ + 0: { + pos: 1, + occ: 0, + exp: "XXXXX56789", + }, + 1: { + pos: 2, + occ: 0, + exp: "0XXXX56789", + }, + 2: { + pos: 3, + occ: 2, + exp: "012X456789", + }, + } + + for _, test := range tests { + name := fmt.Sprintf("[%d,%d]", test.pos, test.occ) + t.Run(name, func(t *testing.T) { + r, err := re.Replace(t.Context(), "X", test.pos, test.occ) + if err != nil { + t.Error(err) + } + if r != test.exp { + t.Errorf("Replace = %q, wants %q", r, test.exp) + } + }) + } +} diff --git a/sql/expression/function/regexp_instr.go b/sql/expression/function/regexp_instr.go index 62111eb719..28ab91a114 100644 --- a/sql/expression/function/regexp_instr.go +++ b/sql/expression/function/regexp_instr.go @@ -19,8 +19,7 @@ import ( "strings" "sync" - regex "github.com/dolthub/go-icu-regex" - + "github.com/dolthub/go-mysql-server/internal/regex" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" diff --git a/sql/expression/function/regexp_like.go b/sql/expression/function/regexp_like.go index b53209b584..176ea33c96 100644 --- a/sql/expression/function/regexp_like.go +++ b/sql/expression/function/regexp_like.go @@ -19,9 +19,9 @@ import ( "strings" "sync" - regex "github.com/dolthub/go-icu-regex" "gopkg.in/src-d/go-errors.v1" + "github.com/dolthub/go-mysql-server/internal/regex" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" diff --git a/sql/expression/function/regexp_replace.go b/sql/expression/function/regexp_replace.go index 55b5bdefd0..eb763a4aa1 100644 --- a/sql/expression/function/regexp_replace.go +++ b/sql/expression/function/regexp_replace.go @@ -19,9 +19,9 @@ import ( "strings" "sync" - regex "github.com/dolthub/go-icu-regex" "gopkg.in/src-d/go-errors.v1" + "github.com/dolthub/go-mysql-server/internal/regex" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" diff --git a/sql/expression/function/regexp_substr.go b/sql/expression/function/regexp_substr.go index 24ddded1c2..6eee430f47 100644 --- a/sql/expression/function/regexp_substr.go +++ b/sql/expression/function/regexp_substr.go @@ -19,8 +19,7 @@ import ( "strings" "sync" - regex "github.com/dolthub/go-icu-regex" - + "github.com/dolthub/go-mysql-server/internal/regex" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types"