Skip to content

Commit e128bf9

Browse files
authored
Merge pull request #3299 from dolthub/nicktobey/where_panic
Prevent panic when using table functions in joins and subqueries
2 parents e0601da + 5f7870e commit e128bf9

File tree

7 files changed

+183
-101
lines changed

7 files changed

+183
-101
lines changed

enginetest/engine_only_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ func TestTableFunctions(t *testing.T) {
582582
&databaseProvider,
583583
SimpleTableFunction{},
584584
memory.IntSequenceTable{},
585+
memory.LookupSequenceTable{},
585586
memory.PointLookupTable{},
586587
memory.TableFunc{},
587588
memory.ExponentialDistTable{},

enginetest/queries/table_func_scripts.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,27 +146,31 @@ var TableFunctionScriptTests = []ScriptTest{
146146
Query: "select seq.x from sequence_table('x', 5) seq",
147147
Expected: []sql.Row{{0}, {1}, {2}, {3}, {4}},
148148
},
149+
{
150+
Query: "select x from sequence_table('x', 5) where exists (select y from sequence_table('y', 3) where x = y)",
151+
Expected: []sql.Row{{0}, {1}, {2}},
152+
},
149153
{
150154
Query: "select not_seq.x from sequence_table('x', 5) as seq",
151155
ExpectedErr: sql.ErrTableNotFound,
152156
},
153157
{
154-
Query: "select /*+ MERGE_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ seq1.x, seq2.y from sequence_table('x', 5) seq1 join sequence_table('y', 5) seq2 on seq1.x = seq2.y",
158+
Query: "select /*+ MERGE_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ seq1.x, seq2.y from lookup_sequence_table('x', 5) seq1 join lookup_sequence_table('y', 5) seq2 on seq1.x = seq2.y",
155159
Expected: []sql.Row{{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}},
156160
ExpectedIndexes: []string{"y", "x"},
157161
},
158162
{
159-
Query: "select /*+ LOOKUP_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ seq1.x, seq2.y from sequence_table('x', 5) seq1 join sequence_table('y', 5) seq2 on seq1.x = seq2.y",
163+
Query: "select /*+ LOOKUP_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ seq1.x, seq2.y from lookup_sequence_table('x', 5) seq1 join lookup_sequence_table('y', 5) seq2 on seq1.x = seq2.y",
160164
Expected: []sql.Row{{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}},
161165
ExpectedIndexes: []string{"x"},
162166
},
163167
{
164-
Query: "select /*+ MERGE_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ * from sequence_table('x', 5) seq1 join sequence_table('y', 5) seq2 on x = 0",
168+
Query: "select /*+ MERGE_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ * from lookup_sequence_table('x', 5) seq1 join lookup_sequence_table('y', 5) seq2 on x = 0",
165169
Expected: []sql.Row{{0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}},
166170
ExpectedIndexes: []string{"x"},
167171
},
168172
{
169-
Query: "select /*+ LOOKUP_JOIN(seq1,seq2) */ * from sequence_table('x', 5) seq1 join sequence_table('y', 5) seq2 on x = 0",
173+
Query: "select /*+ LOOKUP_JOIN(seq1,seq2) */ * from lookup_sequence_table('x', 5) seq1 join lookup_sequence_table('y', 5) seq2 on x = 0",
170174
Expected: []sql.Row{{0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}},
171175
ExpectedIndexes: []string{"x"},
172176
},
@@ -187,14 +191,14 @@ var TableFunctionScriptTests = []ScriptTest{
187191
Expected: []sql.Row{{0}, {1}, {2}, {3}, {4}},
188192
},
189193
{
190-
Name: "sequence_table allows point lookups",
191-
Query: "select * from sequence_table('x', 5) where x = 2",
194+
Name: "lookup_sequence_table allows point lookups",
195+
Query: "select * from lookup_sequence_table('x', 5) where x = 2",
192196
Expected: []sql.Row{{2}},
193197
ExpectedIndexes: []string{"x"},
194198
},
195199
{
196-
Name: "sequence_table allows range lookups",
197-
Query: "select * from sequence_table('x', 5) where x >= 1 and x <= 3",
200+
Name: "lookup_sequence_table allows range lookups",
201+
Query: "select * from lookup_sequence_table('x', 5) where x >= 1 and x <= 3",
198202
Expected: []sql.Row{{1}, {2}, {3}},
199203
ExpectedIndexes: []string{"x"},
200204
},

memory/lookup_squence_table.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package memory
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/dolthub/go-mysql-server/sql"
7+
"github.com/dolthub/go-mysql-server/sql/expression"
8+
"github.com/dolthub/go-mysql-server/sql/types"
9+
)
10+
11+
var _ sql.TableFunction = LookupSequenceTable{}
12+
var _ sql.CollationCoercible = LookupSequenceTable{}
13+
var _ sql.ExecSourceRel = LookupSequenceTable{}
14+
var _ sql.IndexAddressable = LookupSequenceTable{}
15+
var _ sql.IndexedTable = LookupSequenceTable{}
16+
var _ sql.TableNode = LookupSequenceTable{}
17+
18+
// LookupSequenceTable is a variation of IntSequenceTable that supports lookups and implements sql.TableNode
19+
type LookupSequenceTable struct {
20+
IntSequenceTable
21+
}
22+
23+
func (s LookupSequenceTable) UnderlyingTable() sql.Table {
24+
return s
25+
}
26+
27+
func (s LookupSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
28+
newIntSequenceTable, err := s.IntSequenceTable.NewInstance(ctx, db, args)
29+
if err != nil {
30+
return nil, err
31+
}
32+
return LookupSequenceTable{newIntSequenceTable.(IntSequenceTable)}, nil
33+
}
34+
35+
func (s LookupSequenceTable) String() string {
36+
return fmt.Sprintf("sequence(%s, %d)", s.name, s.Len)
37+
}
38+
39+
func (s LookupSequenceTable) DebugString() string {
40+
pr := sql.NewTreePrinter()
41+
_ = pr.WriteNode("sequence")
42+
children := []string{
43+
fmt.Sprintf("name: %s", s.name),
44+
fmt.Sprintf("len: %d", s.Len),
45+
}
46+
_ = pr.WriteChildren(children...)
47+
return pr.String()
48+
}
49+
50+
func (s LookupSequenceTable) Schema() sql.Schema {
51+
schema := []*sql.Column{
52+
{
53+
DatabaseSource: s.db.Name(),
54+
Source: s.Name(),
55+
Name: s.name,
56+
Type: types.Int64,
57+
},
58+
}
59+
60+
return schema
61+
}
62+
63+
func (s LookupSequenceTable) WithChildren(_ ...sql.Node) (sql.Node, error) {
64+
return s, nil
65+
}
66+
67+
func (s LookupSequenceTable) WithExpressions(e ...sql.Expression) (sql.Node, error) {
68+
return s, nil
69+
}
70+
71+
func (s LookupSequenceTable) WithDatabase(_ sql.Database) (sql.Node, error) {
72+
return s, nil
73+
}
74+
75+
func (s LookupSequenceTable) Name() string {
76+
return "lookup_sequence_table"
77+
}
78+
79+
func (s LookupSequenceTable) Description() string {
80+
return "a integer sequence that supports lookup operations"
81+
}
82+
83+
// Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition.
84+
func (s LookupSequenceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
85+
return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: int64(s.Len) - 1}), nil
86+
}
87+
88+
// PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition.
89+
// This table has a partition for just schema changes, one for just data changes, and one for both.
90+
func (s LookupSequenceTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
91+
sp, ok := partition.(*sequencePartition)
92+
if !ok {
93+
return &SequenceTableFnRowIter{i: 0, n: s.Len}, nil
94+
}
95+
min := int64(0)
96+
if sp.min > min {
97+
min = sp.min
98+
}
99+
max := int64(s.Len) - 1
100+
if sp.max < max {
101+
max = sp.max
102+
}
103+
104+
return &SequenceTableFnRowIter{i: min, n: max + 1}, nil
105+
}
106+
107+
// LookupPartitions is a sql.IndexedTable interface function that takes an index lookup and returns the set of corresponding partitions.
108+
func (s LookupSequenceTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
109+
lowerBound := lookup.Ranges.(sql.MySQLRangeCollection)[0][0].LowerBound
110+
below, ok := lowerBound.(sql.Below)
111+
if !ok {
112+
return s.Partitions(ctx)
113+
}
114+
upperBound := lookup.Ranges.(sql.MySQLRangeCollection)[0][0].UpperBound
115+
above, ok := upperBound.(sql.Above)
116+
if !ok {
117+
return s.Partitions(ctx)
118+
}
119+
min, _, err := s.Schema()[0].Type.Convert(ctx, below.Key)
120+
if err != nil {
121+
return nil, err
122+
}
123+
max, _, err := s.Schema()[0].Type.Convert(ctx, above.Key)
124+
if err != nil {
125+
return nil, err
126+
}
127+
return sql.PartitionsToPartitionIter(&sequencePartition{min: min.(int64), max: max.(int64)}), nil
128+
}
129+
130+
func (s LookupSequenceTable) IndexedAccess(ctx *sql.Context, lookup sql.IndexLookup) sql.IndexedTable {
131+
return s
132+
}
133+
134+
func (s LookupSequenceTable) PreciseMatch() bool {
135+
return true
136+
}
137+
138+
func (s LookupSequenceTable) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
139+
return []sql.Index{
140+
&Index{
141+
DB: s.db.Name(),
142+
DriverName: "",
143+
Tbl: nil,
144+
TableName: s.Name(),
145+
Exprs: []sql.Expression{
146+
expression.NewGetFieldWithTable(0, 0, types.Int64, s.db.Name(), s.Name(), s.name, false),
147+
},
148+
Name: s.name,
149+
Unique: true,
150+
Spatial: false,
151+
Fulltext: false,
152+
CommentStr: "",
153+
PrefixLens: nil,
154+
fulltextInfo: fulltextInfo{},
155+
},
156+
}, nil
157+
}

memory/point_lookup_table.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ var _ sql.TableNode = PointLookupTable{}
1818
// PointLookupTable is a table whose indexes only support point lookups but not range scans.
1919
// It's used for testing optimizations on indexes.
2020
type PointLookupTable struct {
21-
IntSequenceTable
21+
LookupSequenceTable
2222
}
2323

2424
func (s PointLookupTable) UnderlyingTable() sql.Table {
2525
return s
2626
}
2727

2828
func (s PointLookupTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
29-
node, err := s.IntSequenceTable.NewInstance(ctx, db, args)
30-
return PointLookupTable{node.(IntSequenceTable)}, err
29+
node, err := s.LookupSequenceTable.NewInstance(ctx, db, args)
30+
return PointLookupTable{node.(LookupSequenceTable)}, err
3131
}
3232

3333
func (s PointLookupTable) String() string {

memory/required_lookup_table.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ var _ sql.IndexRequired = RequiredLookupTable{}
1818

1919
// RequiredLookupTable is a table that will error if not executed as an index lookup
2020
type RequiredLookupTable struct {
21-
IntSequenceTable
21+
LookupSequenceTable
2222
indexOk bool
2323
}
2424

@@ -31,8 +31,8 @@ func (s RequiredLookupTable) UnderlyingTable() sql.Table {
3131
}
3232

3333
func (s RequiredLookupTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
34-
node, err := s.IntSequenceTable.NewInstance(ctx, db, args)
35-
return RequiredLookupTable{IntSequenceTable: node.(IntSequenceTable)}, err
34+
node, err := s.LookupSequenceTable.NewInstance(ctx, db, args)
35+
return RequiredLookupTable{LookupSequenceTable: node.(LookupSequenceTable)}, err
3636
}
3737

3838
func (s RequiredLookupTable) String() string {
@@ -74,7 +74,7 @@ func (s RequiredLookupTable) Database() sql.Database {
7474
}
7575

7676
func (s RequiredLookupTable) IndexedAccess(ctx *sql.Context, lookup sql.IndexLookup) sql.IndexedTable {
77-
return RequiredLookupTable{indexOk: true, IntSequenceTable: s.IntSequenceTable}
77+
return RequiredLookupTable{indexOk: true, LookupSequenceTable: s.LookupSequenceTable}
7878
}
7979

8080
func (s RequiredLookupTable) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIter, error) {
@@ -90,14 +90,14 @@ func (s RequiredLookupTable) Partitions(ctx *sql.Context) (sql.PartitionIter, er
9090
if !s.indexOk {
9191
return nil, fmt.Errorf("table requires index lookup")
9292
}
93-
return s.IntSequenceTable.Partitions(ctx)
93+
return s.LookupSequenceTable.Partitions(ctx)
9494
}
9595

9696
func (s RequiredLookupTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
9797
if !s.indexOk {
9898
return nil, fmt.Errorf("table requires index lookup")
9999
}
100-
return s.IntSequenceTable.PartitionRows(ctx, partition)
100+
return s.LookupSequenceTable.PartitionRows(ctx, partition)
101101
}
102102

103103
func (s RequiredLookupTable) GetIndexes(ctx *sql.Context) (indexes []sql.Index, err error) {

memory/sequence_table.go

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ import (
1313
var _ sql.TableFunction = IntSequenceTable{}
1414
var _ sql.CollationCoercible = IntSequenceTable{}
1515
var _ sql.ExecSourceRel = IntSequenceTable{}
16-
var _ sql.IndexAddressable = IntSequenceTable{}
17-
var _ sql.IndexedTable = IntSequenceTable{}
18-
var _ sql.TableNode = IntSequenceTable{}
1916

2017
// IntSequenceTable a simple table function that returns a sequence
2118
// of integers.
@@ -25,10 +22,6 @@ type IntSequenceTable struct {
2522
Len int64
2623
}
2724

28-
func (s IntSequenceTable) UnderlyingTable() sql.Table {
29-
return s
30-
}
31-
3225
func (s IntSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
3326
if len(args) != 2 {
3427
return nil, fmt.Errorf("sequence table expects 2 arguments: (name, len)")
@@ -164,79 +157,3 @@ type sequencePartition struct {
164157
func (s sequencePartition) Key() []byte {
165158
return binary.LittleEndian.AppendUint64(binary.LittleEndian.AppendUint64(nil, uint64(s.min)), uint64(s.max))
166159
}
167-
168-
// Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition.
169-
func (s IntSequenceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
170-
return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: int64(s.Len) - 1}), nil
171-
}
172-
173-
// PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition.
174-
// This table has a partition for just schema changes, one for just data changes, and one for both.
175-
func (s IntSequenceTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
176-
sp, ok := partition.(*sequencePartition)
177-
if !ok {
178-
return &SequenceTableFnRowIter{i: 0, n: s.Len}, nil
179-
}
180-
min := int64(0)
181-
if sp.min > min {
182-
min = sp.min
183-
}
184-
max := int64(s.Len) - 1
185-
if sp.max < max {
186-
max = sp.max
187-
}
188-
189-
return &SequenceTableFnRowIter{i: min, n: max + 1}, nil
190-
}
191-
192-
// LookupPartitions is a sql.IndexedTable interface function that takes an index lookup and returns the set of corresponding partitions.
193-
func (s IntSequenceTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
194-
lowerBound := lookup.Ranges.(sql.MySQLRangeCollection)[0][0].LowerBound
195-
below, ok := lowerBound.(sql.Below)
196-
if !ok {
197-
return s.Partitions(ctx)
198-
}
199-
upperBound := lookup.Ranges.(sql.MySQLRangeCollection)[0][0].UpperBound
200-
above, ok := upperBound.(sql.Above)
201-
if !ok {
202-
return s.Partitions(ctx)
203-
}
204-
min, _, err := s.Schema()[0].Type.Convert(ctx, below.Key)
205-
if err != nil {
206-
return nil, err
207-
}
208-
max, _, err := s.Schema()[0].Type.Convert(ctx, above.Key)
209-
if err != nil {
210-
return nil, err
211-
}
212-
return sql.PartitionsToPartitionIter(&sequencePartition{min: min.(int64), max: max.(int64)}), nil
213-
}
214-
215-
func (s IntSequenceTable) IndexedAccess(ctx *sql.Context, lookup sql.IndexLookup) sql.IndexedTable {
216-
return s
217-
}
218-
219-
func (s IntSequenceTable) PreciseMatch() bool {
220-
return true
221-
}
222-
223-
func (s IntSequenceTable) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
224-
return []sql.Index{
225-
&Index{
226-
DB: s.db.Name(),
227-
DriverName: "",
228-
Tbl: nil,
229-
TableName: s.Name(),
230-
Exprs: []sql.Expression{
231-
expression.NewGetFieldWithTable(0, 0, types.Int64, s.db.Name(), s.Name(), s.name, false),
232-
},
233-
Name: s.name,
234-
Unique: true,
235-
Spatial: false,
236-
Fulltext: false,
237-
CommentStr: "",
238-
PrefixLens: nil,
239-
fulltextInfo: fulltextInfo{},
240-
},
241-
}, nil
242-
}

sql/analyzer/indexed_joins.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,10 @@ func addRightSemiJoins(ctx *sql.Context, m *memo.Memo) error {
687687
switch n := leftTab.(type) {
688688
case *plan.TableAlias:
689689
aliasName = n.Name()
690-
leftRt = n.Child.(sql.TableNode)
690+
leftRt, ok = n.Child.(sql.TableNode)
691+
if !ok {
692+
return nil
693+
}
691694
case sql.TableNode:
692695
leftRt = n
693696
}

0 commit comments

Comments
 (0)