Skip to content

Commit 33ada9e

Browse files
authored
Merge pull request #1715 from dolthub/james/find
implement `find_in_set`
2 parents 487575e + 9598a36 commit 33ada9e

File tree

6 files changed

+357
-1
lines changed

6 files changed

+357
-1
lines changed

enginetest/enginetests.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,6 @@ func TestConvertPrepared(t *testing.T, harness Harness) {
11261126
TestPreparedQuery(t, harness, query, []sql.Row{{tt.ExpCnt}}, nil)
11271127
})
11281128
}
1129-
11301129
}
11311130

11321131
func TestScripts(t *testing.T, harness Harness) {

enginetest/queries/queries.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7500,6 +7500,28 @@ SELECT * FROM my_cte;`,
75007500
{4},
75017501
},
75027502
},
7503+
{
7504+
Query: "select find_in_set('second row', s) from mytable;",
7505+
Expected: []sql.Row{
7506+
{0},
7507+
{1},
7508+
{0},
7509+
},
7510+
},
7511+
{
7512+
Query: "select find_in_set(s, 'first row,second row,third row') from mytable;",
7513+
Expected: []sql.Row{
7514+
{1},
7515+
{2},
7516+
{3},
7517+
},
7518+
},
7519+
{
7520+
Query: "select i from mytable where find_in_set(s, 'first row,second row,third row') = 2;",
7521+
Expected: []sql.Row{
7522+
{2},
7523+
},
7524+
},
75037525
}
75047526

75057527
var KeylessQueries = []QueryTest{

enginetest/queries/script_queries.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2902,6 +2902,70 @@ var ScriptTests = []ScriptTest{
29022902
},
29032903
},
29042904
},
2905+
{
2906+
Name: "find_in_set tests",
2907+
SetUpScript: []string{
2908+
"create table set_tbl (i int primary key, s set('a','b','c'));",
2909+
"insert into set_tbl values (0, '');",
2910+
"insert into set_tbl values (1, 'a');",
2911+
"insert into set_tbl values (2, 'b');",
2912+
"insert into set_tbl values (3, 'c');",
2913+
"insert into set_tbl values (4, 'a,b');",
2914+
"insert into set_tbl values (6, 'b,c');",
2915+
"insert into set_tbl values (7, 'a,c');",
2916+
"insert into set_tbl values (8, 'a,b,c');",
2917+
2918+
"create table collate_tbl (i int primary key, s varchar(10) collate utf8mb4_0900_ai_ci);",
2919+
"insert into collate_tbl values (0, '');",
2920+
"insert into collate_tbl values (1, 'a');",
2921+
"insert into collate_tbl values (2, 'b');",
2922+
"insert into collate_tbl values (3, 'c');",
2923+
"insert into collate_tbl values (4, 'a,b');",
2924+
"insert into collate_tbl values (6, 'b,c');",
2925+
"insert into collate_tbl values (7, 'a,c');",
2926+
"insert into collate_tbl values (8, 'a,b,c');",
2927+
2928+
"create table enum_tbl (i int primary key, s enum('a','b','c'));",
2929+
"insert into enum_tbl values (0, 'a'), (1, 'b'), (2, 'c');",
2930+
"select i, s, find_in_set('a', s) from enum_tbl;",
2931+
},
2932+
Assertions: []ScriptTestAssertion{
2933+
{
2934+
Query: "select i, find_in_set('a', s) from set_tbl;",
2935+
Expected: []sql.Row{
2936+
{0, 0},
2937+
{1, 1},
2938+
{2, 0},
2939+
{3, 0},
2940+
{4, 1},
2941+
{6, 0},
2942+
{7, 1},
2943+
{8, 1},
2944+
},
2945+
},
2946+
{
2947+
Query: "select i, find_in_set('A', s) from collate_tbl;",
2948+
Expected: []sql.Row{
2949+
{0, 0},
2950+
{1, 1},
2951+
{2, 0},
2952+
{3, 0},
2953+
{4, 1},
2954+
{6, 0},
2955+
{7, 1},
2956+
{8, 1},
2957+
},
2958+
},
2959+
{
2960+
Query: "select i, find_in_set('a', s) from enum_tbl;",
2961+
Expected: []sql.Row{
2962+
{0, 1},
2963+
{1, 0},
2964+
{2, 0},
2965+
},
2966+
},
2967+
},
2968+
},
29052969
}
29062970

29072971
var SpatialScriptTests = []ScriptTest{
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"fmt"
19+
"strings"
20+
21+
"github.com/dolthub/go-mysql-server/sql"
22+
"github.com/dolthub/go-mysql-server/sql/expression"
23+
"github.com/dolthub/go-mysql-server/sql/types"
24+
)
25+
26+
// FindInSet takes out the specified unit(s) from the time expression.
27+
type FindInSet struct {
28+
expression.BinaryExpression
29+
}
30+
31+
var _ sql.FunctionExpression = (*FindInSet)(nil)
32+
var _ sql.CollationCoercible = (*FindInSet)(nil)
33+
34+
// NewFindInSet creates a new FindInSet expression.
35+
func NewFindInSet(e1, e2 sql.Expression) sql.Expression {
36+
return &FindInSet{
37+
expression.BinaryExpression{
38+
Left: e1,
39+
Right: e2,
40+
},
41+
}
42+
}
43+
44+
// FunctionName implements sql.FunctionExpression
45+
func (f *FindInSet) FunctionName() string {
46+
return "find_in_set"
47+
}
48+
49+
// Description implements sql.FunctionExpression
50+
func (f *FindInSet) Description() string {
51+
return "returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"
52+
}
53+
54+
// Type implements the Expression interface.
55+
func (f *FindInSet) Type() sql.Type { return types.Int64 }
56+
57+
// CollationCoercibility implements the interface sql.CollationCoercible.
58+
func (*FindInSet) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
59+
return ctx.GetCollation(), 5
60+
}
61+
62+
func (f *FindInSet) String() string {
63+
return fmt.Sprintf("%s(%s from %s)", f.FunctionName(), f.Left, f.Right)
64+
}
65+
66+
// WithChildren implements the Expression interface.
67+
func (f *FindInSet) WithChildren(children ...sql.Expression) (sql.Expression, error) {
68+
if len(children) != 2 {
69+
return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2)
70+
}
71+
return NewFindInSet(children[0], children[1]), nil
72+
}
73+
74+
// Eval implements the Expression interface.
75+
func (f *FindInSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
76+
if f.Left == nil || f.Right == nil {
77+
return nil, nil
78+
}
79+
80+
left, err := f.Left.Eval(ctx, row)
81+
if err != nil {
82+
return nil, err
83+
}
84+
85+
right, err := f.Right.Eval(ctx, row)
86+
if err != nil {
87+
return nil, err
88+
}
89+
90+
if left == nil || right == nil {
91+
return nil, nil
92+
}
93+
94+
lVal, _, err := types.LongText.Convert(left)
95+
if err != nil {
96+
return nil, err
97+
}
98+
l := lVal.(string)
99+
100+
// always returns 0 when left contains a comma
101+
if strings.Contains(l, ",") {
102+
return 0, nil
103+
}
104+
105+
var r string
106+
rType := f.Right.Type()
107+
if setType, ok := rType.(types.SetType); ok {
108+
// TODO: set type should take advantage of bit arithmetic
109+
r, err = setType.BitsToString(right.(uint64))
110+
if err != nil {
111+
return nil, err
112+
}
113+
} else if enumType, ok := rType.(types.EnumType); ok {
114+
r, ok = enumType.At(int(right.(uint16)))
115+
if !ok {
116+
return nil, fmt.Errorf("enum missing index %v", r)
117+
}
118+
} else {
119+
var rVal interface{}
120+
rVal, _, err = types.LongText.Convert(right)
121+
if err != nil {
122+
return nil, err
123+
}
124+
r = rVal.(string)
125+
}
126+
127+
leftColl, leftCoer := sql.GetCoercibility(ctx, f.Left)
128+
rightColl, rightCoer := sql.GetCoercibility(ctx, f.Right)
129+
collPref, _ := sql.ResolveCoercibility(leftColl, leftCoer, rightColl, rightCoer)
130+
131+
strType := types.CreateLongText(collPref)
132+
for i, r := range strings.Split(r, ",") {
133+
cmp, err := strType.Compare(l, r)
134+
if err != nil {
135+
return nil, err
136+
}
137+
if cmp == 0 {
138+
return i + 1, nil
139+
}
140+
}
141+
142+
return 0, nil
143+
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"testing"
19+
20+
"github.com/stretchr/testify/require"
21+
22+
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/expression"
24+
"github.com/dolthub/go-mysql-server/sql/types"
25+
)
26+
27+
func TestFindInSet(t *testing.T) {
28+
testCases := []struct {
29+
name string
30+
left string
31+
right string
32+
expected int
33+
skip bool
34+
}{
35+
{
36+
name: "string exists",
37+
left: "b",
38+
right: "a,b,c",
39+
expected: 2,
40+
},
41+
{
42+
name: "string does not exist",
43+
left: "abc",
44+
right: "a,b,c",
45+
expected: 0,
46+
},
47+
{
48+
name: "whitespace not removed",
49+
left: " b ",
50+
right: "a,b,c",
51+
expected: 0,
52+
},
53+
{
54+
name: "whitespace not removed 2",
55+
left: "b",
56+
right: " a , b , c ",
57+
expected: 0,
58+
},
59+
{
60+
name: "whitespace not removed 3",
61+
left: " a b ",
62+
right: "a, a b ,c",
63+
expected: 2,
64+
},
65+
{
66+
name: "comma bad",
67+
left: "b,",
68+
right: "a,b,c",
69+
expected: 0,
70+
},
71+
{
72+
name: "special characters ok",
73+
74+
75+
expected: 3,
76+
},
77+
{
78+
name: "look for empty string",
79+
left: "",
80+
right: "a,",
81+
expected: 2,
82+
},
83+
{
84+
name: "look in empty string",
85+
left: "a",
86+
right: "",
87+
expected: 0,
88+
},
89+
}
90+
91+
for _, tt := range testCases {
92+
t.Run(tt.name, func(t *testing.T) {
93+
if tt.skip {
94+
t.Skip()
95+
}
96+
require := require.New(t)
97+
f := NewFindInSet(expression.NewLiteral(tt.left, types.LongText), expression.NewLiteral(tt.right, types.LongText))
98+
v, err := f.Eval(sql.NewEmptyContext(), nil)
99+
require.NoError(err)
100+
require.Equal(tt.expected, v)
101+
})
102+
}
103+
104+
t.Run("test find in null set", func(t *testing.T) {
105+
require := require.New(t)
106+
f := NewFindInSet(expression.NewLiteral("a", types.LongText), expression.NewLiteral(nil, types.Null))
107+
v, err := f.Eval(sql.NewEmptyContext(), nil)
108+
require.NoError(err)
109+
require.Equal(nil, v)
110+
})
111+
112+
t.Run("find null in set", func(t *testing.T) {
113+
require := require.New(t)
114+
f := NewFindInSet(expression.NewLiteral("a", types.LongText), expression.NewLiteral(nil, types.Null))
115+
v, err := f.Eval(sql.NewEmptyContext(), nil)
116+
require.NoError(err)
117+
require.Equal(nil, v)
118+
})
119+
120+
t.Run("find number", func(t *testing.T) {
121+
require := require.New(t)
122+
f := NewFindInSet(expression.NewLiteral(500, types.Int64), expression.NewLiteral("1,2,3,500", types.Null))
123+
v, err := f.Eval(sql.NewEmptyContext(), nil)
124+
require.NoError(err)
125+
require.Equal(4, v)
126+
})
127+
}

sql/expression/function/registry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ var BuiltIns = []sql.Function{
8181
sql.Function1{Name: "dayofyear", Fn: NewDayOfYear},
8282
sql.Function1{Name: "degrees", Fn: NewDegrees},
8383
sql.Function2{Name: "extract", Fn: NewExtract},
84+
sql.Function2{Name: "find_in_set", Fn: NewFindInSet},
8485
sql.Function1{Name: "first", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewFirst(e) }},
8586
sql.Function1{Name: "floor", Fn: NewFloor},
8687
sql.Function0{Name: "found_rows", Fn: NewFoundRows},

0 commit comments

Comments
 (0)