diff --git a/enginetest/engine_only_test.go b/enginetest/engine_only_test.go index ca4f005784..86aa287be1 100644 --- a/enginetest/engine_only_test.go +++ b/enginetest/engine_only_test.go @@ -582,6 +582,7 @@ func TestTableFunctions(t *testing.T) { &databaseProvider, SimpleTableFunction{}, memory.IntSequenceTable{}, + memory.LookupSequenceTable{}, memory.PointLookupTable{}, memory.TableFunc{}, memory.ExponentialDistTable{}, diff --git a/enginetest/queries/table_func_scripts.go b/enginetest/queries/table_func_scripts.go index ebf307fc94..82793eff31 100644 --- a/enginetest/queries/table_func_scripts.go +++ b/enginetest/queries/table_func_scripts.go @@ -146,27 +146,31 @@ var TableFunctionScriptTests = []ScriptTest{ Query: "select seq.x from sequence_table('x', 5) seq", Expected: []sql.Row{{0}, {1}, {2}, {3}, {4}}, }, + { + Query: "select x from sequence_table('x', 5) where exists (select y from sequence_table('y', 3) where x = y)", + Expected: []sql.Row{{0}, {1}, {2}}, + }, { Query: "select not_seq.x from sequence_table('x', 5) as seq", ExpectedErr: sql.ErrTableNotFound, }, { - Query: "select /*+ MERGE_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ seq1.x, seq2.y from sequence_table('x', 5) seq1 join sequence_table('y', 5) seq2 on seq1.x = seq2.y", + Query: "select /*+ MERGE_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ seq1.x, seq2.y from lookup_sequence_table('x', 5) seq1 join lookup_sequence_table('y', 5) seq2 on seq1.x = seq2.y", Expected: []sql.Row{{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}}, ExpectedIndexes: []string{"y", "x"}, }, { - Query: "select /*+ LOOKUP_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ seq1.x, seq2.y from sequence_table('x', 5) seq1 join sequence_table('y', 5) seq2 on seq1.x = seq2.y", + Query: "select /*+ LOOKUP_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ seq1.x, seq2.y from lookup_sequence_table('x', 5) seq1 join lookup_sequence_table('y', 5) seq2 on seq1.x = seq2.y", Expected: []sql.Row{{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}}, ExpectedIndexes: []string{"x"}, }, { - Query: "select /*+ MERGE_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ * from sequence_table('x', 5) seq1 join sequence_table('y', 5) seq2 on x = 0", + Query: "select /*+ MERGE_JOIN(seq1,seq2) JOIN_ORDER(seq2,seq1) */ * from lookup_sequence_table('x', 5) seq1 join lookup_sequence_table('y', 5) seq2 on x = 0", Expected: []sql.Row{{0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}}, ExpectedIndexes: []string{"x"}, }, { - Query: "select /*+ LOOKUP_JOIN(seq1,seq2) */ * from sequence_table('x', 5) seq1 join sequence_table('y', 5) seq2 on x = 0", + Query: "select /*+ LOOKUP_JOIN(seq1,seq2) */ * from lookup_sequence_table('x', 5) seq1 join lookup_sequence_table('y', 5) seq2 on x = 0", Expected: []sql.Row{{0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}}, ExpectedIndexes: []string{"x"}, }, @@ -187,14 +191,14 @@ var TableFunctionScriptTests = []ScriptTest{ Expected: []sql.Row{{0}, {1}, {2}, {3}, {4}}, }, { - Name: "sequence_table allows point lookups", - Query: "select * from sequence_table('x', 5) where x = 2", + Name: "lookup_sequence_table allows point lookups", + Query: "select * from lookup_sequence_table('x', 5) where x = 2", Expected: []sql.Row{{2}}, ExpectedIndexes: []string{"x"}, }, { - Name: "sequence_table allows range lookups", - Query: "select * from sequence_table('x', 5) where x >= 1 and x <= 3", + Name: "lookup_sequence_table allows range lookups", + Query: "select * from lookup_sequence_table('x', 5) where x >= 1 and x <= 3", Expected: []sql.Row{{1}, {2}, {3}}, ExpectedIndexes: []string{"x"}, }, diff --git a/memory/lookup_squence_table.go b/memory/lookup_squence_table.go new file mode 100644 index 0000000000..ce482662c5 --- /dev/null +++ b/memory/lookup_squence_table.go @@ -0,0 +1,157 @@ +package memory + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +var _ sql.TableFunction = LookupSequenceTable{} +var _ sql.CollationCoercible = LookupSequenceTable{} +var _ sql.ExecSourceRel = LookupSequenceTable{} +var _ sql.IndexAddressable = LookupSequenceTable{} +var _ sql.IndexedTable = LookupSequenceTable{} +var _ sql.TableNode = LookupSequenceTable{} + +// LookupSequenceTable is a variation of IntSequenceTable that supports lookups and implements sql.TableNode +type LookupSequenceTable struct { + IntSequenceTable +} + +func (s LookupSequenceTable) UnderlyingTable() sql.Table { + return s +} + +func (s LookupSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { + newIntSequenceTable, err := s.IntSequenceTable.NewInstance(ctx, db, args) + if err != nil { + return nil, err + } + return LookupSequenceTable{newIntSequenceTable.(IntSequenceTable)}, nil +} + +func (s LookupSequenceTable) String() string { + return fmt.Sprintf("sequence(%s, %d)", s.name, s.Len) +} + +func (s LookupSequenceTable) DebugString() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("sequence") + children := []string{ + fmt.Sprintf("name: %s", s.name), + fmt.Sprintf("len: %d", s.Len), + } + _ = pr.WriteChildren(children...) + return pr.String() +} + +func (s LookupSequenceTable) Schema() sql.Schema { + schema := []*sql.Column{ + { + DatabaseSource: s.db.Name(), + Source: s.Name(), + Name: s.name, + Type: types.Int64, + }, + } + + return schema +} + +func (s LookupSequenceTable) WithChildren(_ ...sql.Node) (sql.Node, error) { + return s, nil +} + +func (s LookupSequenceTable) WithExpressions(e ...sql.Expression) (sql.Node, error) { + return s, nil +} + +func (s LookupSequenceTable) WithDatabase(_ sql.Database) (sql.Node, error) { + return s, nil +} + +func (s LookupSequenceTable) Name() string { + return "lookup_sequence_table" +} + +func (s LookupSequenceTable) Description() string { + return "a integer sequence that supports lookup operations" +} + +// Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition. +func (s LookupSequenceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { + return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: int64(s.Len) - 1}), nil +} + +// PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition. +// This table has a partition for just schema changes, one for just data changes, and one for both. +func (s LookupSequenceTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { + sp, ok := partition.(*sequencePartition) + if !ok { + return &SequenceTableFnRowIter{i: 0, n: s.Len}, nil + } + min := int64(0) + if sp.min > min { + min = sp.min + } + max := int64(s.Len) - 1 + if sp.max < max { + max = sp.max + } + + return &SequenceTableFnRowIter{i: min, n: max + 1}, nil +} + +// LookupPartitions is a sql.IndexedTable interface function that takes an index lookup and returns the set of corresponding partitions. +func (s LookupSequenceTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) { + lowerBound := lookup.Ranges.(sql.MySQLRangeCollection)[0][0].LowerBound + below, ok := lowerBound.(sql.Below) + if !ok { + return s.Partitions(ctx) + } + upperBound := lookup.Ranges.(sql.MySQLRangeCollection)[0][0].UpperBound + above, ok := upperBound.(sql.Above) + if !ok { + return s.Partitions(ctx) + } + min, _, err := s.Schema()[0].Type.Convert(ctx, below.Key) + if err != nil { + return nil, err + } + max, _, err := s.Schema()[0].Type.Convert(ctx, above.Key) + if err != nil { + return nil, err + } + return sql.PartitionsToPartitionIter(&sequencePartition{min: min.(int64), max: max.(int64)}), nil +} + +func (s LookupSequenceTable) IndexedAccess(ctx *sql.Context, lookup sql.IndexLookup) sql.IndexedTable { + return s +} + +func (s LookupSequenceTable) PreciseMatch() bool { + return true +} + +func (s LookupSequenceTable) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { + return []sql.Index{ + &Index{ + DB: s.db.Name(), + DriverName: "", + Tbl: nil, + TableName: s.Name(), + Exprs: []sql.Expression{ + expression.NewGetFieldWithTable(0, 0, types.Int64, s.db.Name(), s.Name(), s.name, false), + }, + Name: s.name, + Unique: true, + Spatial: false, + Fulltext: false, + CommentStr: "", + PrefixLens: nil, + fulltextInfo: fulltextInfo{}, + }, + }, nil +} diff --git a/memory/point_lookup_table.go b/memory/point_lookup_table.go index 910bba37b7..8660da13c6 100644 --- a/memory/point_lookup_table.go +++ b/memory/point_lookup_table.go @@ -18,7 +18,7 @@ var _ sql.TableNode = PointLookupTable{} // PointLookupTable is a table whose indexes only support point lookups but not range scans. // It's used for testing optimizations on indexes. type PointLookupTable struct { - IntSequenceTable + LookupSequenceTable } func (s PointLookupTable) UnderlyingTable() sql.Table { @@ -26,8 +26,8 @@ func (s PointLookupTable) UnderlyingTable() sql.Table { } func (s PointLookupTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { - node, err := s.IntSequenceTable.NewInstance(ctx, db, args) - return PointLookupTable{node.(IntSequenceTable)}, err + node, err := s.LookupSequenceTable.NewInstance(ctx, db, args) + return PointLookupTable{node.(LookupSequenceTable)}, err } func (s PointLookupTable) String() string { diff --git a/memory/required_lookup_table.go b/memory/required_lookup_table.go index 60c556d951..bb7ba87cdd 100644 --- a/memory/required_lookup_table.go +++ b/memory/required_lookup_table.go @@ -18,7 +18,7 @@ var _ sql.IndexRequired = RequiredLookupTable{} // RequiredLookupTable is a table that will error if not executed as an index lookup type RequiredLookupTable struct { - IntSequenceTable + LookupSequenceTable indexOk bool } @@ -31,8 +31,8 @@ func (s RequiredLookupTable) UnderlyingTable() sql.Table { } func (s RequiredLookupTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { - node, err := s.IntSequenceTable.NewInstance(ctx, db, args) - return RequiredLookupTable{IntSequenceTable: node.(IntSequenceTable)}, err + node, err := s.LookupSequenceTable.NewInstance(ctx, db, args) + return RequiredLookupTable{LookupSequenceTable: node.(LookupSequenceTable)}, err } func (s RequiredLookupTable) String() string { @@ -74,7 +74,7 @@ func (s RequiredLookupTable) Database() sql.Database { } func (s RequiredLookupTable) IndexedAccess(ctx *sql.Context, lookup sql.IndexLookup) sql.IndexedTable { - return RequiredLookupTable{indexOk: true, IntSequenceTable: s.IntSequenceTable} + return RequiredLookupTable{indexOk: true, LookupSequenceTable: s.LookupSequenceTable} } func (s RequiredLookupTable) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIter, error) { @@ -90,14 +90,14 @@ func (s RequiredLookupTable) Partitions(ctx *sql.Context) (sql.PartitionIter, er if !s.indexOk { return nil, fmt.Errorf("table requires index lookup") } - return s.IntSequenceTable.Partitions(ctx) + return s.LookupSequenceTable.Partitions(ctx) } func (s RequiredLookupTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { if !s.indexOk { return nil, fmt.Errorf("table requires index lookup") } - return s.IntSequenceTable.PartitionRows(ctx, partition) + return s.LookupSequenceTable.PartitionRows(ctx, partition) } func (s RequiredLookupTable) GetIndexes(ctx *sql.Context) (indexes []sql.Index, err error) { diff --git a/memory/sequence_table.go b/memory/sequence_table.go index d30dd4afda..c7ce06ab27 100644 --- a/memory/sequence_table.go +++ b/memory/sequence_table.go @@ -13,9 +13,6 @@ import ( var _ sql.TableFunction = IntSequenceTable{} var _ sql.CollationCoercible = IntSequenceTable{} var _ sql.ExecSourceRel = IntSequenceTable{} -var _ sql.IndexAddressable = IntSequenceTable{} -var _ sql.IndexedTable = IntSequenceTable{} -var _ sql.TableNode = IntSequenceTable{} // IntSequenceTable a simple table function that returns a sequence // of integers. @@ -25,10 +22,6 @@ type IntSequenceTable struct { Len int64 } -func (s IntSequenceTable) UnderlyingTable() sql.Table { - return s -} - func (s IntSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { if len(args) != 2 { return nil, fmt.Errorf("sequence table expects 2 arguments: (name, len)") @@ -164,79 +157,3 @@ type sequencePartition struct { func (s sequencePartition) Key() []byte { return binary.LittleEndian.AppendUint64(binary.LittleEndian.AppendUint64(nil, uint64(s.min)), uint64(s.max)) } - -// Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition. -func (s IntSequenceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { - return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: int64(s.Len) - 1}), nil -} - -// PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition. -// This table has a partition for just schema changes, one for just data changes, and one for both. -func (s IntSequenceTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { - sp, ok := partition.(*sequencePartition) - if !ok { - return &SequenceTableFnRowIter{i: 0, n: s.Len}, nil - } - min := int64(0) - if sp.min > min { - min = sp.min - } - max := int64(s.Len) - 1 - if sp.max < max { - max = sp.max - } - - return &SequenceTableFnRowIter{i: min, n: max + 1}, nil -} - -// LookupPartitions is a sql.IndexedTable interface function that takes an index lookup and returns the set of corresponding partitions. -func (s IntSequenceTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) { - lowerBound := lookup.Ranges.(sql.MySQLRangeCollection)[0][0].LowerBound - below, ok := lowerBound.(sql.Below) - if !ok { - return s.Partitions(ctx) - } - upperBound := lookup.Ranges.(sql.MySQLRangeCollection)[0][0].UpperBound - above, ok := upperBound.(sql.Above) - if !ok { - return s.Partitions(ctx) - } - min, _, err := s.Schema()[0].Type.Convert(ctx, below.Key) - if err != nil { - return nil, err - } - max, _, err := s.Schema()[0].Type.Convert(ctx, above.Key) - if err != nil { - return nil, err - } - return sql.PartitionsToPartitionIter(&sequencePartition{min: min.(int64), max: max.(int64)}), nil -} - -func (s IntSequenceTable) IndexedAccess(ctx *sql.Context, lookup sql.IndexLookup) sql.IndexedTable { - return s -} - -func (s IntSequenceTable) PreciseMatch() bool { - return true -} - -func (s IntSequenceTable) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { - return []sql.Index{ - &Index{ - DB: s.db.Name(), - DriverName: "", - Tbl: nil, - TableName: s.Name(), - Exprs: []sql.Expression{ - expression.NewGetFieldWithTable(0, 0, types.Int64, s.db.Name(), s.Name(), s.name, false), - }, - Name: s.name, - Unique: true, - Spatial: false, - Fulltext: false, - CommentStr: "", - PrefixLens: nil, - fulltextInfo: fulltextInfo{}, - }, - }, nil -} diff --git a/sql/analyzer/indexed_joins.go b/sql/analyzer/indexed_joins.go index d701726dc9..2a382af9ef 100644 --- a/sql/analyzer/indexed_joins.go +++ b/sql/analyzer/indexed_joins.go @@ -687,7 +687,10 @@ func addRightSemiJoins(ctx *sql.Context, m *memo.Memo) error { switch n := leftTab.(type) { case *plan.TableAlias: aliasName = n.Name() - leftRt = n.Child.(sql.TableNode) + leftRt, ok = n.Child.(sql.TableNode) + if !ok { + return nil + } case sql.TableNode: leftRt = n }