Skip to content

Commit 3610cac

Browse files
authored
Merge pull request #1710 from dolthub/nicktobey/no-panic
Detect invalid uses of * and window functions in queries.
2 parents 0cd7a98 + bc44772 commit 3610cac

File tree

20 files changed

+318
-212
lines changed

20 files changed

+318
-212
lines changed

enginetest/enginetests.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838
"github.com/dolthub/go-mysql-server/server"
3939
"github.com/dolthub/go-mysql-server/sql"
4040
"github.com/dolthub/go-mysql-server/sql/analyzer"
41+
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
4142
"github.com/dolthub/go-mysql-server/sql/expression"
4243
"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
4344
"github.com/dolthub/go-mysql-server/sql/mysql_db"
@@ -299,7 +300,7 @@ func TestReadOnlyDatabases(t *testing.T, harness ReadOnlyDatabaseHarness) {
299300
} {
300301
for _, tt := range querySet {
301302
t.Run(tt.WriteQuery, func(t *testing.T) {
302-
AssertErrWithBindings(t, engine, harness, tt.WriteQuery, tt.Bindings, analyzer.ErrReadOnlyDatabase)
303+
AssertErrWithBindings(t, engine, harness, tt.WriteQuery, tt.Bindings, analyzererrors.ErrReadOnlyDatabase)
303304
})
304305
}
305306
}

enginetest/queries/order_by_group_by_queries.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ package queries
1616

1717
import (
1818
"github.com/dolthub/go-mysql-server/sql"
19-
"github.com/dolthub/go-mysql-server/sql/analyzer"
19+
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
2020
)
2121

2222
var OrderByGroupByScriptTests = []ScriptTest{
@@ -245,7 +245,7 @@ var OrderByGroupByScriptTests = []ScriptTest{
245245
},
246246
{
247247
Query: "select id, team from members group by team",
248-
ExpectedErr: analyzer.ErrValidationGroupBy,
248+
ExpectedErr: analyzererrors.ErrValidationGroupBy,
249249
},
250250
},
251251
},

enginetest/queries/queries.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
"gopkg.in/src-d/go-errors.v1"
2323

2424
"github.com/dolthub/go-mysql-server/sql"
25-
"github.com/dolthub/go-mysql-server/sql/analyzer"
25+
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
2626
"github.com/dolthub/go-mysql-server/sql/expression"
2727
"github.com/dolthub/go-mysql-server/sql/plan"
2828
"github.com/dolthub/go-mysql-server/sql/types"
@@ -7493,6 +7493,13 @@ SELECT * FROM my_cte;`,
74937493
},
74947494
},
74957495
},
7496+
// Regression test for https://github.com/dolthub/dolt/issues/5656
7497+
{
7498+
Query: "select count((select * from (select pk from one_pk limit 1) as sq)) from one_pk;",
7499+
Expected: []sql.Row{
7500+
{4},
7501+
},
7502+
},
74967503
}
74977504

74987505
var KeylessQueries = []QueryTest{
@@ -8355,11 +8362,11 @@ var ErrorQueries = []QueryErrorTest{
83558362
// TODO: The following two queries should work. See https://github.com/dolthub/go-mysql-server/issues/542.
83568363
{
83578364
Query: "SELECT SUM(i), i FROM mytable GROUP BY i ORDER BY 1+SUM(i) ASC",
8358-
ExpectedErr: analyzer.ErrAggregationUnsupported,
8365+
ExpectedErr: analyzererrors.ErrAggregationUnsupported,
83598366
},
83608367
{
83618368
Query: "SELECT SUM(i) as sum, i FROM mytable GROUP BY i ORDER BY 1+SUM(i) ASC",
8362-
ExpectedErr: analyzer.ErrAggregationUnsupported,
8369+
ExpectedErr: analyzererrors.ErrAggregationUnsupported,
83638370
},
83648371
{
83658372
Query: "select ((1, 2)) from dual",
@@ -8550,6 +8557,10 @@ var ErrorQueries = []QueryErrorTest{
85508557
Query: "drop table myview;",
85518558
ExpectedErr: sql.ErrUnknownTable,
85528559
},
8560+
{
8561+
Query: "select SUM(*) from dual;",
8562+
ExpectedErr: analyzererrors.ErrStarUnsupported,
8563+
},
85538564
}
85548565

85558566
var BrokenErrorQueries = []QueryErrorTest{

enginetest/queries/script_queries.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
"gopkg.in/src-d/go-errors.v1"
2121

2222
"github.com/dolthub/go-mysql-server/sql"
23-
"github.com/dolthub/go-mysql-server/sql/analyzer"
23+
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
2424
"github.com/dolthub/go-mysql-server/sql/plan"
2525
"github.com/dolthub/go-mysql-server/sql/types"
2626
)
@@ -1396,15 +1396,15 @@ var ScriptTests = []ScriptTest{
13961396
},
13971397
{
13981398
Query: "SELECT col0, col1 FROM tab1 GROUP by col0;",
1399-
ExpectedErr: analyzer.ErrValidationGroupBy,
1399+
ExpectedErr: analyzererrors.ErrValidationGroupBy,
14001400
},
14011401
{
14021402
Query: "SELECT col0, floor(col1) FROM tab1 GROUP by col0;",
1403-
ExpectedErr: analyzer.ErrValidationGroupBy,
1403+
ExpectedErr: analyzererrors.ErrValidationGroupBy,
14041404
},
14051405
{
14061406
Query: "SELECT floor(cor0.col1) * ceil(cor0.col0) AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0",
1407-
ExpectedErr: analyzer.ErrValidationGroupBy,
1407+
ExpectedErr: analyzererrors.ErrValidationGroupBy,
14081408
},
14091409
},
14101410
},

enginetest/queries/update_queries.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package queries
1717
import (
1818
"github.com/dolthub/vitess/go/mysql"
1919

20+
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
2021
"github.com/dolthub/go-mysql-server/sql/types"
2122

2223
"github.com/dolthub/go-mysql-server/sql"
@@ -730,6 +731,18 @@ var UpdateErrorTests = []QueryErrorTest{
730731
Query: `UPDATE people set height_inches = null where height_inches < 100`,
731732
ExpectedErr: sql.ErrInsertIntoNonNullableProvidedNull,
732733
},
734+
{
735+
Query: `UPDATE people SET height_inches = IF(SUM(height_inches) % 2 = 0, 42, height_inches)`,
736+
ExpectedErr: analyzererrors.ErrAggregationUnsupported,
737+
},
738+
{
739+
Query: `UPDATE people SET height_inches = IF(SUM(*) % 2 = 0, 42, height_inches)`,
740+
ExpectedErr: analyzererrors.ErrStarUnsupported,
741+
},
742+
{
743+
Query: `UPDATE people SET height_inches = IF(ROW_NUMBER() OVER() % 2 = 0, 42, height_inches)`,
744+
ExpectedErr: analyzererrors.ErrWindowUnsupported,
745+
},
733746
}
734747

735748
var UpdateErrorScripts = []ScriptTest{
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright 2020-2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package analyzererrors
16+
17+
import "gopkg.in/src-d/go-errors.v1"
18+
19+
var (
20+
// ErrValidationResolved is returned when the plan can not be resolved.
21+
ErrValidationResolved = errors.NewKind("plan is not resolved because of node '%T'")
22+
// ErrValidationOrderBy is returned when the order by contains aggregation
23+
// expressions.
24+
ErrValidationOrderBy = errors.NewKind("OrderBy does not support aggregation expressions")
25+
// ErrValidationGroupBy is returned when the aggregation expression does not
26+
// appear in the grouping columns.
27+
ErrValidationGroupBy = errors.NewKind("expression '%v' doesn't appear in the group by expressions")
28+
// ErrValidationSchemaSource is returned when there is any column source
29+
// that does not match the table name.
30+
ErrValidationSchemaSource = errors.NewKind("one or more schema sources are empty")
31+
// ErrUnknownIndexColumns is returned when there are columns in the expr
32+
// to index that are unknown in the table.
33+
ErrUnknownIndexColumns = errors.NewKind("unknown columns to index for table %q: %s")
34+
// ErrCaseResultType is returned when one or more of the types of the values in
35+
// a case expression don't match.
36+
ErrCaseResultType = errors.NewKind(
37+
"expecting all case branches to return values of type %s, " +
38+
"but found value %q of type %s on %s",
39+
)
40+
// ErrIntervalInvalidUse is returned when an interval expression is not
41+
// correctly used.
42+
ErrIntervalInvalidUse = errors.NewKind(
43+
"invalid use of an interval, which can only be used with DATE_ADD, " +
44+
"DATE_SUB and +/- operators to subtract from or add to a date",
45+
)
46+
// ErrExplodeInvalidUse is returned when an EXPLODE function is used
47+
// outside a Project node.
48+
ErrExplodeInvalidUse = errors.NewKind(
49+
"using EXPLODE is not supported outside a Project node",
50+
)
51+
52+
// ErrSubqueryFieldIndex is returned when an expression subquery references a field outside the range of the rows it
53+
// works on.
54+
ErrSubqueryFieldIndex = errors.NewKind(
55+
"subquery field index out of range for expression %s: only %d columns available",
56+
)
57+
58+
// ErrUnionSchemasMatch is returned when both sides of a UNION do not
59+
// have the same schema.
60+
ErrUnionSchemasMatch = errors.NewKind(
61+
"the schema of the left side of union does not match the right side, expected %s to match %s",
62+
)
63+
64+
// ErrReadOnlyDatabase is returned when a write is attempted to a ReadOnlyDatabse.
65+
ErrReadOnlyDatabase = errors.NewKind("Database %s is read-only.")
66+
67+
// ErrAggregationUnsupported is returned when the analyzer has failed
68+
// to push down an Aggregation in an expression to a GroupBy node.
69+
ErrAggregationUnsupported = errors.NewKind(
70+
"an aggregation remained in the expression '%s' after analysis, outside of a node capable of evaluating it; this query is currently unsupported.",
71+
)
72+
73+
ErrWindowUnsupported = errors.NewKind(
74+
"a window function '%s' is in a context where it cannot be evaluated.",
75+
)
76+
77+
ErrStarUnsupported = errors.NewKind(
78+
"a '*' is in a context where it is not allowed.",
79+
)
80+
)

sql/analyzer/resolve_subqueries.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package analyzer
1616

1717
import (
1818
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
1920
"github.com/dolthub/go-mysql-server/sql/expression"
2021
"github.com/dolthub/go-mysql-server/sql/plan"
2122
"github.com/dolthub/go-mysql-server/sql/transform"
@@ -72,7 +73,7 @@ func finalizeSubqueriesHelper(ctx *sql.Context, a *Analyzer, node sql.Node, scop
7273
if sq, ok := e.(*plan.Subquery); ok {
7374
newSq, same2, err := analyzeSubqueryExpression(ctx, a, node, sq, scope, sel, true)
7475
if err != nil {
75-
if ErrValidationResolved.Is(err) {
76+
if analyzererrors.ErrValidationResolved.Is(err) {
7677
// if a parent is unresolved, we want to dig deeper to find the unresolved
7778
// child dependency
7879
_, _, err := finalizeSubqueriesHelper(ctx, a, sq.Query, scope.newScopeFromSubqueryExpression(node), sel)
@@ -162,7 +163,7 @@ func analyzeSubqueryExpression(ctx *sql.Context, a *Analyzer, n sql.Node, sq *pl
162163
if err != nil {
163164
// We ignore certain errors during non-final passes of the analyzer, deferring them to later analysis passes.
164165
// Specifically, if the subquery isn't resolved or a column can't be found in the scope node, wait until a later pass.
165-
if !finalize && (ErrValidationResolved.Is(err) || sql.ErrTableColumnNotFound.Is(err) || sql.ErrColumnNotFound.Is(err)) {
166+
if !finalize && (analyzererrors.ErrValidationResolved.Is(err) || sql.ErrTableColumnNotFound.Is(err) || sql.ErrColumnNotFound.Is(err)) {
166167
// keep the work we have and defer remainder of analysis of this subquery until a later pass
167168
return sq.WithQuery(analyzed), transform.NewTree, nil
168169
}

sql/analyzer/rule_ids.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ type RuleId int
77
const (
88
// once before
99
applyDefaultSelectLimitId RuleId = iota // applyDefaultSelectLimit
10-
validateOffsetAndLimitId //validateOffsetAndLimit
10+
validateOffsetAndLimitId // validateOffsetAndLimit
11+
validateStarExpressionsId // validateStarExpressions
1112
validateCreateTableId // validateCreateTable
1213
validateExprSemId // validateExprSem
1314
resolveVariablesId // resolveVariables
@@ -28,7 +29,7 @@ const (
2829
validateDropConstraintId // validateDropConstraint
2930
loadCheckConstraintsId // loadCheckConstraints
3031
assignCatalogId // assignCatalog
31-
resolveAnalyzeTablesId //resolveAnalyzeTables
32+
resolveAnalyzeTablesId // resolveAnalyzeTables
3233
resolveCreateSelectId // resolveCreateSelect
3334
resolveSubqueriesId // resolveSubqueries
3435
setViewTargetSchemaId // setViewTargetSchema

0 commit comments

Comments
 (0)