Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8178,6 +8178,12 @@ ORDER BY 1;`,
{3},
},
},
{
Query: "SELECT * FROM xy JOIN LATERAL (SELECT * FROM uv WHERE xy.x+1 = uv.u) uv2",
Expected: []sql.Row{
{0, 2, 1, 1}, {1, 0, 2, 2}, {2, 1, 3, 2},
},
},
{
Query: `
select * from mytable,
Expand Down
8 changes: 8 additions & 0 deletions enginetest/queries/table_func_scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ var TableFunctionScriptTests = []ScriptTest{
Query: "select x from sequence_table('x', 5) where exists (select y from sequence_table('y', 3) where x = y)",
Expected: []sql.Row{{0}, {1}, {2}},
},
{
Query: "select * from sequence_table('x', 3) l join lateral (select * from sequence_table('y', l.x)) r",
Expected: []sql.Row{{1, 0}, {2, 0}, {2, 1}},
},
{
Query: "select * from sequence_table('x', 3) l where exists (select * from sequence_table('y', l.x))",
Expected: []sql.Row{{1}, {2}},
},
{
Query: "select not_seq.x from sequence_table('x', 5) as seq",
ExpectedErr: sql.ErrTableNotFound,
Expand Down
17 changes: 13 additions & 4 deletions memory/lookup_squence_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ var _ sql.TableNode = LookupSequenceTable{}
// LookupSequenceTable is a variation of IntSequenceTable that supports lookups and implements sql.TableNode
type LookupSequenceTable struct {
IntSequenceTable
length int64
}

func (s LookupSequenceTable) UnderlyingTable() sql.Table {
Expand All @@ -29,7 +30,15 @@ func (s LookupSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args
if err != nil {
return nil, err
}
return LookupSequenceTable{newIntSequenceTable.(IntSequenceTable)}, nil
lenExp, ok := args[1].(*expression.Literal)
if !ok {
return nil, fmt.Errorf("sequence table expects arguments to be literal expressions")
}
length, _, err := types.Int64.Convert(ctx, lenExp.Value())
if err != nil {
return nil, err
}
return LookupSequenceTable{newIntSequenceTable.(IntSequenceTable), length.(int64)}, nil
}

func (s LookupSequenceTable) String() string {
Expand Down Expand Up @@ -82,21 +91,21 @@ func (s LookupSequenceTable) Description() string {

// Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition.
func (s LookupSequenceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: int64(s.Len) - 1}), nil
return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: s.length - 1}), nil
}

// PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition.
// This table has a partition for just schema changes, one for just data changes, and one for both.
func (s LookupSequenceTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
sp, ok := partition.(*sequencePartition)
if !ok {
return &SequenceTableFnRowIter{i: 0, n: s.Len}, nil
return &SequenceTableFnRowIter{i: 0, n: s.length}, nil
}
min := int64(0)
if sp.min > min {
min = sp.min
}
max := int64(s.Len) - 1
max := int64(s.length) - 1
if sp.max < max {
max = sp.max
}
Expand Down
39 changes: 26 additions & 13 deletions memory/sequence_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ var _ sql.ExecSourceRel = IntSequenceTable{}
type IntSequenceTable struct {
db sql.Database
name string
Len int64
Len sql.Expression
}

func (s IntSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
Expand All @@ -34,15 +34,11 @@ func (s IntSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []
if !ok {
return nil, fmt.Errorf("sequence table expects 1st argument to be column name")
}
lenExp, ok := args[1].(*expression.Literal)
if !ok {
return nil, fmt.Errorf("sequence table expects arguments to be literal expressions")
}
length, _, err := types.Int64.Convert(ctx, lenExp.Value())
if !ok {
return nil, fmt.Errorf("%w; sequence table expects 2nd argument to be a sequence length integer", err)
lenExp := args[1]
if !sql.IsNumberType(lenExp.Type()) {
return nil, fmt.Errorf("sequence table expects length argument to be a number")
}
return IntSequenceTable{db: db, name: name, Len: length.(int64)}, nil
return IntSequenceTable{db: db, name: name, Len: lenExp}, nil
}

func (s IntSequenceTable) Resolved() bool {
Expand Down Expand Up @@ -85,8 +81,20 @@ func (s IntSequenceTable) Children() []sql.Node {
return []sql.Node{}
}

func (s IntSequenceTable) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIter, error) {
rowIter := &SequenceTableFnRowIter{i: 0, n: s.Len}
func (s IntSequenceTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
iterLen, err := s.Len.Eval(ctx, row)
if err != nil {
return nil, err
}
iterLenVal, ok, err := types.Int64.Convert(ctx, iterLen)
if err != nil {
return nil, err
}
if !ok {
return nil, fmt.Errorf("sequence table expects integer argument")
}

rowIter := &SequenceTableFnRowIter{i: 0, n: iterLenVal.(int64)}
return rowIter, nil
}

Expand All @@ -105,11 +113,16 @@ func (IntSequenceTable) Collation() sql.CollationID {
}

func (s IntSequenceTable) Expressions() []sql.Expression {
return []sql.Expression{}
return []sql.Expression{s.Len}
}

func (s IntSequenceTable) WithExpressions(e ...sql.Expression) (sql.Node, error) {
return s, nil
if len(e) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(s, len(e), 1)
}
newSequenceTable := s
newSequenceTable.Len = e[0]
return newSequenceTable, nil
}

func (s IntSequenceTable) Database() sql.Database {
Expand Down
6 changes: 5 additions & 1 deletion sql/analyzer/unnest_exists_subqueries.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ func unnestExistSubqueries(ctx *sql.Context, scope *plan.Scope, a *Analyzer, fil
ret = plan.NewAntiJoinIncludingNulls(ret, s.inner, cond).WithComment(comment)
qFlags.Set(sql.QFlagInnerJoin)
case plan.JoinTypeSemi:
ret = plan.NewCrossJoin(ret, s.inner).WithComment(comment)
if sq.Correlated().Empty() {
ret = plan.NewCrossJoin(ret, s.inner).WithComment(comment)
} else {
ret = plan.NewLateralCrossJoin(ret, s.inner).WithComment(comment)
}
qFlags.Set(sql.QFlagCrossJoin)
default:
return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type")
Expand Down
2 changes: 1 addition & 1 deletion sql/memo/exec_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ func (b *ExecBuilder) buildMergeJoin(j *MergeJoin, children ...sql.Node) (sql.No

func (b *ExecBuilder) buildLateralJoin(j *LateralJoin, children ...sql.Node) (sql.Node, error) {
if len(j.Filter) == 0 {
return plan.NewCrossJoin(children[0], children[1]), nil
return plan.NewLateralCrossJoin(children[0], children[1]), nil
}
filters := b.buildFilterConjunction(j.Filter...)
return plan.NewJoin(children[0], children[1], j.Op.AsLateral(), filters), nil
Expand Down
4 changes: 4 additions & 0 deletions sql/plan/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,10 @@ func NewCrossJoin(left, right sql.Node) *JoinNode {
return NewJoin(left, right, JoinTypeCross, nil)
}

func NewLateralCrossJoin(left, right sql.Node) *JoinNode {
return NewJoin(left, right, JoinTypeLateralCross, nil)
}

// NaturalJoin is a join that automatically joins by all the columns with the
// same name.
// NaturalJoin is a placeholder node, it should be transformed into an INNER
Expand Down
4 changes: 2 additions & 2 deletions sql/planbuilder/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bo
return scopeColumn{}, false
}

if s.parent.activeSubquery != nil {
s.parent.activeSubquery.addOutOfScope(c.id)
if s.activeSubquery != nil {
s.activeSubquery.addOutOfScope(c.id)
}
return c, true
}
Expand Down
96 changes: 49 additions & 47 deletions sql/rowexec/join_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,16 +734,19 @@ type lateralJoinIterator struct {
secondaryNode sql.Node
cond sql.Expression
b sql.NodeExecBuilder
parentRow sql.Row
primaryRow sql.Row
secondaryRow sql.Row
rowSize int
scopeLen int
jType plan.JoinType
foundMatch bool
// primaryRow contains the parent row concatenated with the current row from the primary child,
// and is used to build the secondary child iter.
primaryRow sql.Row
// secondaryRow contains the current row from the secondary child.
secondaryRow sql.Row
rowSize int
scopeLen int
parentLen int
jType plan.JoinType
foundMatch bool
}

func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, parentRow sql.Row) (sql.RowIter, error) {
var left, right string
if leftTable, ok := j.Left().(sql.Nameable); ok {
left = leftTable.Name()
Expand All @@ -761,73 +764,72 @@ func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNod
attribute.String("right", right),
))

l, err := b.Build(ctx, j.Left(), row)
l, err := b.Build(ctx, j.Left(), parentRow)
if err != nil {
span.End()
return nil, err
}

parentLen := len(parentRow)

primaryRow := make(sql.Row, parentLen+len(j.Left().Schema()))
copy(primaryRow, parentRow)

return sql.NewSpanIter(span, &lateralJoinIterator{
parentRow: row,
primaryRow: primaryRow,
parentLen: len(parentRow),
primary: l,
secondaryNode: j.Right(),
cond: j.Filter,
jType: j.Op,
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
rowSize: len(parentRow) + len(j.Left().Schema()) + len(j.Right().Schema()),
scopeLen: j.ScopeLen,
b: b,
}), nil
}

func (i *lateralJoinIterator) loadPrimary(ctx *sql.Context) error {
if i.primaryRow == nil {
lRow, err := i.primary.Next(ctx)
if err != nil {
return err
}
i.primaryRow = lRow
i.foundMatch = false
lRow, err := i.primary.Next(ctx)
if err != nil {
return err
}
copy(i.primaryRow[i.parentLen:], lRow)
i.foundMatch = false
return nil
}

func (i *lateralJoinIterator) buildSecondary(ctx *sql.Context) error {
if i.secondary == nil {
prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true))
if err != nil {
return err
}
iter, err := i.b.Build(ctx, prepended, i.primaryRow)
if err != nil {
return err
}
i.secondary = iter
prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true))
if err != nil {
return err
}
iter, err := i.b.Build(ctx, prepended, i.primaryRow)
if err != nil {
return err
}
i.secondary = iter
return nil
}

func (i *lateralJoinIterator) loadSecondary(ctx *sql.Context) error {
if i.secondaryRow == nil {
sRow, err := i.secondary.Next(ctx)
if err != nil {
return err
}
i.secondaryRow = sRow[len(i.primaryRow):]
sRow, err := i.secondary.Next(ctx)
if err != nil {
return err
}
i.secondaryRow = sRow[len(i.primaryRow):]
return nil
}

func (i *lateralJoinIterator) buildRow(primaryRow, secondaryRow sql.Row) sql.Row {
row := make(sql.Row, i.rowSize)
copy(row, i.parentRow)
copy(row[len(i.parentRow):], primaryRow)
copy(row[len(i.parentRow)+len(primaryRow):], secondaryRow)
copy(row, primaryRow)
copy(row[len(primaryRow):], secondaryRow)
return row
}

func (i *lateralJoinIterator) removeParentRow(r sql.Row) sql.Row {
copy(r[i.scopeLen:], r[len(i.parentRow):])
r = r[:len(r)-len(i.parentRow)+i.scopeLen]
copy(r[i.scopeLen:], r[i.parentLen:])
r = r[:len(r)-i.parentLen+i.scopeLen]
return r
}

Expand All @@ -836,18 +838,20 @@ func (i *lateralJoinIterator) reset(ctx *sql.Context) (err error) {
err = i.secondary.Close(ctx)
i.secondary = nil
}
i.primaryRow = nil
i.secondaryRow = nil
return
}

func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
for {
if err := i.loadPrimary(ctx); err != nil {
return nil, err
}
if err := i.buildSecondary(ctx); err != nil {
return nil, err
// secondary being nil means we've exhausted all secondary rows for the current primary.
if i.secondary == nil {
if err := i.loadPrimary(ctx); err != nil {
return nil, err
}
if err := i.buildSecondary(ctx); err != nil {
return nil, err
}
}
if err := i.loadSecondary(ctx); err != nil {
if errors.Is(err, io.EOF) {
Expand All @@ -865,9 +869,7 @@ func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
}
return nil, err
}

row := i.buildRow(i.primaryRow, i.secondaryRow)
i.secondaryRow = nil
if i.cond != nil {
if res, err := sql.EvaluateCondition(ctx, i.cond, row); err != nil {
return nil, err
Expand Down