diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 8afd96b32c..ce1853607a 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -9374,6 +9374,24 @@ from typestable`, Query: "select c1, c2 from one_pk where c2 = 1 group by c1", Expected: []sql.Row{{0, 1}}, }, + // https://github.com/dolthub/dolt/issues/9699 + // Correlated columns in subqueries are included in select dependencies + { + Query: "select any_value(pk), (select max(pk) from one_pk where pk < opk.pk) as x from one_pk opk", + Expected: []sql.Row{{0, nil}, {1, 0}, {2, 1}, {3, 2}}, + }, + { + Query: "SELECT any_value(pk), (SELECT max(pk) FROM one_pk WHERE pk < opk.pk) AS x FROM one_pk opk WHERE (SELECT max(pk) FROM one_pk WHERE pk < opk.pk) > 0;", + Expected: []sql.Row{{2, 1}, {3, 2}}, + }, + { + Query: "select any_value(pk), (select max(pk) from one_pk where pk < (opk.c1 - 10)) as x from one_pk opk", + Expected: []sql.Row{{0, nil}, {1, nil}, {2, 3}, {3, 3}}, + }, + { + Query: "select pk, (select max(pk) from one_pk where pk < opk.pk) as x from one_pk opk", + Expected: []sql.Row{{0, nil}, {1, 0}, {2, 1}, {3, 2}}, + }, } var KeylessQueries = []QueryTest{ diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 940224b29f..cbb819f883 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -197,9 +197,7 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s group := fromScope.groupBy outScope := group.outScope - // select columns: - // - aggs - // - extra columns needed by having, order by, select + // Select dependencies include aggregations and table columns needed for projections, having, and sort (order by) var selectDeps []sql.Expression var selectGfs []sql.Expression selectStr := make(map[string]bool) @@ -224,8 +222,8 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s default: } - // projection dependencies -> table cols needed above - transform.InspectExpr(col.scalar, func(e sql.Expression) bool { + var findSelectDeps func(sql.Expression) bool + findSelectDeps = func(e sql.Expression) bool { switch e := e.(type) { case *expression.GetField: colName := strings.ToLower(e.String()) @@ -241,10 +239,18 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s } else if isAliasDep && !inAlias { aliasDeps[exprStr] = false } + case *plan.Subquery: + e.Correlated().ForEach(func(colId sql.ColumnId) { + if correlated, found := projScope.parent.getCol(colId); found { + findSelectDeps(correlated.scalarGf()) + } + }) default: } return false - }) + } + + transform.InspectExpr(col.scalar, findSelectDeps) } for _, e := range fromScope.extraCols { // accessory cols used by ORDER_BY, HAVING diff --git a/sql/planbuilder/parse_test.go b/sql/planbuilder/parse_test.go index 299232be2c..ec4899cc96 100644 --- a/sql/planbuilder/parse_test.go +++ b/sql/planbuilder/parse_test.go @@ -1058,7 +1058,7 @@ Project │ └─ tableId: 2 │ ->(select u from uv where x = u)] └─ GroupBy - ├─ select: + ├─ select: xy.x:1!null ├─ group: Subquery │ ├─ cacheable: false │ ├─ alias-string: select u from uv where x = u @@ -2050,7 +2050,7 @@ Project │ └─ tableId: 0 │ ->a1:8] └─ Project - ├─ columns: [max(xy.x):4!null, Subquery + ├─ columns: [max(xy.x):4!null, xy.x:1!null, Subquery │ ├─ cacheable: false │ ├─ alias-string: select max(dt.a) from (select x as a) as dt (a) │ └─ Project @@ -2074,7 +2074,7 @@ Project │ └─ tableId: 0 │ ->a1:8] └─ GroupBy - ├─ select: MAX(xy.x:1!null) + ├─ select: MAX(xy.x:1!null), xy.x:1!null ├─ group: Subquery │ ├─ cacheable: false │ ├─ alias-string: select max(dt.a) from (select x as a) as dt (a) diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index f9ad2360fa..730960354c 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -149,6 +149,18 @@ func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bo return c, true } +// getCol gets a scopeColumn based on a columnId +func (s *scope) getCol(colId sql.ColumnId) (scopeColumn, bool) { + if s.colset.Contains(colId) { + for _, c := range s.cols { + if sql.ColumnId(c.id) == colId { + return c, true + } + } + } + return scopeColumn{}, false +} + func (s *scope) hasTable(table string) bool { _, ok := s.tables[strings.ToLower(table)] if ok {