Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -9374,6 +9374,24 @@ from typestable`,
Query: "select c1, c2 from one_pk where c2 = 1 group by c1",
Expected: []sql.Row{{0, 1}},
},
// https://github.com/dolthub/dolt/issues/9699
// Correlated columns in subqueries are included in select dependencies
{
Query: "select any_value(pk), (select max(pk) from one_pk where pk < opk.pk) as x from one_pk opk",
Expected: []sql.Row{{0, nil}, {1, 0}, {2, 1}, {3, 2}},
},
{
Query: "SELECT any_value(pk), (SELECT max(pk) FROM one_pk WHERE pk < opk.pk) AS x FROM one_pk opk WHERE (SELECT max(pk) FROM one_pk WHERE pk < opk.pk) > 0;",
Expected: []sql.Row{{2, 1}, {3, 2}},
},
{
Query: "select any_value(pk), (select max(pk) from one_pk where pk < (opk.c1 - 10)) as x from one_pk opk",
Expected: []sql.Row{{0, nil}, {1, nil}, {2, 3}, {3, 3}},
},
{
Query: "select pk, (select max(pk) from one_pk where pk < opk.pk) as x from one_pk opk",
Expected: []sql.Row{{0, nil}, {1, 0}, {2, 1}, {3, 2}},
},
}

var KeylessQueries = []QueryTest{
Expand Down
18 changes: 12 additions & 6 deletions sql/planbuilder/aggregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s

group := fromScope.groupBy
outScope := group.outScope
// select columns:
// - aggs
// - extra columns needed by having, order by, select
// Select dependencies include aggregations and table columns needed for projections, having, and sort (order by)
var selectDeps []sql.Expression
var selectGfs []sql.Expression
selectStr := make(map[string]bool)
Expand All @@ -224,8 +222,8 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
default:
}

// projection dependencies -> table cols needed above
transform.InspectExpr(col.scalar, func(e sql.Expression) bool {
var findSelectDeps func(sql.Expression) bool
findSelectDeps = func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.GetField:
colName := strings.ToLower(e.String())
Expand All @@ -241,10 +239,18 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
} else if isAliasDep && !inAlias {
aliasDeps[exprStr] = false
}
case *plan.Subquery:
e.Correlated().ForEach(func(colId sql.ColumnId) {
if correlated, found := projScope.parent.getCol(colId); found {
findSelectDeps(correlated.scalarGf())
}
})
default:
}
return false
})
}

transform.InspectExpr(col.scalar, findSelectDeps)
}
for _, e := range fromScope.extraCols {
// accessory cols used by ORDER_BY, HAVING
Expand Down
6 changes: 3 additions & 3 deletions sql/planbuilder/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ Project
│ └─ tableId: 2
│ ->(select u from uv where x = u)]
└─ GroupBy
├─ select:
├─ select: xy.x:1!null
├─ group: Subquery
│ ├─ cacheable: false
│ ├─ alias-string: select u from uv where x = u
Expand Down Expand Up @@ -2050,7 +2050,7 @@ Project
│ └─ tableId: 0
│ ->a1:8]
└─ Project
├─ columns: [max(xy.x):4!null, Subquery
├─ columns: [max(xy.x):4!null, xy.x:1!null, Subquery
│ ├─ cacheable: false
│ ├─ alias-string: select max(dt.a) from (select x as a) as dt (a)
│ └─ Project
Expand All @@ -2074,7 +2074,7 @@ Project
│ └─ tableId: 0
│ ->a1:8]
└─ GroupBy
├─ select: MAX(xy.x:1!null)
├─ select: MAX(xy.x:1!null), xy.x:1!null
├─ group: Subquery
│ ├─ cacheable: false
│ ├─ alias-string: select max(dt.a) from (select x as a) as dt (a)
Expand Down
12 changes: 12 additions & 0 deletions sql/planbuilder/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bo
return c, true
}

// getCol gets a scopeColumn based on a columnId
func (s *scope) getCol(colId sql.ColumnId) (scopeColumn, bool) {
if s.colset.Contains(colId) {
for _, c := range s.cols {
if sql.ColumnId(c.id) == colId {
return c, true
}
}
}
return scopeColumn{}, false
}

func (s *scope) hasTable(table string) bool {
_, ok := s.tables[strings.ToLower(table)]
if ok {
Expand Down
Loading