Skip to content

Commit 5741bf6

Browse files
committed
plpgsql: fix handling of annotations for DO blocks
This commit should fix for good annotations handling for DO blocks that was recently attempted to be fixed in 3b8ab38. (That commit is partially reversed too). The general issue is that in the PLpgSQL parser we use an ephemeral SQL parser instance to process SQL statements. However, when processing the DO block stmts we need to preserve the annotations state between each stmt. This is now achieved by extending the SQL parser infrastructure to give an option to override the initial annotations index as well as keep on reusing the counter across multiple stmts. This allows us to correctly count the number of annotations that we need for the whole DO block during the initial parsing, and then we use the allocated annotations container to set the items in the optbuild. Release note (bug fix): Previously, CockroachDB node could crash when executing DO stmts when they contain user-defined types (possibly non-existing) in non-default configuration (additional logging like the one controlled via `sql.log.all_statements.enabled` cluster setting needed to be enabled). This bug was introduced in 25.1 release and is now fixed.
1 parent e58fa7d commit 5741bf6

File tree

12 files changed

+284
-161
lines changed

12 files changed

+284
-161
lines changed

pkg/sql/catalog/redact/redact_test.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@ func TestRedactQueries(t *testing.T) {
3232
defer srv.Stopper().Stop(ctx)
3333
codec := srv.ApplicationLayer().Codec()
3434
tdb := sqlutils.MakeSQLRunner(db)
35-
tdb.Exec(t, "CREATE TABLE kv (k INT PRIMARY KEY, v STRING)")
36-
tdb.Exec(t, "CREATE VIEW view AS SELECT k, v FROM kv WHERE v <> 'constant literal'")
37-
tdb.Exec(t, "CREATE TABLE ctas AS SELECT k, v FROM kv WHERE v <> 'constant literal'")
35+
tdb.Exec(t, "CREATE TYPE my_enum AS ENUM ('foo', 'bar')")
36+
tdb.Exec(t, "CREATE TABLE kv (k INT PRIMARY KEY, v STRING, e my_enum)")
37+
tdb.Exec(t, "CREATE VIEW view AS SELECT k, v, e FROM kv WHERE v <> 'constant literal' AND e <> 'foo'")
38+
tdb.Exec(t, "CREATE TABLE ctas AS SELECT k, v, e FROM kv WHERE v <> 'constant literal' AND e <> 'foo'")
3839
tdb.Exec(t, `
3940
CREATE FUNCTION f1() RETURNS INT
4041
LANGUAGE SQL
4142
AS $$
4243
SELECT k FROM kv WHERE v != 'foo';
4344
SELECT k FROM kv WHERE v = 'bar';
45+
SELECT k FROM kv WHERE e != 'foo';
46+
SELECT k FROM kv WHERE e = 'bar';
4447
$$`)
4548
tdb.Exec(t, `
4649
CREATE FUNCTION f2() RETURNS INT
@@ -49,8 +52,9 @@ AS $$
4952
DECLARE
5053
x INT := 0;
5154
y TEXT := 'bar';
55+
z my_enum;
5256
BEGIN
53-
SELECT k FROM kv WHERE v != 'foo';
57+
SELECT k FROM kv WHERE v != 'foo' AND e != 'bar'::my_enum;
5458
RETURN x + 3;
5559
END;
5660
$$`)
@@ -61,7 +65,7 @@ $$`)
6165
)
6266
mut := tabledesc.NewBuilder(view.TableDesc()).BuildCreatedMutableTable()
6367
require.Empty(t, redact.Redact(mut.DescriptorProto()))
64-
require.Equal(t, `SELECT k, v FROM defaultdb.public.kv WHERE v != '_'`, mut.ViewQuery)
68+
require.Equal(t, `SELECT k, v, e FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_')`, mut.ViewQuery)
6569
})
6670

6771
t.Run("create table as", func(t *testing.T) {
@@ -70,14 +74,14 @@ $$`)
7074
)
7175
mut := tabledesc.NewBuilder(ctas.TableDesc()).BuildCreatedMutableTable()
7276
require.Empty(t, redact.Redact(mut.DescriptorProto()))
73-
require.Equal(t, `SELECT k, v FROM defaultdb.public.kv WHERE v != '_'`, mut.CreateQuery)
77+
require.Equal(t, `SELECT k, v, e FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_')`, mut.CreateQuery)
7478
})
7579

7680
t.Run("create function sql", func(t *testing.T) {
7781
fn := desctestutils.TestingGetFunctionDescriptor(kvDB, codec, "defaultdb", "public", "f1")
7882
mut := funcdesc.NewBuilder(fn.FuncDesc()).BuildCreatedMutableFunction()
7983
require.Empty(t, redact.Redact(mut.DescriptorProto()))
80-
require.Equal(t, `SELECT k FROM defaultdb.public.kv WHERE v != '_'; SELECT k FROM defaultdb.public.kv WHERE v = '_';`, mut.FunctionBody)
84+
require.Equal(t, `SELECT k FROM defaultdb.public.kv WHERE v != '_'; SELECT k FROM defaultdb.public.kv WHERE v = '_'; SELECT k FROM defaultdb.public.kv WHERE e != '_'; SELECT k FROM defaultdb.public.kv WHERE e = '_';`, mut.FunctionBody)
8185
})
8286

8387
t.Run("create function plpgsql", func(t *testing.T) {
@@ -87,8 +91,9 @@ $$`)
8791
require.Equal(t, `DECLARE
8892
x INT8 := _;
8993
y STRING := '_';
94+
z @100104;
9095
BEGIN
91-
SELECT k FROM defaultdb.public.kv WHERE v != '_';
96+
SELECT k FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_':::@100104);
9297
RETURN x + _;
9398
END;
9499
`, mut.FunctionBody)

pkg/sql/logictest/testdata/logic_test/do

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,23 @@ END
286286
$$;
287287

288288
subtest end
289+
290+
subtest regression_151848
291+
292+
statement error pgcode 42704 type \"bar\" does not exist
293+
DO $$
294+
DECLARE
295+
foo bar;
296+
BEGIN
297+
END;
298+
$$;
299+
300+
statement error pgcode 42704 type \"_\" does not exist
301+
DO $$
302+
DECLARE
303+
foo _[];
304+
BEGIN
305+
END;
306+
$$;
307+
308+
subtest end

pkg/sql/opt/optbuilder/routine.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,19 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
855855
// TODO(drewk): Enable memo reuse with DO statements.
856856
b.DisableMemoReuse = true
857857

858-
defer func(oldInsideFuncDep bool) { b.insideFuncDef = oldInsideFuncDep }(b.insideFuncDef)
858+
doBlockImpl, ok := do.Code.(*plpgsqltree.DoBlock)
859+
if !ok {
860+
panic(errors.AssertionFailedf("expected a plpgsql block"))
861+
}
862+
863+
defer func(oldInsideFuncDep bool, oldAnn tree.Annotations) {
864+
b.insideFuncDef = oldInsideFuncDep
865+
b.semaCtx.Annotations = oldAnn
866+
b.evalCtx.Annotations = &b.semaCtx.Annotations
867+
}(b.insideFuncDef, b.semaCtx.Annotations)
859868
b.insideFuncDef = true
869+
b.semaCtx.Annotations = doBlockImpl.Annotations
870+
b.evalCtx.Annotations = &b.semaCtx.Annotations
860871

861872
// Build the routine body.
862873
var bodyStmts []string
@@ -865,10 +876,6 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
865876
fmtCtx.FormatNode(do.Code)
866877
bodyStmts = []string{fmtCtx.CloseAndGetString()}
867878
}
868-
doBlockImpl, ok := do.Code.(*plpgsqltree.DoBlock)
869-
if !ok {
870-
panic(errors.AssertionFailedf("expected a plpgsql block"))
871-
}
872879
bodyScope := b.buildPLpgSQLDoBody(doBlockImpl)
873880

874881
// Build a CALL expression that invokes the routine.

pkg/sql/parser/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ go_library(
3838
"//pkg/sql/sem/tree/treecmp", # keep
3939
"//pkg/sql/sem/tree/treewindow", # keep
4040
"//pkg/sql/types",
41+
"//pkg/util/buildutil",
4142
"//pkg/util/errorutil/unimplemented",
4243
"//pkg/util/vector", # keep
4344
"@com_github_cockroachdb_errors//:errors",

pkg/sql/parser/lexer.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@ type lexer struct {
3838
lastError error
3939
}
4040

41-
func (l *lexer) init(sql string, tokens []sqlSymType, nakedIntType *types.T) {
41+
// numAnnotations indicates the number of annotations that have already been
42+
// claimed.
43+
func (l *lexer) init(
44+
sql string, tokens []sqlSymType, nakedIntType *types.T, numAnnotations tree.AnnotationIdx,
45+
) {
4246
l.in = sql
4347
l.tokens = tokens
4448
l.lastPos = -1
4549
l.stmt = nil
4650
l.numPlaceholders = 0
47-
l.numAnnotations = 0
51+
l.numAnnotations = numAnnotations
4852
l.lastError = nil
4953

5054
l.nakedIntType = nakedIntType

pkg/sql/parser/lexer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestLexer(t *testing.T) {
3636
scanTokens = append(scanTokens, lval)
3737
}
3838
var l lexer
39-
l.init(d.sql, scanTokens, defaultNakedIntType)
39+
l.init(d.sql, scanTokens, defaultNakedIntType, 0 /* numAnnotations */)
4040
var lexTokens []int
4141
for {
4242
var lval sqlSymType

pkg/sql/parser/parse.go

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/cockroachdb/cockroach/pkg/sql/scanner"
2626
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
2727
"github.com/cockroachdb/cockroach/pkg/sql/types"
28+
"github.com/cockroachdb/cockroach/pkg/util/buildutil"
2829
"github.com/cockroachdb/errors"
2930
)
3031

@@ -47,6 +48,7 @@ type Parser struct {
4748
type ParseOptions struct {
4849
intType *types.T
4950
retainComments bool
51+
numAnnotations *tree.AnnotationIdx
5052
}
5153

5254
var DefaultParseOptions = ParseOptions{
@@ -64,6 +66,18 @@ func (po ParseOptions) WithIntType(t *types.T) ParseOptions {
6466
return po
6567
}
6668

69+
// WithNumAnnotations overrides how annotations are handled. If this option is
70+
// used, then
71+
// - the given integer indicates the number of annotations already claimed, and
72+
// - if multiple stmts are parsed, then unique annotation indexes are used
73+
// across all stmts.
74+
//
75+
// When this option is not used, then each stmt is parsed indepedently.
76+
func (po ParseOptions) WithNumAnnotations(numAnnotations tree.AnnotationIdx) ParseOptions {
77+
po.numAnnotations = &numAnnotations
78+
return po
79+
}
80+
6781
// INT8 is the historical interpretation of INT. This should be left
6882
// alone in the future, since there are many sql fragments stored
6983
// in various descriptors. Any user input that was created after
@@ -156,14 +170,29 @@ func (p *Parser) parseWithDepth(
156170
p.scanner.RetainComments()
157171
}
158172
defer p.scanner.Cleanup()
173+
var numAnnotations tree.AnnotationIdx
174+
if options.numAnnotations != nil {
175+
numAnnotations = *options.numAnnotations
176+
}
159177
for {
160178
sql, tokens, done := p.scanOneStmt()
161-
stmt, err := p.parse(depth+1, sql, tokens, options.intType)
179+
stmt, err := p.parse(depth+1, sql, tokens, options.intType, numAnnotations)
162180
if err != nil {
163181
return nil, err
164182
}
165183
if stmt.AST != nil {
166184
stmts = append(stmts, stmt)
185+
if options.numAnnotations != nil {
186+
if buildutil.CrdbTestBuild && numAnnotations > stmt.NumAnnotations {
187+
return nil, errors.AssertionFailedf(
188+
"annotation index has regressed: numAnnotations=%d, stmt.NumAnnotations=%d ",
189+
numAnnotations, stmt.NumAnnotations,
190+
)
191+
}
192+
// If this stmt used any annotations, we need to advance the
193+
// number of annotations accordingly.
194+
numAnnotations = stmt.NumAnnotations
195+
}
167196
}
168197
if done {
169198
break
@@ -174,9 +203,13 @@ func (p *Parser) parseWithDepth(
174203

175204
// parse parses a statement from the given scanned tokens.
176205
func (p *Parser) parse(
177-
depth int, sql string, tokens []sqlSymType, nakedIntType *types.T,
206+
depth int,
207+
sql string,
208+
tokens []sqlSymType,
209+
nakedIntType *types.T,
210+
numAnnotations tree.AnnotationIdx,
178211
) (statements.Statement[tree.Statement], error) {
179-
p.lexer.init(sql, tokens, nakedIntType)
212+
p.lexer.init(sql, tokens, nakedIntType, numAnnotations)
180213
defer p.lexer.cleanup()
181214
if p.parserImpl.Parse(&p.lexer) != 0 {
182215
if p.lexer.lastError == nil {
@@ -359,27 +392,32 @@ func ParseTablePattern(sql string) (tree.TablePattern, error) {
359392
// the results are undefined if the string contains invalid SQL
360393
// syntax.
361394
func ParseExprs(exprs []string) (tree.Exprs, error) {
362-
return ParseExprsWithOptions(exprs, DefaultParseOptions)
395+
res, _, err := ParseExprsWithOptions(exprs, DefaultParseOptions)
396+
return res, err
363397
}
364398

365399
// ParseExprsWithOptions parses a comma-delimited sequence of SQL scalar
366400
// expressions with the provided options. The caller is responsible for
367401
// ensuring that the input is, in fact, a comma-delimited sequence of SQL
368402
// scalar expressions — the results are undefined if the string contains
369403
// invalid SQL syntax.
370-
func ParseExprsWithOptions(exprs []string, opts ParseOptions) (tree.Exprs, error) {
404+
//
405+
// It also returns the number of annotations used.
406+
func ParseExprsWithOptions(
407+
exprs []string, opts ParseOptions,
408+
) (tree.Exprs, tree.AnnotationIdx, error) {
371409
if len(exprs) == 0 {
372-
return tree.Exprs{}, nil
410+
return tree.Exprs{}, 0, nil
373411
}
374412
stmt, err := ParseOneWithOptions(fmt.Sprintf("SET ROW (%s)", strings.Join(exprs, ",")), opts)
375413
if err != nil {
376-
return nil, err
414+
return nil, 0, err
377415
}
378416
set, ok := stmt.AST.(*tree.SetVar)
379417
if !ok {
380-
return nil, errors.AssertionFailedf("expected a SET statement, but found %T", stmt)
418+
return nil, 0, errors.AssertionFailedf("expected a SET statement, but found %T", stmt)
381419
}
382-
return set.Values, nil
420+
return set.Values, stmt.NumAnnotations, nil
383421
}
384422

385423
// ParseExpr parses a SQL scalar expression. The caller is responsible

pkg/sql/parser/testdata/do

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,45 @@ END IF;
406406
END;
407407
$$;
408408
-- identifiers removed
409+
410+
parse
411+
DO $$
412+
DECLARE
413+
foo bar;
414+
_ _[];
415+
BEGIN
416+
END;
417+
$$;
418+
----
419+
DO $$
420+
DECLARE
421+
foo bar;
422+
_ _[];
423+
BEGIN
424+
END;
425+
$$;
426+
-- normalized!
427+
DO $$
428+
DECLARE
429+
foo bar;
430+
_ _[];
431+
BEGIN
432+
END;
433+
$$;
434+
-- fully parenthesized
435+
DO $$
436+
DECLARE
437+
foo bar;
438+
_ _[];
439+
BEGIN
440+
END;
441+
$$;
442+
-- literals removed
443+
DO $$
444+
DECLARE
445+
_ _;
446+
_ _[];
447+
BEGIN
448+
END;
449+
$$;
450+
-- identifiers removed

0 commit comments

Comments
 (0)