Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions enginetest/queries/join_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,4 +1429,20 @@ LATERAL (
},
},
},
{
// https://github.com/dolthub/dolt/issues/9820
Name: "lateral cross join with subquery",
SetUpScript: []string{
"create table t0(c0 boolean)",
"create table t1(c0 int)",
"insert into t0 values (true)",
"insert into t1 values(0)",
},
Assertions: []ScriptTestAssertion{
{
Query: "select v.c0, t1.c0 from t0 cross join lateral (select 1 as c0) as v join t1 on v.c0 > t1.c0",
Expected: []sql.Row{{1, 0}},
},
},
},
}
127 changes: 64 additions & 63 deletions sql/rowexec/join_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ func (i *crossJoinIterator) Close(ctx *sql.Context) (err error) {
return err
}

// lateralJoinIter is an iterator that performs a lateral join.
// lateralJoinIterator is an iterator that performs a lateral join.
// A LateralJoin is a join where the right side is a subquery that can reference the left side, like through a filter.
// MySQL Docs: https://dev.mysql.com/doc/refman/8.0/en/lateral-derived-tables.html
// Example:
Expand Down Expand Up @@ -716,18 +716,18 @@ func (i *crossJoinIterator) Close(ctx *sql.Context) (err error) {
// +---+---+
// cond is passed to the filter iter to be evaluated.
type lateralJoinIterator struct {
lIter sql.RowIter
rIter sql.RowIter
rNode sql.Node
cond sql.Expression
b sql.NodeExecBuilder
pRow sql.Row
lRow sql.Row
rRow sql.Row
rowSize int
scopeLen int
jType plan.JoinType
foundMatch bool
primary sql.RowIter
secondary sql.RowIter
secondaryNode sql.Node
cond sql.Expression
b sql.NodeExecBuilder
parentRow sql.Row
primaryRow sql.Row
secondaryRow sql.Row
rowSize int
scopeLen int
jType plan.JoinType
foundMatch bool
}

func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
Expand Down Expand Up @@ -755,105 +755,106 @@ func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNod
}

return sql.NewSpanIter(span, &lateralJoinIterator{
pRow: row,
lIter: l,
rNode: j.Right(),
cond: j.Filter,
jType: j.Op,
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
scopeLen: j.ScopeLen,
b: b,
parentRow: row,
primary: l,
secondaryNode: j.Right(),
cond: j.Filter,
jType: j.Op,
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
scopeLen: j.ScopeLen,
b: b,
}), nil
}

func (i *lateralJoinIterator) loadLeft(ctx *sql.Context) error {
if i.lRow == nil {
lRow, err := i.lIter.Next(ctx)
func (i *lateralJoinIterator) loadPrimary(ctx *sql.Context) error {
if i.primaryRow == nil {
lRow, err := i.primary.Next(ctx)
if err != nil {
return err
}
i.lRow = lRow
i.primaryRow = lRow
i.foundMatch = false
}
return nil
}

func (i *lateralJoinIterator) buildRight(ctx *sql.Context) error {
if i.rIter == nil {
prepended, _, err := transform.Node(i.rNode, plan.PrependRowInPlan(i.lRow, true))
func (i *lateralJoinIterator) buildSecondary(ctx *sql.Context) error {
if i.secondary == nil {
prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true))
if err != nil {
return err
}
iter, err := i.b.Build(ctx, prepended, i.lRow)
iter, err := i.b.Build(ctx, prepended, i.primaryRow)
if err != nil {
return err
}
i.rIter = iter
i.secondary = iter
}
return nil
}

func (i *lateralJoinIterator) loadRight(ctx *sql.Context) error {
if i.rRow == nil {
rRow, err := i.rIter.Next(ctx)
func (i *lateralJoinIterator) loadSecondary(ctx *sql.Context) error {
if i.secondaryRow == nil {
sRow, err := i.secondary.Next(ctx)
if err != nil {
return err
}
i.rRow = rRow[len(i.lRow):]
i.secondaryRow = sRow[len(i.primaryRow):]
}
return nil
}

func (i *lateralJoinIterator) buildRow(lRow, rRow sql.Row) sql.Row {
func (i *lateralJoinIterator) buildRow(primaryRow, secondaryRow sql.Row) sql.Row {
row := make(sql.Row, i.rowSize)
copy(row, lRow)
copy(row[len(lRow):], rRow)
copy(row, i.parentRow)
copy(row[len(i.parentRow):], primaryRow)
copy(row[len(i.parentRow)+len(primaryRow):], secondaryRow)
return row
}

func (i *lateralJoinIterator) removeParentRow(r sql.Row) sql.Row {
copy(r[i.scopeLen:], r[len(i.pRow):])
r = r[:len(r)-len(i.pRow)+i.scopeLen]
copy(r[i.scopeLen:], r[len(i.parentRow):])
r = r[:len(r)-len(i.parentRow)+i.scopeLen]
return r
}

func (i *lateralJoinIterator) reset(ctx *sql.Context) (err error) {
if i.rIter != nil {
err = i.rIter.Close(ctx)
i.rIter = nil
if i.secondary != nil {
err = i.secondary.Close(ctx)
i.secondary = nil
}
i.lRow = nil
i.rRow = nil
i.primaryRow = nil
i.secondaryRow = nil
return
}

func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
for {
if err := i.loadLeft(ctx); err != nil {
if err := i.loadPrimary(ctx); err != nil {
return nil, err
}
if err := i.buildRight(ctx); err != nil {
if err := i.buildSecondary(ctx); err != nil {
return nil, err
}
if err := i.loadRight(ctx); err != nil {
if err := i.loadSecondary(ctx); err != nil {
if errors.Is(err, io.EOF) {
if !i.foundMatch && i.jType == plan.JoinTypeLateralLeft {
res := i.buildRow(i.lRow, nil)
if rerr := i.reset(ctx); rerr != nil {
return nil, rerr
res := i.buildRow(i.primaryRow, nil)
if resetErr := i.reset(ctx); resetErr != nil {
return nil, resetErr
}
return i.removeParentRow(res), nil
}
if rerr := i.reset(ctx); rerr != nil {
return nil, rerr
if resetErr := i.reset(ctx); resetErr != nil {
return nil, resetErr
}
continue
}
return nil, err
}

row := i.buildRow(i.lRow, i.rRow)
i.rRow = nil
row := i.buildRow(i.primaryRow, i.secondaryRow)
i.secondaryRow = nil
if i.cond != nil {
if res, err := sql.EvaluateCondition(ctx, i.cond, row); err != nil {
return nil, err
Expand All @@ -868,18 +869,18 @@ func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
}

func (i *lateralJoinIterator) Close(ctx *sql.Context) error {
var lerr, rerr error
if i.lIter != nil {
lerr = i.lIter.Close(ctx)
var pErr, sErr error
if i.primary != nil {
pErr = i.primary.Close(ctx)
}
if i.rIter != nil {
rerr = i.rIter.Close(ctx)
if i.secondary != nil {
sErr = i.secondary.Close(ctx)
}
if lerr != nil {
return lerr
if pErr != nil {
return pErr
}
if rerr != nil {
return rerr
if sErr != nil {
return sErr
}
return nil
}
Loading