@@ -251,7 +251,9 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
251251 var err error
252252 var parent sql.Node
253253 checkParent := false
254- var projectParent * plan.Project
254+ var project * plan.Project
255+ var having * plan.Having
256+ var filter * plan.Filter
255257 transform .Inspect (n , func (n sql.Node ) bool {
256258 defer func () {
257259 parent = n
@@ -282,28 +284,45 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
282284 }
283285 }
284286
285- // TODO set groupBys to equalsExpr
286287 groupBys := make (map [string ]bool )
287288 groupByPrimaryKeys := 0
288- for _ , expr := range n .GroupByExprs {
289+ isJoin := false
290+ exprs := make ([]sql.Expression , 0 )
291+ exprs = append (exprs , n .GroupByExprs ... )
292+ if having != nil {
293+ if eq , ok := having .Cond .(* expression.Equals ); ok {
294+ exprs = append (exprs , eq .Children ()... )
295+ }
296+ }
297+ if filter != nil {
298+ }
299+ if join , ok := n .Child .(* plan.JoinNode ); ok {
300+ isJoin = true
301+ if eq , ok := join .Filter .(* expression.Equals ); ok {
302+ exprs = append (exprs , eq .Children ()... )
303+ }
304+ }
305+ for _ , expr := range exprs {
289306 sql .Inspect (expr , func (expr sql.Expression ) bool {
290307 exprStr := strings .ToLower (expr .String ())
291- groupBys [exprStr ] = true
292- if primaryKeys [exprStr ] {
308+ if primaryKeys [exprStr ] && ! groupBys [exprStr ] {
293309 groupByPrimaryKeys ++
294310 }
311+ groupBys [exprStr ] = true
295312
296313 _ , isAlias := expr .(* expression.Alias )
297314 return isAlias
298315 })
299316 }
300317
301318 // TODO: also allow grouping by unique non-nullable columns
302- if len (primaryKeys ) != 0 && groupByPrimaryKeys == len (primaryKeys ) {
319+ // TODO: There's currently no way to tell whether or not a primary key column is part of a multi-column
320+ // primary key.
321+ if len (primaryKeys ) != 0 && (groupByPrimaryKeys == len (primaryKeys ) || (isJoin && groupByPrimaryKeys > 0 )) {
303322 return true
304323 }
305324
306- selectExprs := getSelectExprs (projectParent , n .SelectDeps , groupBys )
325+ selectExprs := getSelectExprs (project , n .SelectDeps , groupBys )
307326
308327 for _ , expr := range selectExprs {
309328 if ! expressionReferencesOnlyGroupBys (groupBys , expr ) {
@@ -314,10 +333,11 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
314333 }
315334 }
316335 case * plan.Project :
317- projectParent = n
318- case * plan.Filter , * plan.Having :
319- // TODO inspect for equals and add GetField (if direct child) to equalsExprs
320- // make sure filter hasn't been pushed down yet
336+ project = n
337+ case * plan.Filter :
338+ filter = n
339+ case * plan.Having :
340+ having = n
321341 }
322342 return true
323343 })
0 commit comments