Skip to content

Commit 0787112

Browse files
committed
add join filters to group by expressions
1 parent 7fbc47e commit 0787112

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

enginetest/join_planning_tests.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ order by 1;`,
512512
},
513513
{
514514
// group by doesn't transform
515-
q: "select * from xy where y-1 in (select u from uv group by v having v = 2 order by 1) order by 1;",
515+
q: "select * from xy where y-1 in (select u from uv group by u having u = 2 order by 1) order by 1;",
516516
types: []plan.JoinType{plan.JoinTypeSemi},
517517
exp: []sql.Row{{3, 3}},
518518
},

sql/analyzer/validation_rules.go

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
251251
var err error
252252
var parent sql.Node
253253
checkParent := false
254-
var projectParent *plan.Project
254+
var project *plan.Project
255+
var having *plan.Having
256+
var filter *plan.Filter
255257
transform.Inspect(n, func(n sql.Node) bool {
256258
defer func() {
257259
parent = n
@@ -282,28 +284,45 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
282284
}
283285
}
284286

285-
// TODO set groupBys to equalsExpr
286287
groupBys := make(map[string]bool)
287288
groupByPrimaryKeys := 0
288-
for _, expr := range n.GroupByExprs {
289+
isJoin := false
290+
exprs := make([]sql.Expression, 0)
291+
exprs = append(exprs, n.GroupByExprs...)
292+
if having != nil {
293+
if eq, ok := having.Cond.(*expression.Equals); ok {
294+
exprs = append(exprs, eq.Children()...)
295+
}
296+
}
297+
if filter != nil {
298+
}
299+
if join, ok := n.Child.(*plan.JoinNode); ok {
300+
isJoin = true
301+
if eq, ok := join.Filter.(*expression.Equals); ok {
302+
exprs = append(exprs, eq.Children()...)
303+
}
304+
}
305+
for _, expr := range exprs {
289306
sql.Inspect(expr, func(expr sql.Expression) bool {
290307
exprStr := strings.ToLower(expr.String())
291-
groupBys[exprStr] = true
292-
if primaryKeys[exprStr] {
308+
if primaryKeys[exprStr] && !groupBys[exprStr] {
293309
groupByPrimaryKeys++
294310
}
311+
groupBys[exprStr] = true
295312

296313
_, isAlias := expr.(*expression.Alias)
297314
return isAlias
298315
})
299316
}
300317

301318
// TODO: also allow grouping by unique non-nullable columns
302-
if len(primaryKeys) != 0 && groupByPrimaryKeys == len(primaryKeys) {
319+
// TODO: There's currently no way to tell whether or not a primary key column is part of a multi-column
320+
// primary key.
321+
if len(primaryKeys) != 0 && (groupByPrimaryKeys == len(primaryKeys) || (isJoin && groupByPrimaryKeys > 0)) {
303322
return true
304323
}
305324

306-
selectExprs := getSelectExprs(projectParent, n.SelectDeps, groupBys)
325+
selectExprs := getSelectExprs(project, n.SelectDeps, groupBys)
307326

308327
for _, expr := range selectExprs {
309328
if !expressionReferencesOnlyGroupBys(groupBys, expr) {
@@ -314,10 +333,11 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
314333
}
315334
}
316335
case *plan.Project:
317-
projectParent = n
318-
case *plan.Filter, *plan.Having:
319-
// TODO inspect for equals and add GetField (if direct child) to equalsExprs
320-
// make sure filter hasn't been pushed down yet
336+
project = n
337+
case *plan.Filter:
338+
filter = n
339+
case *plan.Having:
340+
having = n
321341
}
322342
return true
323343
})

sql/plan/subquery.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,12 @@ func (s *Subquery) IsNullable() bool {
502502
}
503503

504504
func (s *Subquery) String() string {
505-
pr := sql.NewTreePrinter()
506-
_ = pr.WriteNode("Subquery")
507-
children := []string{fmt.Sprintf("cacheable: %t", s.canCacheResults()), s.Query.String()}
508-
_ = pr.WriteChildren(children...)
509-
return pr.String()
505+
//pr := sql.NewTreePrinter()
506+
//_ = pr.WriteNode("Subquery")
507+
//children := []string{fmt.Sprintf("cacheable: %t", s.canCacheResults()), s.Query.String()}
508+
//_ = pr.WriteChildren(children...)
509+
//return pr.String()
510+
return fmt.Sprintf("Subquery(%s)", s.QueryString)
510511
}
511512

512513
func (s *Subquery) DebugString() string {

0 commit comments

Comments
 (0)