From f42fc3e75ed5738e34f9edd14706e973ff67014c Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Mon, 10 Nov 2025 15:51:20 -0800 Subject: [PATCH 1/6] Allow IntSequenceTable to take a non-constant argument. --- memory/lookup_squence_table.go | 17 +++++++++++++---- memory/sequence_table.go | 30 +++++++++++++++++++----------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/memory/lookup_squence_table.go b/memory/lookup_squence_table.go index ce482662c5..1542fd5cb3 100644 --- a/memory/lookup_squence_table.go +++ b/memory/lookup_squence_table.go @@ -18,6 +18,7 @@ var _ sql.TableNode = LookupSequenceTable{} // LookupSequenceTable is a variation of IntSequenceTable that supports lookups and implements sql.TableNode type LookupSequenceTable struct { IntSequenceTable + length int64 } func (s LookupSequenceTable) UnderlyingTable() sql.Table { @@ -29,7 +30,15 @@ func (s LookupSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args if err != nil { return nil, err } - return LookupSequenceTable{newIntSequenceTable.(IntSequenceTable)}, nil + lenExp, ok := args[1].(*expression.Literal) + if !ok { + return nil, fmt.Errorf("sequence table expects arguments to be literal expressions") + } + length, _, err := types.Int64.Convert(ctx, lenExp.Value()) + if err != nil { + return nil, err + } + return LookupSequenceTable{newIntSequenceTable.(IntSequenceTable), length.(int64)}, nil } func (s LookupSequenceTable) String() string { @@ -82,7 +91,7 @@ func (s LookupSequenceTable) Description() string { // Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition. func (s LookupSequenceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { - return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: int64(s.Len) - 1}), nil + return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: s.length - 1}), nil } // PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition. @@ -90,13 +99,13 @@ func (s LookupSequenceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, er func (s LookupSequenceTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { sp, ok := partition.(*sequencePartition) if !ok { - return &SequenceTableFnRowIter{i: 0, n: s.Len}, nil + return &SequenceTableFnRowIter{i: 0, n: s.length}, nil } min := int64(0) if sp.min > min { min = sp.min } - max := int64(s.Len) - 1 + max := int64(s.length) - 1 if sp.max < max { max = sp.max } diff --git a/memory/sequence_table.go b/memory/sequence_table.go index c7ce06ab27..878e2f0d94 100644 --- a/memory/sequence_table.go +++ b/memory/sequence_table.go @@ -19,7 +19,7 @@ var _ sql.ExecSourceRel = IntSequenceTable{} type IntSequenceTable struct { db sql.Database name string - Len int64 + Len sql.Expression } func (s IntSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { @@ -34,15 +34,11 @@ func (s IntSequenceTable) NewInstance(ctx *sql.Context, db sql.Database, args [] if !ok { return nil, fmt.Errorf("sequence table expects 1st argument to be column name") } - lenExp, ok := args[1].(*expression.Literal) - if !ok { - return nil, fmt.Errorf("sequence table expects arguments to be literal expressions") + lenExp := args[1] + if !sql.IsNumberType(lenExp.Type()) { + return nil, fmt.Errorf("sequence table expects length argument to be a number") } - length, _, err := types.Int64.Convert(ctx, lenExp.Value()) - if !ok { - return nil, fmt.Errorf("%w; sequence table expects 2nd argument to be a sequence length integer", err) - } - return IntSequenceTable{db: db, name: name, Len: length.(int64)}, nil + return IntSequenceTable{db: db, name: name, Len: lenExp}, nil } func (s IntSequenceTable) Resolved() bool { @@ -85,8 +81,20 @@ func (s IntSequenceTable) Children() []sql.Node { return []sql.Node{} } -func (s IntSequenceTable) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIter, error) { - rowIter := &SequenceTableFnRowIter{i: 0, n: s.Len} +func (s IntSequenceTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { + iterLen, err := s.Len.Eval(ctx, row) + if err != nil { + return nil, err + } + iterLenVal, ok, err := types.Int64.Convert(ctx, iterLen) + if err != nil { + return nil, err + } + if !ok { + return nil, fmt.Errorf("sequence table expects integer argument") + } + + rowIter := &SequenceTableFnRowIter{i: 0, n: iterLenVal.(int64)} return rowIter, nil } From acf761aab08ca9fb2879f859a0e5df9fa904282d Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Mon, 10 Nov 2025 17:47:21 -0800 Subject: [PATCH 2/6] Add support for modifying expressions in IntSequenceTable --- memory/sequence_table.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/memory/sequence_table.go b/memory/sequence_table.go index 878e2f0d94..6dd0ee60ba 100644 --- a/memory/sequence_table.go +++ b/memory/sequence_table.go @@ -113,11 +113,16 @@ func (IntSequenceTable) Collation() sql.CollationID { } func (s IntSequenceTable) Expressions() []sql.Expression { - return []sql.Expression{} + return []sql.Expression{s.Len} } func (s IntSequenceTable) WithExpressions(e ...sql.Expression) (sql.Node, error) { - return s, nil + if len(e) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(e), 1) + } + newSequenceTable := s + newSequenceTable.Len = e[0] + return newSequenceTable, nil } func (s IntSequenceTable) Database() sql.Database { From 00ca308cbbcb3886688d18598091b0878636c67c Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Tue, 11 Nov 2025 12:03:58 -0800 Subject: [PATCH 3/6] Add additional tests for table functions in subqueries --- enginetest/queries/table_func_scripts.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/enginetest/queries/table_func_scripts.go b/enginetest/queries/table_func_scripts.go index 82793eff31..14294d7fb7 100644 --- a/enginetest/queries/table_func_scripts.go +++ b/enginetest/queries/table_func_scripts.go @@ -150,6 +150,14 @@ var TableFunctionScriptTests = []ScriptTest{ Query: "select x from sequence_table('x', 5) where exists (select y from sequence_table('y', 3) where x = y)", Expected: []sql.Row{{0}, {1}, {2}}, }, + { + Query: "select * from sequence_table('x', 3) l join lateral (select * from sequence_table('y', l.x)) r", + Expected: []sql.Row{{1, 0}, {2, 0}, {2, 1}}, + }, + { + Query: "select * from sequence_table('x', 3) l where exists (select * from sequence_table('y', l.x))", + Expected: []sql.Row{{1}, {2}}, + }, { Query: "select not_seq.x from sequence_table('x', 5) as seq", ExpectedErr: sql.ErrTableNotFound, From 8a5f5821120ca8b5306628cd18ad493f87a28413 Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Thu, 6 Nov 2025 16:23:09 -0800 Subject: [PATCH 4/6] Ensure that references to an outer subquery in the top scope of a subquery are correctly handled. --- enginetest/queries/queries.go | 6 ++++++ sql/planbuilder/scope.go | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 7a989e04c0..36cf975d66 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -8178,6 +8178,12 @@ ORDER BY 1;`, {3}, }, }, + { + Query: "SELECT * FROM xy JOIN LATERAL (SELECT * FROM uv WHERE xy.x+1 = uv.u) uv2", + Expected: []sql.Row{ + {0, 2, 1, 1}, {1, 0, 2, 2}, {2, 1, 3, 2}, + }, + }, { Query: ` select * from mytable, diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index 80ee268303..75efead05d 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -147,8 +147,8 @@ func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bo return scopeColumn{}, false } - if s.parent.activeSubquery != nil { - s.parent.activeSubquery.addOutOfScope(c.id) + if s.activeSubquery != nil { + s.activeSubquery.addOutOfScope(c.id) } return c, true } From 0147ad3c3e522fa36552944d15adb1425cb22c1c Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Thu, 13 Nov 2025 16:11:21 -0800 Subject: [PATCH 5/6] Preserve lateralness when replanning join. --- sql/analyzer/unnest_exists_subqueries.go | 6 +++++- sql/memo/exec_builder.go | 2 +- sql/plan/join.go | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/analyzer/unnest_exists_subqueries.go b/sql/analyzer/unnest_exists_subqueries.go index 166de18260..45ae9c6007 100644 --- a/sql/analyzer/unnest_exists_subqueries.go +++ b/sql/analyzer/unnest_exists_subqueries.go @@ -193,7 +193,11 @@ func unnestExistSubqueries(ctx *sql.Context, scope *plan.Scope, a *Analyzer, fil ret = plan.NewAntiJoinIncludingNulls(ret, s.inner, cond).WithComment(comment) qFlags.Set(sql.QFlagInnerJoin) case plan.JoinTypeSemi: - ret = plan.NewCrossJoin(ret, s.inner).WithComment(comment) + if sq.Correlated().Empty() { + ret = plan.NewCrossJoin(ret, s.inner).WithComment(comment) + } else { + ret = plan.NewLateralCrossJoin(ret, s.inner).WithComment(comment) + } qFlags.Set(sql.QFlagCrossJoin) default: return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type") diff --git a/sql/memo/exec_builder.go b/sql/memo/exec_builder.go index 5007c66416..67addda622 100644 --- a/sql/memo/exec_builder.go +++ b/sql/memo/exec_builder.go @@ -299,7 +299,7 @@ func (b *ExecBuilder) buildMergeJoin(j *MergeJoin, children ...sql.Node) (sql.No func (b *ExecBuilder) buildLateralJoin(j *LateralJoin, children ...sql.Node) (sql.Node, error) { if len(j.Filter) == 0 { - return plan.NewCrossJoin(children[0], children[1]), nil + return plan.NewLateralCrossJoin(children[0], children[1]), nil } filters := b.buildFilterConjunction(j.Filter...) return plan.NewJoin(children[0], children[1], j.Op.AsLateral(), filters), nil diff --git a/sql/plan/join.go b/sql/plan/join.go index 600c752dff..e579fdd48d 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -532,6 +532,10 @@ func NewCrossJoin(left, right sql.Node) *JoinNode { return NewJoin(left, right, JoinTypeCross, nil) } +func NewLateralCrossJoin(left, right sql.Node) *JoinNode { + return NewJoin(left, right, JoinTypeLateralCross, nil) +} + // NaturalJoin is a join that automatically joins by all the columns with the // same name. // NaturalJoin is a placeholder node, it should be transformed into an INNER From 0e2b6da9d20228a9f3255b2b532f97eae9f34b80 Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Thu, 13 Nov 2025 17:26:04 -0800 Subject: [PATCH 6/6] In LateralJoinIter, provide parent rows to the secondary child rowIter. --- sql/rowexec/join_iters.go | 96 ++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/sql/rowexec/join_iters.go b/sql/rowexec/join_iters.go index 0e24a16dfe..1d46c86d8d 100644 --- a/sql/rowexec/join_iters.go +++ b/sql/rowexec/join_iters.go @@ -734,16 +734,19 @@ type lateralJoinIterator struct { secondaryNode sql.Node cond sql.Expression b sql.NodeExecBuilder - parentRow sql.Row - primaryRow sql.Row - secondaryRow sql.Row - rowSize int - scopeLen int - jType plan.JoinType - foundMatch bool + // primaryRow contains the parent row concatenated with the current row from the primary child, + // and is used to build the secondary child iter. + primaryRow sql.Row + // secondaryRow contains the current row from the secondary child. + secondaryRow sql.Row + rowSize int + scopeLen int + parentLen int + jType plan.JoinType + foundMatch bool } -func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) { +func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, parentRow sql.Row) (sql.RowIter, error) { var left, right string if leftTable, ok := j.Left().(sql.Nameable); ok { left = leftTable.Name() @@ -761,73 +764,72 @@ func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNod attribute.String("right", right), )) - l, err := b.Build(ctx, j.Left(), row) + l, err := b.Build(ctx, j.Left(), parentRow) if err != nil { span.End() return nil, err } + parentLen := len(parentRow) + + primaryRow := make(sql.Row, parentLen+len(j.Left().Schema())) + copy(primaryRow, parentRow) + return sql.NewSpanIter(span, &lateralJoinIterator{ - parentRow: row, + primaryRow: primaryRow, + parentLen: len(parentRow), primary: l, secondaryNode: j.Right(), cond: j.Filter, jType: j.Op, - rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()), + rowSize: len(parentRow) + len(j.Left().Schema()) + len(j.Right().Schema()), scopeLen: j.ScopeLen, b: b, }), nil } func (i *lateralJoinIterator) loadPrimary(ctx *sql.Context) error { - if i.primaryRow == nil { - lRow, err := i.primary.Next(ctx) - if err != nil { - return err - } - i.primaryRow = lRow - i.foundMatch = false + lRow, err := i.primary.Next(ctx) + if err != nil { + return err } + copy(i.primaryRow[i.parentLen:], lRow) + i.foundMatch = false return nil } func (i *lateralJoinIterator) buildSecondary(ctx *sql.Context) error { - if i.secondary == nil { - prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true)) - if err != nil { - return err - } - iter, err := i.b.Build(ctx, prepended, i.primaryRow) - if err != nil { - return err - } - i.secondary = iter + prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true)) + if err != nil { + return err + } + iter, err := i.b.Build(ctx, prepended, i.primaryRow) + if err != nil { + return err } + i.secondary = iter return nil } func (i *lateralJoinIterator) loadSecondary(ctx *sql.Context) error { - if i.secondaryRow == nil { - sRow, err := i.secondary.Next(ctx) - if err != nil { - return err - } - i.secondaryRow = sRow[len(i.primaryRow):] + sRow, err := i.secondary.Next(ctx) + if err != nil { + return err } + i.secondaryRow = sRow[len(i.primaryRow):] return nil } func (i *lateralJoinIterator) buildRow(primaryRow, secondaryRow sql.Row) sql.Row { row := make(sql.Row, i.rowSize) - copy(row, i.parentRow) - copy(row[len(i.parentRow):], primaryRow) - copy(row[len(i.parentRow)+len(primaryRow):], secondaryRow) + copy(row, primaryRow) + copy(row[len(primaryRow):], secondaryRow) return row } func (i *lateralJoinIterator) removeParentRow(r sql.Row) sql.Row { - copy(r[i.scopeLen:], r[len(i.parentRow):]) - r = r[:len(r)-len(i.parentRow)+i.scopeLen] + copy(r[i.scopeLen:], r[i.parentLen:]) + r = r[:len(r)-i.parentLen+i.scopeLen] return r } @@ -836,18 +838,20 @@ func (i *lateralJoinIterator) reset(ctx *sql.Context) (err error) { err = i.secondary.Close(ctx) i.secondary = nil } - i.primaryRow = nil i.secondaryRow = nil return } func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) { for { - if err := i.loadPrimary(ctx); err != nil { - return nil, err - } - if err := i.buildSecondary(ctx); err != nil { - return nil, err + // secondary being nil means we've exhausted all secondary rows for the current primary. + if i.secondary == nil { + if err := i.loadPrimary(ctx); err != nil { + return nil, err + } + if err := i.buildSecondary(ctx); err != nil { + return nil, err + } } if err := i.loadSecondary(ctx); err != nil { if errors.Is(err, io.EOF) { @@ -865,9 +869,7 @@ func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) { } return nil, err } - row := i.buildRow(i.primaryRow, i.secondaryRow) - i.secondaryRow = nil if i.cond != nil { if res, err := sql.EvaluateCondition(ctx, i.cond, row); err != nil { return nil, err