Skip to content

Commit b3e1e88

Browse files
authored
Merge pull request #3304 from dolthub/nicktobey/lateral_fix
Support table functions with non-literal arguments in subqueries and lateral joins
2 parents 7b8fb20 + 0e2b6da commit b3e1e88

File tree

9 files changed

+114
-68
lines changed

9 files changed

+114
-68
lines changed

enginetest/queries/queries.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8178,6 +8178,12 @@ ORDER BY 1;`,
81788178
{3},
81798179
},
81808180
},
8181+
{
8182+
Query: "SELECT * FROM xy JOIN LATERAL (SELECT * FROM uv WHERE xy.x+1 = uv.u) uv2",
8183+
Expected: []sql.Row{
8184+
{0, 2, 1, 1}, {1, 0, 2, 2}, {2, 1, 3, 2},
8185+
},
8186+
},
81818187
{
81828188
Query: `
81838189
select * from mytable,

enginetest/queries/table_func_scripts.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ var TableFunctionScriptTests = []ScriptTest{
150150
Query: "select x from sequence_table('x', 5) where exists (select y from sequence_table('y', 3) where x = y)",
151151
Expected: []sql.Row{{0}, {1}, {2}},
152152
},
153+
{
154+
Query: "select * from sequence_table('x', 3) l join lateral (select * from sequence_table('y', l.x)) r",
155+
Expected: []sql.Row{{1, 0}, {2, 0}, {2, 1}},
156+
},
157+
{
158+
Query: "select * from sequence_table('x', 3) l where exists (select * from sequence_table('y', l.x))",
159+
Expected: []sql.Row{{1}, {2}},
160+
},
153161
{
154162
Query: "select not_seq.x from sequence_table('x', 5) as seq",
155163
ExpectedErr: sql.ErrTableNotFound,

memory/lookup_squence_table.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ var _ sql.TableNode = LookupSequenceTable{}
1818
// LookupSequenceTable is a variation of IntSequenceTable that supports lookups and implements sql.TableNode
1919
type LookupSequenceTable struct {
2020
IntSequenceTable
21+
length int64
2122
}
2223

2324
func (s LookupSequenceTable) UnderlyingTable() sql.Table {
@@ -29,7 +30,15 @@ func (s LookupSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args
2930
if err != nil {
3031
return nil, err
3132
}
32-
return LookupSequenceTable{newIntSequenceTable.(IntSequenceTable)}, nil
33+
lenExp, ok := args[1].(*expression.Literal)
34+
if !ok {
35+
return nil, fmt.Errorf("sequence table expects arguments to be literal expressions")
36+
}
37+
length, _, err := types.Int64.Convert(ctx, lenExp.Value())
38+
if err != nil {
39+
return nil, err
40+
}
41+
return LookupSequenceTable{newIntSequenceTable.(IntSequenceTable), length.(int64)}, nil
3342
}
3443

3544
func (s LookupSequenceTable) String() string {
@@ -82,21 +91,21 @@ func (s LookupSequenceTable) Description() string {
8291

8392
// Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition.
8493
func (s LookupSequenceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
85-
return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: int64(s.Len) - 1}), nil
94+
return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: s.length - 1}), nil
8695
}
8796

8897
// PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition.
8998
// This table has a partition for just schema changes, one for just data changes, and one for both.
9099
func (s LookupSequenceTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
91100
sp, ok := partition.(*sequencePartition)
92101
if !ok {
93-
return &SequenceTableFnRowIter{i: 0, n: s.Len}, nil
102+
return &SequenceTableFnRowIter{i: 0, n: s.length}, nil
94103
}
95104
min := int64(0)
96105
if sp.min > min {
97106
min = sp.min
98107
}
99-
max := int64(s.Len) - 1
108+
max := int64(s.length) - 1
100109
if sp.max < max {
101110
max = sp.max
102111
}

memory/sequence_table.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ var _ sql.ExecSourceRel = IntSequenceTable{}
1919
type IntSequenceTable struct {
2020
db sql.Database
2121
name string
22-
Len int64
22+
Len sql.Expression
2323
}
2424

2525
func (s IntSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
@@ -34,15 +34,11 @@ func (s IntSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []
3434
if !ok {
3535
return nil, fmt.Errorf("sequence table expects 1st argument to be column name")
3636
}
37-
lenExp, ok := args[1].(*expression.Literal)
38-
if !ok {
39-
return nil, fmt.Errorf("sequence table expects arguments to be literal expressions")
40-
}
41-
length, _, err := types.Int64.Convert(ctx, lenExp.Value())
42-
if !ok {
43-
return nil, fmt.Errorf("%w; sequence table expects 2nd argument to be a sequence length integer", err)
37+
lenExp := args[1]
38+
if !sql.IsNumberType(lenExp.Type()) {
39+
return nil, fmt.Errorf("sequence table expects length argument to be a number")
4440
}
45-
return IntSequenceTable{db: db, name: name, Len: length.(int64)}, nil
41+
return IntSequenceTable{db: db, name: name, Len: lenExp}, nil
4642
}
4743

4844
func (s IntSequenceTable) Resolved() bool {
@@ -85,8 +81,20 @@ func (s IntSequenceTable) Children() []sql.Node {
8581
return []sql.Node{}
8682
}
8783

88-
func (s IntSequenceTable) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIter, error) {
89-
rowIter := &SequenceTableFnRowIter{i: 0, n: s.Len}
84+
func (s IntSequenceTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
85+
iterLen, err := s.Len.Eval(ctx, row)
86+
if err != nil {
87+
return nil, err
88+
}
89+
iterLenVal, ok, err := types.Int64.Convert(ctx, iterLen)
90+
if err != nil {
91+
return nil, err
92+
}
93+
if !ok {
94+
return nil, fmt.Errorf("sequence table expects integer argument")
95+
}
96+
97+
rowIter := &SequenceTableFnRowIter{i: 0, n: iterLenVal.(int64)}
9098
return rowIter, nil
9199
}
92100

@@ -105,11 +113,16 @@ func (IntSequenceTable) Collation() sql.CollationID {
105113
}
106114

107115
func (s IntSequenceTable) Expressions() []sql.Expression {
108-
return []sql.Expression{}
116+
return []sql.Expression{s.Len}
109117
}
110118

111119
func (s IntSequenceTable) WithExpressions(e ...sql.Expression) (sql.Node, error) {
112-
return s, nil
120+
if len(e) != 1 {
121+
return nil, sql.ErrInvalidChildrenNumber.New(s, len(e), 1)
122+
}
123+
newSequenceTable := s
124+
newSequenceTable.Len = e[0]
125+
return newSequenceTable, nil
113126
}
114127

115128
func (s IntSequenceTable) Database() sql.Database {

sql/analyzer/unnest_exists_subqueries.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ func unnestExistSubqueries(ctx *sql.Context, scope *plan.Scope, a *Analyzer, fil
193193
ret = plan.NewAntiJoinIncludingNulls(ret, s.inner, cond).WithComment(comment)
194194
qFlags.Set(sql.QFlagInnerJoin)
195195
case plan.JoinTypeSemi:
196-
ret = plan.NewCrossJoin(ret, s.inner).WithComment(comment)
196+
if sq.Correlated().Empty() {
197+
ret = plan.NewCrossJoin(ret, s.inner).WithComment(comment)
198+
} else {
199+
ret = plan.NewLateralCrossJoin(ret, s.inner).WithComment(comment)
200+
}
197201
qFlags.Set(sql.QFlagCrossJoin)
198202
default:
199203
return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type")

sql/memo/exec_builder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ func (b *ExecBuilder) buildMergeJoin(j *MergeJoin, children ...sql.Node) (sql.No
299299

300300
func (b *ExecBuilder) buildLateralJoin(j *LateralJoin, children ...sql.Node) (sql.Node, error) {
301301
if len(j.Filter) == 0 {
302-
return plan.NewCrossJoin(children[0], children[1]), nil
302+
return plan.NewLateralCrossJoin(children[0], children[1]), nil
303303
}
304304
filters := b.buildFilterConjunction(j.Filter...)
305305
return plan.NewJoin(children[0], children[1], j.Op.AsLateral(), filters), nil

sql/plan/join.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,10 @@ func NewCrossJoin(left, right sql.Node) *JoinNode {
532532
return NewJoin(left, right, JoinTypeCross, nil)
533533
}
534534

535+
func NewLateralCrossJoin(left, right sql.Node) *JoinNode {
536+
return NewJoin(left, right, JoinTypeLateralCross, nil)
537+
}
538+
535539
// NaturalJoin is a join that automatically joins by all the columns with the
536540
// same name.
537541
// NaturalJoin is a placeholder node, it should be transformed into an INNER

sql/planbuilder/scope.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bo
147147
return scopeColumn{}, false
148148
}
149149

150-
if s.parent.activeSubquery != nil {
151-
s.parent.activeSubquery.addOutOfScope(c.id)
150+
if s.activeSubquery != nil {
151+
s.activeSubquery.addOutOfScope(c.id)
152152
}
153153
return c, true
154154
}

sql/rowexec/join_iters.go

Lines changed: 49 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -734,16 +734,19 @@ type lateralJoinIterator struct {
734734
secondaryNode sql.Node
735735
cond sql.Expression
736736
b sql.NodeExecBuilder
737-
parentRow sql.Row
738-
primaryRow sql.Row
739-
secondaryRow sql.Row
740-
rowSize int
741-
scopeLen int
742-
jType plan.JoinType
743-
foundMatch bool
737+
// primaryRow contains the parent row concatenated with the current row from the primary child,
738+
// and is used to build the secondary child iter.
739+
primaryRow sql.Row
740+
// secondaryRow contains the current row from the secondary child.
741+
secondaryRow sql.Row
742+
rowSize int
743+
scopeLen int
744+
parentLen int
745+
jType plan.JoinType
746+
foundMatch bool
744747
}
745748

746-
func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
749+
func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, parentRow sql.Row) (sql.RowIter, error) {
747750
var left, right string
748751
if leftTable, ok := j.Left().(sql.Nameable); ok {
749752
left = leftTable.Name()
@@ -761,73 +764,72 @@ func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNod
761764
attribute.String("right", right),
762765
))
763766

764-
l, err := b.Build(ctx, j.Left(), row)
767+
l, err := b.Build(ctx, j.Left(), parentRow)
765768
if err != nil {
766769
span.End()
767770
return nil, err
768771
}
769772

773+
parentLen := len(parentRow)
774+
775+
primaryRow := make(sql.Row, parentLen+len(j.Left().Schema()))
776+
copy(primaryRow, parentRow)
777+
770778
return sql.NewSpanIter(span, &lateralJoinIterator{
771-
parentRow: row,
779+
primaryRow: primaryRow,
780+
parentLen: len(parentRow),
772781
primary: l,
773782
secondaryNode: j.Right(),
774783
cond: j.Filter,
775784
jType: j.Op,
776-
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
785+
rowSize: len(parentRow) + len(j.Left().Schema()) + len(j.Right().Schema()),
777786
scopeLen: j.ScopeLen,
778787
b: b,
779788
}), nil
780789
}
781790

782791
func (i *lateralJoinIterator) loadPrimary(ctx *sql.Context) error {
783-
if i.primaryRow == nil {
784-
lRow, err := i.primary.Next(ctx)
785-
if err != nil {
786-
return err
787-
}
788-
i.primaryRow = lRow
789-
i.foundMatch = false
792+
lRow, err := i.primary.Next(ctx)
793+
if err != nil {
794+
return err
790795
}
796+
copy(i.primaryRow[i.parentLen:], lRow)
797+
i.foundMatch = false
791798
return nil
792799
}
793800

794801
func (i *lateralJoinIterator) buildSecondary(ctx *sql.Context) error {
795-
if i.secondary == nil {
796-
prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true))
797-
if err != nil {
798-
return err
799-
}
800-
iter, err := i.b.Build(ctx, prepended, i.primaryRow)
801-
if err != nil {
802-
return err
803-
}
804-
i.secondary = iter
802+
prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true))
803+
if err != nil {
804+
return err
805+
}
806+
iter, err := i.b.Build(ctx, prepended, i.primaryRow)
807+
if err != nil {
808+
return err
805809
}
810+
i.secondary = iter
806811
return nil
807812
}
808813

809814
func (i *lateralJoinIterator) loadSecondary(ctx *sql.Context) error {
810-
if i.secondaryRow == nil {
811-
sRow, err := i.secondary.Next(ctx)
812-
if err != nil {
813-
return err
814-
}
815-
i.secondaryRow = sRow[len(i.primaryRow):]
815+
sRow, err := i.secondary.Next(ctx)
816+
if err != nil {
817+
return err
816818
}
819+
i.secondaryRow = sRow[len(i.primaryRow):]
817820
return nil
818821
}
819822

820823
func (i *lateralJoinIterator) buildRow(primaryRow, secondaryRow sql.Row) sql.Row {
821824
row := make(sql.Row, i.rowSize)
822-
copy(row, i.parentRow)
823-
copy(row[len(i.parentRow):], primaryRow)
824-
copy(row[len(i.parentRow)+len(primaryRow):], secondaryRow)
825+
copy(row, primaryRow)
826+
copy(row[len(primaryRow):], secondaryRow)
825827
return row
826828
}
827829

828830
func (i *lateralJoinIterator) removeParentRow(r sql.Row) sql.Row {
829-
copy(r[i.scopeLen:], r[len(i.parentRow):])
830-
r = r[:len(r)-len(i.parentRow)+i.scopeLen]
831+
copy(r[i.scopeLen:], r[i.parentLen:])
832+
r = r[:len(r)-i.parentLen+i.scopeLen]
831833
return r
832834
}
833835

@@ -836,18 +838,20 @@ func (i *lateralJoinIterator) reset(ctx *sql.Context) (err error) {
836838
err = i.secondary.Close(ctx)
837839
i.secondary = nil
838840
}
839-
i.primaryRow = nil
840841
i.secondaryRow = nil
841842
return
842843
}
843844

844845
func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
845846
for {
846-
if err := i.loadPrimary(ctx); err != nil {
847-
return nil, err
848-
}
849-
if err := i.buildSecondary(ctx); err != nil {
850-
return nil, err
847+
// secondary being nil means we've exhausted all secondary rows for the current primary.
848+
if i.secondary == nil {
849+
if err := i.loadPrimary(ctx); err != nil {
850+
return nil, err
851+
}
852+
if err := i.buildSecondary(ctx); err != nil {
853+
return nil, err
854+
}
851855
}
852856
if err := i.loadSecondary(ctx); err != nil {
853857
if errors.Is(err, io.EOF) {
@@ -865,9 +869,7 @@ func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
865869
}
866870
return nil, err
867871
}
868-
869872
row := i.buildRow(i.primaryRow, i.secondaryRow)
870-
i.secondaryRow = nil
871873
if i.cond != nil {
872874
if res, err := sql.EvaluateCondition(ctx, i.cond, row); err != nil {
873875
return nil, err

0 commit comments

Comments
 (0)