Skip to content

Commit 248b1e8

Browse files
authored
Merge pull request #3210 from dolthub/angela/lateraljoin
Copy parent row in lateralJoinIterator.buildRow
2 parents a77c705 + 731c9da commit 248b1e8

File tree

2 files changed

+80
-63
lines changed

2 files changed

+80
-63
lines changed

enginetest/queries/join_queries.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,4 +1429,20 @@ LATERAL (
14291429
},
14301430
},
14311431
},
1432+
{
1433+
// https://github.com/dolthub/dolt/issues/9820
1434+
Name: "lateral cross join with subquery",
1435+
SetUpScript: []string{
1436+
"create table t0(c0 boolean)",
1437+
"create table t1(c0 int)",
1438+
"insert into t0 values (true)",
1439+
"insert into t1 values(0)",
1440+
},
1441+
Assertions: []ScriptTestAssertion{
1442+
{
1443+
Query: "select v.c0, t1.c0 from t0 cross join lateral (select 1 as c0) as v join t1 on v.c0 > t1.c0",
1444+
Expected: []sql.Row{{1, 0}},
1445+
},
1446+
},
1447+
},
14321448
}

sql/rowexec/join_iters.go

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ func (i *crossJoinIterator) Close(ctx *sql.Context) (err error) {
688688
return err
689689
}
690690

691-
// lateralJoinIter is an iterator that performs a lateral join.
691+
// lateralJoinIterator is an iterator that performs a lateral join.
692692
// A LateralJoin is a join where the right side is a subquery that can reference the left side, like through a filter.
693693
// MySQL Docs: https://dev.mysql.com/doc/refman/8.0/en/lateral-derived-tables.html
694694
// Example:
@@ -716,18 +716,18 @@ func (i *crossJoinIterator) Close(ctx *sql.Context) (err error) {
716716
// +---+---+
717717
// cond is passed to the filter iter to be evaluated.
718718
type lateralJoinIterator struct {
719-
lIter sql.RowIter
720-
rIter sql.RowIter
721-
rNode sql.Node
722-
cond sql.Expression
723-
b sql.NodeExecBuilder
724-
pRow sql.Row
725-
lRow sql.Row
726-
rRow sql.Row
727-
rowSize int
728-
scopeLen int
729-
jType plan.JoinType
730-
foundMatch bool
719+
primary sql.RowIter
720+
secondary sql.RowIter
721+
secondaryNode sql.Node
722+
cond sql.Expression
723+
b sql.NodeExecBuilder
724+
parentRow sql.Row
725+
primaryRow sql.Row
726+
secondaryRow sql.Row
727+
rowSize int
728+
scopeLen int
729+
jType plan.JoinType
730+
foundMatch bool
731731
}
732732

733733
func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
@@ -755,105 +755,106 @@ func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNod
755755
}
756756

757757
return sql.NewSpanIter(span, &lateralJoinIterator{
758-
pRow: row,
759-
lIter: l,
760-
rNode: j.Right(),
761-
cond: j.Filter,
762-
jType: j.Op,
763-
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
764-
scopeLen: j.ScopeLen,
765-
b: b,
758+
parentRow: row,
759+
primary: l,
760+
secondaryNode: j.Right(),
761+
cond: j.Filter,
762+
jType: j.Op,
763+
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
764+
scopeLen: j.ScopeLen,
765+
b: b,
766766
}), nil
767767
}
768768

769-
func (i *lateralJoinIterator) loadLeft(ctx *sql.Context) error {
770-
if i.lRow == nil {
771-
lRow, err := i.lIter.Next(ctx)
769+
func (i *lateralJoinIterator) loadPrimary(ctx *sql.Context) error {
770+
if i.primaryRow == nil {
771+
lRow, err := i.primary.Next(ctx)
772772
if err != nil {
773773
return err
774774
}
775-
i.lRow = lRow
775+
i.primaryRow = lRow
776776
i.foundMatch = false
777777
}
778778
return nil
779779
}
780780

781-
func (i *lateralJoinIterator) buildRight(ctx *sql.Context) error {
782-
if i.rIter == nil {
783-
prepended, _, err := transform.Node(i.rNode, plan.PrependRowInPlan(i.lRow, true))
781+
func (i *lateralJoinIterator) buildSecondary(ctx *sql.Context) error {
782+
if i.secondary == nil {
783+
prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true))
784784
if err != nil {
785785
return err
786786
}
787-
iter, err := i.b.Build(ctx, prepended, i.lRow)
787+
iter, err := i.b.Build(ctx, prepended, i.primaryRow)
788788
if err != nil {
789789
return err
790790
}
791-
i.rIter = iter
791+
i.secondary = iter
792792
}
793793
return nil
794794
}
795795

796-
func (i *lateralJoinIterator) loadRight(ctx *sql.Context) error {
797-
if i.rRow == nil {
798-
rRow, err := i.rIter.Next(ctx)
796+
func (i *lateralJoinIterator) loadSecondary(ctx *sql.Context) error {
797+
if i.secondaryRow == nil {
798+
sRow, err := i.secondary.Next(ctx)
799799
if err != nil {
800800
return err
801801
}
802-
i.rRow = rRow[len(i.lRow):]
802+
i.secondaryRow = sRow[len(i.primaryRow):]
803803
}
804804
return nil
805805
}
806806

807-
func (i *lateralJoinIterator) buildRow(lRow, rRow sql.Row) sql.Row {
807+
func (i *lateralJoinIterator) buildRow(primaryRow, secondaryRow sql.Row) sql.Row {
808808
row := make(sql.Row, i.rowSize)
809-
copy(row, lRow)
810-
copy(row[len(lRow):], rRow)
809+
copy(row, i.parentRow)
810+
copy(row[len(i.parentRow):], primaryRow)
811+
copy(row[len(i.parentRow)+len(primaryRow):], secondaryRow)
811812
return row
812813
}
813814

814815
func (i *lateralJoinIterator) removeParentRow(r sql.Row) sql.Row {
815-
copy(r[i.scopeLen:], r[len(i.pRow):])
816-
r = r[:len(r)-len(i.pRow)+i.scopeLen]
816+
copy(r[i.scopeLen:], r[len(i.parentRow):])
817+
r = r[:len(r)-len(i.parentRow)+i.scopeLen]
817818
return r
818819
}
819820

820821
func (i *lateralJoinIterator) reset(ctx *sql.Context) (err error) {
821-
if i.rIter != nil {
822-
err = i.rIter.Close(ctx)
823-
i.rIter = nil
822+
if i.secondary != nil {
823+
err = i.secondary.Close(ctx)
824+
i.secondary = nil
824825
}
825-
i.lRow = nil
826-
i.rRow = nil
826+
i.primaryRow = nil
827+
i.secondaryRow = nil
827828
return
828829
}
829830

830831
func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
831832
for {
832-
if err := i.loadLeft(ctx); err != nil {
833+
if err := i.loadPrimary(ctx); err != nil {
833834
return nil, err
834835
}
835-
if err := i.buildRight(ctx); err != nil {
836+
if err := i.buildSecondary(ctx); err != nil {
836837
return nil, err
837838
}
838-
if err := i.loadRight(ctx); err != nil {
839+
if err := i.loadSecondary(ctx); err != nil {
839840
if errors.Is(err, io.EOF) {
840841
if !i.foundMatch && i.jType == plan.JoinTypeLateralLeft {
841-
res := i.buildRow(i.lRow, nil)
842-
if rerr := i.reset(ctx); rerr != nil {
843-
return nil, rerr
842+
res := i.buildRow(i.primaryRow, nil)
843+
if resetErr := i.reset(ctx); resetErr != nil {
844+
return nil, resetErr
844845
}
845846
return i.removeParentRow(res), nil
846847
}
847-
if rerr := i.reset(ctx); rerr != nil {
848-
return nil, rerr
848+
if resetErr := i.reset(ctx); resetErr != nil {
849+
return nil, resetErr
849850
}
850851
continue
851852
}
852853
return nil, err
853854
}
854855

855-
row := i.buildRow(i.lRow, i.rRow)
856-
i.rRow = nil
856+
row := i.buildRow(i.primaryRow, i.secondaryRow)
857+
i.secondaryRow = nil
857858
if i.cond != nil {
858859
if res, err := sql.EvaluateCondition(ctx, i.cond, row); err != nil {
859860
return nil, err
@@ -868,18 +869,18 @@ func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
868869
}
869870

870871
func (i *lateralJoinIterator) Close(ctx *sql.Context) error {
871-
var lerr, rerr error
872-
if i.lIter != nil {
873-
lerr = i.lIter.Close(ctx)
872+
var pErr, sErr error
873+
if i.primary != nil {
874+
pErr = i.primary.Close(ctx)
874875
}
875-
if i.rIter != nil {
876-
rerr = i.rIter.Close(ctx)
876+
if i.secondary != nil {
877+
sErr = i.secondary.Close(ctx)
877878
}
878-
if lerr != nil {
879-
return lerr
879+
if pErr != nil {
880+
return pErr
880881
}
881-
if rerr != nil {
882-
return rerr
882+
if sErr != nil {
883+
return sErr
883884
}
884885
return nil
885886
}

0 commit comments

Comments
 (0)