@@ -249,41 +249,59 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
249249 }
250250
251251 var err error
252- // var parent sql.Node
252+ var parent sql.Node
253253 transform .Inspect (n , func (n sql.Node ) bool {
254- // defer func() {
255- // parent = n
256- // }()
254+ defer func () {
255+ parent = n
256+ }()
257257
258258 gb , ok := n .(* plan.GroupBy )
259259 if ! ok {
260260 return true
261261 }
262262
263- // switch parent.(type) {
264- // case *plan.Having, *plan.Project, *plan.Sort:
265- // // TODO: these shouldn't be skipped; you can group by primary key without problem b/c only one value
266- // // https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html#:~:text=The%20query%20is%20valid%20if%20name%20is%20a%20primary%20key
267- // return true
268- // }
263+ switch parent .(type ) {
264+ case * plan.Having , * plan.Project , * plan.Sort :
265+ // TODO: these shouldn't be skipped; you can group by primary key without problem b/c only one value
266+ // https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html#:~:text=The%20query%20is%20valid%20if%20name%20is%20a%20primary%20key
267+ return true
268+ }
269269
270270 // Allow the parser use the GroupBy node to eval the aggregation functions
271271 // for sql statements that don't make use of the GROUP BY expression.
272272 if len (gb .GroupByExprs ) == 0 {
273273 return true
274274 }
275275
276- var groupBys []string
276+ primaryKeys := make (map [string ]bool )
277+ for _ , col := range gb .Child .Schema () {
278+ if col .PrimaryKey {
279+ primaryKeys [strings .ToLower (col .Name )] = true
280+ }
281+ }
282+
283+ groupBys := make (map [string ]bool )
284+ groupByAliases := make (map [string ]bool )
285+ groupByPrimaryKeys := 0
277286 for _ , expr := range gb .GroupByExprs {
278- groupBys = append (groupBys , expr .String ())
287+ exprStr := strings .ToLower (expr .String ())
288+ groupBys [exprStr ] = true
289+ if primaryKeys [exprStr ] {
290+ groupByPrimaryKeys ++
291+ }
292+ if _ , ok := expr .(sql.Aggregation ); ok {
293+ groupByAliases [exprStr ] = true
294+ }
295+ }
296+
297+ if len (primaryKeys ) != 0 && groupByPrimaryKeys == len (primaryKeys ) {
298+ return true
279299 }
280300
281301 for _ , expr := range gb .SelectedExprs {
282- if _ , ok := expr .(sql.Aggregation ); ! ok {
283- if ! expressionReferencesOnlyGroupBys (groupBys , expr ) {
284- err = analyzererrors .ErrValidationGroupBy .New (expr .String ())
285- return false
286- }
302+ if ! expressionReferencesOnlyGroupBys (groupBys , groupByAliases , expr ) {
303+ err = analyzererrors .ErrValidationGroupBy .New (expr .String ())
304+ return false
287305 }
288306 }
289307 return true
@@ -292,22 +310,15 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
292310 return n , transform .SameTree , err
293311}
294312
295- func expressionReferencesOnlyGroupBys (groupBys [] string , expr sql.Expression ) bool {
313+ func expressionReferencesOnlyGroupBys (groupBys , groupByAliases map [ string ] bool , expr sql.Expression ) bool {
296314 valid := true
297315 sql .Inspect (expr , func (expr sql.Expression ) bool {
316+ exprStr := strings .ToLower (expr .String ())
298317 switch expr := expr .(type ) {
299318 case nil , sql.Aggregation , * expression.Literal :
300319 return false
301- case * expression.Alias , sql.FunctionExpression :
302- if stringContains (groupBys , expr .String ()) {
303- return false
304- }
305- return true
306- // cc: https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html
307- // Each part of the SelectExpr must refer to the aggregated columns in some way
308- // TODO: this isn't complete, it's overly restrictive. Dependant columns are fine to reference.
309320 default :
310- if stringContains ( groupBys , expr . String ()) {
321+ if groupBys [ exprStr ] {
311322 return false
312323 }
313324
0 commit comments