Skip to content

Commit 6ba4161

Browse files
authored
Merge pull request #3224 from dolthub/angela/groupby
Validate expressions in `ORDER BY` clause during `GROUP BY` validation
2 parents 1e27be8 + 8341c91 commit 6ba4161

File tree

5 files changed

+120
-31
lines changed

5 files changed

+120
-31
lines changed

enginetest/enginetests.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ func TestOrderByGroupBy(t *testing.T, harness Harness) {
828828
// group by with any_value or non-strict are non-deterministic (unless there's only one value), so we must accept multiple
829829
// group by with any_value()
830830

831-
_, rowIter, _, err = e.Query(ctx, "select any_value(id), team from members group by team order by id")
831+
_, rowIter, _, err = e.Query(ctx, "select any_value(id), team from members group by team")
832832
require.NoError(t, err)
833833
rowCount = 0
834834

@@ -867,6 +867,7 @@ func TestOrderByGroupBy(t *testing.T, harness Harness) {
867867
require.Equal(t, rowCount, 3)
868868

869869
AssertErr(t, e, harness, "select id, team from members group by team order by id", nil, analyzererrors.ErrValidationGroupBy)
870+
AssertErr(t, e, harness, "select any_value(id), team from members group by team order by id", nil, analyzererrors.ErrValidationGroupByOrderBy)
870871
})
871872
}
872873

enginetest/queries/order_by_group_by_queries.go

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,8 @@ var OrderByGroupByScriptTests = []ScriptTest{
109109
},
110110
},
111111
{
112-
Query: "select binary s from t group by binary s order by s",
113-
Expected: []sql.Row{
114-
{[]uint8("abc")},
115-
{[]uint8("def")},
116-
},
112+
Query: "select binary s from t group by binary s order by s",
113+
ExpectedErr: analyzererrors.ErrValidationGroupByOrderBy,
117114
},
118115
},
119116
},
@@ -365,4 +362,44 @@ var OrderByGroupByScriptTests = []ScriptTest{
365362
},
366363
},
367364
},
365+
{
366+
Name: "valid group by order by queries",
367+
SetUpScript: []string{
368+
"create table t0(c0 int primary key, c1 int, c2 int, c3 int)",
369+
"insert into t0 values (3, 1, 3, 1), (4, 1, 7, 2), (5, 2, 9, 3),(6,2, 1, 3), (7,2, 2, 2),(8,3,2, 5)",
370+
},
371+
Assertions: []ScriptTestAssertion{
372+
{
373+
// group by primary key
374+
Query: "select c1 from t0 group by c0 order by c2",
375+
Expected: []sql.Row{{2}, {2}, {3}, {1}, {1}, {2}},
376+
},
377+
{
378+
// order by aggregate
379+
Query: "select c1 from t0 group by c1 order by min(c2)",
380+
Expected: []sql.Row{{2}, {3}, {1}},
381+
},
382+
{
383+
// order by alias for column in group by clause
384+
Query: "select c1 as col from t0 group by c1 order by col",
385+
Expected: []sql.Row{{1}, {2}, {3}},
386+
},
387+
{
388+
// order by alias for aggregate column
389+
Query: "select min(c0) as min, c1 from t0 group by c1 order by min",
390+
Expected: []sql.Row{{3, 1}, {5, 2}, {8, 3}},
391+
},
392+
{
393+
// order by multiple columns
394+
Query: "select c1 from t0 group by c1, c2, c3 order by c2, c3",
395+
Expected: []sql.Row{{2}, {2}, {3}, {1}, {1}, {2}},
396+
},
397+
{
398+
// order by functionally dependent column
399+
Dialect: "mysql",
400+
Query: "select c1 from t0 where c2 = 3 group by c1 order by c2",
401+
Expected: []sql.Row{{1}},
402+
},
403+
},
404+
},
368405
}

enginetest/queries/queries.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10423,6 +10423,10 @@ var ErrorQueries = []QueryErrorTest{
1042310423
Query: "select * from two_pk group by pk1 + 1, mod(pk2, 2)",
1042410424
ExpectedErr: analyzererrors.ErrValidationGroupBy,
1042510425
},
10426+
{
10427+
Query: `select s from mytable group by s order by i`,
10428+
ExpectedErr: analyzererrors.ErrValidationGroupByOrderBy,
10429+
},
1042610430
}
1042710431

1042810432
var BrokenErrorQueries = []QueryErrorTest{

sql/analyzer/analyzererrors/errors.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,49 @@ import "gopkg.in/src-d/go-errors.v1"
1919
var (
2020
// ErrValidationResolved is returned when the plan can not be resolved.
2121
ErrValidationResolved = errors.NewKind("plan is not resolved because of node '%T'")
22+
2223
// ErrValidationOrderBy is returned when the order by contains aggregation
2324
// expressions.
2425
ErrValidationOrderBy = errors.NewKind("OrderBy does not support aggregation expressions")
25-
// ErrValidationGroupBy is returned when the aggregation expression does not
26-
// appear in the grouping columns.
26+
27+
// ErrValidationGroupBy is returned when a selected expression contains a nonaggregated column that does not appear
28+
// in the group by clause
2729
ErrValidationGroupBy = errors.NewKind(
28-
"Expression #%d of SELECT list is not in GROUP BY clause; " +
30+
"Expression #%d of SELECT list is not in GROUP BY clause and contains nonaggregated column '%s' which " +
31+
"is not functionally dependent on columns in GROUP BY clause; " +
32+
"this is incompatible with sql_mode=only_full_group_by",
33+
)
34+
35+
// ErrValidationGroupByOrderBy is returned when an order by expression contains a nonaggregated column that does not
36+
// appear in the group by clause
37+
ErrValidationGroupByOrderBy = errors.NewKind(
38+
"Expression #%d of ORDER BY clause is not in GROUP BY clause and contains nonaggregated column '%s' which " +
39+
"is not functionally dependent on columns in GROUP BY clause; " +
2940
"this is incompatible with sql_mode=only_full_group_by",
3041
)
42+
3143
// ErrValidationSchemaSource is returned when there is any column source
3244
// that does not match the table name.
3345
ErrValidationSchemaSource = errors.NewKind("one or more schema sources are empty")
46+
3447
// ErrUnknownIndexColumns is returned when there are columns in the expr
3548
// to index that are unknown in the table.
3649
ErrUnknownIndexColumns = errors.NewKind("unknown columns to index for table %q: %s")
50+
3751
// ErrCaseResultType is returned when one or more of the types of the values in
3852
// a case expression don't match.
3953
ErrCaseResultType = errors.NewKind(
4054
"expecting all case branches to return values of type %s, " +
4155
"but found value %q of type %s on %s",
4256
)
57+
4358
// ErrIntervalInvalidUse is returned when an interval expression is not
4459
// correctly used.
4560
ErrIntervalInvalidUse = errors.NewKind(
4661
"invalid use of an interval, which can only be used with DATE_ADD, " +
4762
"DATE_SUB and +/- operators to subtract from or add to a date",
4863
)
64+
4965
// ErrExplodeInvalidUse is returned when an EXPLODE function is used
5066
// outside a Project node.
5167
ErrExplodeInvalidUse = errors.NewKind(

sql/analyzer/validation_rules.go

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
252252

253253
var err error
254254
var project *plan.Project
255+
var orderBy *plan.Sort
255256
transform.Inspect(n, func(n sql.Node) bool {
256257
switch n := n.(type) {
257258
case *plan.GroupBy:
@@ -307,20 +308,36 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
307308
return true
308309
}
309310

310-
selectExprs := getSelectExprs(project, n.SelectDeps, groupBys)
311+
selectExprs, orderByExprs := getSelectAndOrderByExprs(project, orderBy, n.SelectDeps, groupBys)
311312

312313
for i, expr := range selectExprs {
313314
if valid, col := expressionReferencesOnlyGroupBys(groupBys, expr, noGroupBy); !valid {
314315
if noGroupBy {
315316
err = sql.ErrNonAggregatedColumnWithoutGroupBy.New(i+1, col)
316317
} else {
317-
err = analyzererrors.ErrValidationGroupBy.New(i + 1)
318+
err = analyzererrors.ErrValidationGroupBy.New(i+1, col)
318319
}
319320
return false
320321
}
321322
}
323+
// According to MySQL documentation, we should still be validating ORDER BY expressions when there's not an
324+
// explicit GROUP BY in the query ("If a query has aggregate functions and no GROUP BY clause, it cannot
325+
// have nonaggregated columns in the select list, HAVING condition, or ORDER BY list with
326+
// ONLY_FULL_GROUP_BY enabled"). But when testing queries in MySQL, it doesn't seem like they actually
327+
// validate ORDER BY expressions in aggregate queries without an explicit GROUP BY
328+
if !noGroupBy {
329+
for i, expr := range orderByExprs {
330+
if valid, col := expressionReferencesOnlyGroupBys(groupBys, expr, noGroupBy); !valid {
331+
err = analyzererrors.ErrValidationGroupByOrderBy.New(i+1, col)
332+
return false
333+
}
334+
}
335+
}
322336
case *plan.Project:
323337
project = n
338+
orderBy = nil
339+
case *plan.Sort:
340+
orderBy = n
324341
}
325342
return true
326343
})
@@ -351,42 +368,56 @@ func getEqualsDependencies(expr sql.Expression) []sql.Expression {
351368

352369
// getSelectExprs transforms the projection expressions from a Project node such that it uses the appropriate select
353370
// dependency expressions.
354-
func getSelectExprs(project *plan.Project, selectDeps []sql.Expression, groupBys map[string]bool) []sql.Expression {
355-
if project == nil {
356-
return selectDeps
371+
func getSelectAndOrderByExprs(project *plan.Project, orderBy *plan.Sort, selectDeps []sql.Expression, groupBys map[string]bool) ([]sql.Expression, []sql.Expression) {
372+
if project == nil && orderBy == nil {
373+
return selectDeps, nil
357374
} else {
358375
sd := make(map[string]sql.Expression, len(selectDeps))
359376
for _, dep := range selectDeps {
360377
sd[strings.ToLower(dep.String())] = dep
361378
}
362379

363380
selectExprs := make([]sql.Expression, 0)
381+
orderByExprs := make([]sql.Expression, 0)
364382

365383
for _, expr := range project.Projections {
366384
if !project.AliasDeps[strings.ToLower(expr.String())] {
367-
resolvedExpr, _, _ := transform.Expr(expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
368-
if groupBys[strings.ToLower(expr.String())] {
369-
return expr, transform.SameTree, nil
370-
}
371-
switch expr := expr.(type) {
372-
case *expression.Alias:
373-
if dep, ok := sd[strings.ToLower(expr.Child.String())]; ok {
374-
return dep, transform.NewTree, nil
375-
}
376-
case *expression.GetField:
377-
if dep, ok := sd[strings.ToLower(expr.String())]; ok {
378-
return dep, transform.NewTree, nil
379-
}
380-
}
381-
return expr, transform.SameTree, nil
382-
})
385+
resolvedExpr := resolveExpr(expr, sd, groupBys)
383386
selectExprs = append(selectExprs, resolvedExpr)
384387
}
385388
}
386-
return selectExprs
389+
390+
if orderBy != nil {
391+
for _, expr := range orderBy.Expressions() {
392+
resolvedExpr := resolveExpr(expr, sd, groupBys)
393+
orderByExprs = append(orderByExprs, resolvedExpr)
394+
}
395+
}
396+
397+
return selectExprs, orderByExprs
387398
}
388399
}
389400

401+
func resolveExpr(expr sql.Expression, selectDeps map[string]sql.Expression, groupBys map[string]bool) sql.Expression {
402+
resolvedExpr, _, _ := transform.Expr(expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
403+
if groupBys[strings.ToLower(expr.String())] {
404+
return expr, transform.SameTree, nil
405+
}
406+
switch expr := expr.(type) {
407+
case *expression.Alias:
408+
if dep, ok := selectDeps[strings.ToLower(expr.Child.String())]; ok {
409+
return dep, transform.NewTree, nil
410+
}
411+
case *expression.GetField:
412+
if dep, ok := selectDeps[strings.ToLower(expr.String())]; ok {
413+
return dep, transform.NewTree, nil
414+
}
415+
}
416+
return expr, transform.SameTree, nil
417+
})
418+
return resolvedExpr
419+
}
420+
390421
// expressionReferencesOnlyGroupBys validates that an expression is dependent on only group by expressions
391422
func expressionReferencesOnlyGroupBys(groupBys map[string]bool, expr sql.Expression, noGroupBy bool) (bool, string) {
392423
var col string

0 commit comments

Comments
 (0)