Skip to content

Commit 639ea4f

Browse files
committed
plpgsql: fix handling of annotations for DO blocks
This commit should fix for good annotations handling for DO blocks that was recently attempted to be fixed in 3b8ab38. (That commit is partially reversed too). The general issue is that in the PLpgSQL parser we use an ephemeral SQL parser instance to process SQL statements. However, when processing the DO block stmts we need to preserve the annotations state between each stmt. This is now achieved by extending the SQL parser infrastructure to give an option to override the initial annotations index as well as keep on reusing the counter across multiple stmts. This allows us to correctly count the number of annotations that we need for the whole DO block during the initial parsing, and then we use the allocated annotations container to set the items in the optbuild. Release note (bug fix): Previously, CockroachDB node could crash when executing DO stmts when they contain user-defined types (possibly non-existing) in non-default configuration (additional logging like the one controlled via `sql.log.all_statements.enabled` cluster setting needed to be enabled). This bug was introduced in 25.1 release and is now fixed.
1 parent 35d4de5 commit 639ea4f

File tree

12 files changed

+284
-161
lines changed

12 files changed

+284
-161
lines changed

pkg/sql/catalog/redact/redact_test.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@ func TestRedactQueries(t *testing.T) {
3232
defer srv.Stopper().Stop(ctx)
3333
codec := srv.ApplicationLayer().Codec()
3434
tdb := sqlutils.MakeSQLRunner(db)
35-
tdb.Exec(t, "CREATE TABLE kv (k INT PRIMARY KEY, v STRING)")
36-
tdb.Exec(t, "CREATE VIEW view AS SELECT k, v FROM kv WHERE v <> 'constant literal'")
37-
tdb.Exec(t, "CREATE TABLE ctas AS SELECT k, v FROM kv WHERE v <> 'constant literal'")
35+
tdb.Exec(t, "CREATE TYPE my_enum AS ENUM ('foo', 'bar')")
36+
tdb.Exec(t, "CREATE TABLE kv (k INT PRIMARY KEY, v STRING, e my_enum)")
37+
tdb.Exec(t, "CREATE VIEW view AS SELECT k, v, e FROM kv WHERE v <> 'constant literal' AND e <> 'foo'")
38+
tdb.Exec(t, "CREATE TABLE ctas AS SELECT k, v, e FROM kv WHERE v <> 'constant literal' AND e <> 'foo'")
3839
tdb.Exec(t, `
3940
CREATE FUNCTION f1() RETURNS INT
4041
LANGUAGE SQL
4142
AS $$
4243
SELECT k FROM kv WHERE v != 'foo';
4344
SELECT k FROM kv WHERE v = 'bar';
45+
SELECT k FROM kv WHERE e != 'foo';
46+
SELECT k FROM kv WHERE e = 'bar';
4447
$$`)
4548
tdb.Exec(t, `
4649
CREATE FUNCTION f2() RETURNS INT
@@ -49,8 +52,9 @@ AS $$
4952
DECLARE
5053
x INT := 0;
5154
y TEXT := 'bar';
55+
z my_enum;
5256
BEGIN
53-
SELECT k FROM kv WHERE v != 'foo';
57+
SELECT k FROM kv WHERE v != 'foo' AND e != 'bar'::my_enum;
5458
RETURN x + 3;
5559
END;
5660
$$`)
@@ -61,7 +65,7 @@ $$`)
6165
)
6266
mut := tabledesc.NewBuilder(view.TableDesc()).BuildCreatedMutableTable()
6367
require.Empty(t, redact.Redact(mut.DescriptorProto()))
64-
require.Equal(t, `SELECT k, v FROM defaultdb.public.kv WHERE v != '_'`, mut.ViewQuery)
68+
require.Equal(t, `SELECT k, v, e FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_')`, mut.ViewQuery)
6569
})
6670

6771
t.Run("create table as", func(t *testing.T) {
@@ -70,14 +74,14 @@ $$`)
7074
)
7175
mut := tabledesc.NewBuilder(ctas.TableDesc()).BuildCreatedMutableTable()
7276
require.Empty(t, redact.Redact(mut.DescriptorProto()))
73-
require.Equal(t, `SELECT k, v FROM defaultdb.public.kv WHERE v != '_'`, mut.CreateQuery)
77+
require.Equal(t, `SELECT k, v, e FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_')`, mut.CreateQuery)
7478
})
7579

7680
t.Run("create function sql", func(t *testing.T) {
7781
fn := desctestutils.TestingGetFunctionDescriptor(kvDB, codec, "defaultdb", "public", "f1")
7882
mut := funcdesc.NewBuilder(fn.FuncDesc()).BuildCreatedMutableFunction()
7983
require.Empty(t, redact.Redact(mut.DescriptorProto()))
80-
require.Equal(t, `SELECT k FROM defaultdb.public.kv WHERE v != '_'; SELECT k FROM defaultdb.public.kv WHERE v = '_';`, mut.FunctionBody)
84+
require.Equal(t, `SELECT k FROM defaultdb.public.kv WHERE v != '_'; SELECT k FROM defaultdb.public.kv WHERE v = '_'; SELECT k FROM defaultdb.public.kv WHERE e != '_'; SELECT k FROM defaultdb.public.kv WHERE e = '_';`, mut.FunctionBody)
8185
})
8286

8387
t.Run("create function plpgsql", func(t *testing.T) {
@@ -87,8 +91,9 @@ $$`)
8791
require.Equal(t, `DECLARE
8892
x INT8 := _;
8993
y STRING := '_';
94+
z @100104;
9095
BEGIN
91-
SELECT k FROM defaultdb.public.kv WHERE v != '_';
96+
SELECT k FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_':::@100104);
9297
RETURN x + _;
9398
END;
9499
`, mut.FunctionBody)

pkg/sql/logictest/testdata/logic_test/do

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,23 @@ END
285285
$$;
286286

287287
subtest end
288+
289+
subtest regression_151848
290+
291+
statement error pgcode 42704 type \"bar\" does not exist
292+
DO $$
293+
DECLARE
294+
foo bar;
295+
BEGIN
296+
END;
297+
$$;
298+
299+
statement error pgcode 42704 type \"_\" does not exist
300+
DO $$
301+
DECLARE
302+
foo _[];
303+
BEGIN
304+
END;
305+
$$;
306+
307+
subtest end

pkg/sql/opt/optbuilder/routine.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,19 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
840840
// TODO(drewk): Enable memo reuse with DO statements.
841841
b.DisableMemoReuse = true
842842

843-
defer func(oldInsideFuncDep bool) { b.insideFuncDef = oldInsideFuncDep }(b.insideFuncDef)
843+
doBlockImpl, ok := do.Code.(*plpgsqltree.DoBlock)
844+
if !ok {
845+
panic(errors.AssertionFailedf("expected a plpgsql block"))
846+
}
847+
848+
defer func(oldInsideFuncDep bool, oldAnn tree.Annotations) {
849+
b.insideFuncDef = oldInsideFuncDep
850+
b.semaCtx.Annotations = oldAnn
851+
b.evalCtx.Annotations = &b.semaCtx.Annotations
852+
}(b.insideFuncDef, b.semaCtx.Annotations)
844853
b.insideFuncDef = true
854+
b.semaCtx.Annotations = doBlockImpl.Annotations
855+
b.evalCtx.Annotations = &b.semaCtx.Annotations
845856

846857
// Build the routine body.
847858
var bodyStmts []string
@@ -850,10 +861,6 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
850861
fmtCtx.FormatNode(do.Code)
851862
bodyStmts = []string{fmtCtx.CloseAndGetString()}
852863
}
853-
doBlockImpl, ok := do.Code.(*plpgsqltree.DoBlock)
854-
if !ok {
855-
panic(errors.AssertionFailedf("expected a plpgsql block"))
856-
}
857864
bodyScope := b.buildPLpgSQLDoBody(doBlockImpl)
858865

859866
// Build a CALL expression that invokes the routine.

pkg/sql/parser/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ go_library(
3737
"//pkg/sql/sem/tree/treecmp", # keep
3838
"//pkg/sql/sem/tree/treewindow", # keep
3939
"//pkg/sql/types",
40+
"//pkg/util/buildutil",
4041
"//pkg/util/errorutil/unimplemented",
4142
"//pkg/util/vector", # keep
4243
"@com_github_cockroachdb_errors//:errors",

pkg/sql/parser/lexer.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@ type lexer struct {
3838
lastError error
3939
}
4040

41-
func (l *lexer) init(sql string, tokens []sqlSymType, nakedIntType *types.T) {
41+
// numAnnotations indicates the number of annotations that have already been
42+
// claimed.
43+
func (l *lexer) init(
44+
sql string, tokens []sqlSymType, nakedIntType *types.T, numAnnotations tree.AnnotationIdx,
45+
) {
4246
l.in = sql
4347
l.tokens = tokens
4448
l.lastPos = -1
4549
l.stmt = nil
4650
l.numPlaceholders = 0
47-
l.numAnnotations = 0
51+
l.numAnnotations = numAnnotations
4852
l.lastError = nil
4953

5054
l.nakedIntType = nakedIntType

pkg/sql/parser/lexer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestLexer(t *testing.T) {
3636
scanTokens = append(scanTokens, lval)
3737
}
3838
var l lexer
39-
l.init(d.sql, scanTokens, defaultNakedIntType)
39+
l.init(d.sql, scanTokens, defaultNakedIntType, 0 /* numAnnotations */)
4040
var lexTokens []int
4141
for {
4242
var lval sqlSymType

pkg/sql/parser/parse.go

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/cockroachdb/cockroach/pkg/sql/scanner"
2626
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
2727
"github.com/cockroachdb/cockroach/pkg/sql/types"
28+
"github.com/cockroachdb/cockroach/pkg/util/buildutil"
2829
"github.com/cockroachdb/errors"
2930
)
3031

@@ -47,6 +48,7 @@ type Parser struct {
4748
type ParseOptions struct {
4849
intType *types.T
4950
retainComments bool
51+
numAnnotations *tree.AnnotationIdx
5052
}
5153

5254
var DefaultParseOptions = ParseOptions{
@@ -64,6 +66,18 @@ func (po ParseOptions) WithIntType(t *types.T) ParseOptions {
6466
return po
6567
}
6668

69+
// WithNumAnnotations overrides how annotations are handled. If this option is
70+
// used, then
71+
// - the given integer indicates the number of annotations already claimed, and
72+
// - if multiple stmts are parsed, then unique annotation indexes are used
73+
// across all stmts.
74+
//
75+
// When this option is not used, then each stmt is parsed indepedently.
76+
func (po ParseOptions) WithNumAnnotations(numAnnotations tree.AnnotationIdx) ParseOptions {
77+
po.numAnnotations = &numAnnotations
78+
return po
79+
}
80+
6781
// INT8 is the historical interpretation of INT. This should be left
6882
// alone in the future, since there are many sql fragments stored
6983
// in various descriptors. Any user input that was created after
@@ -168,14 +182,29 @@ func (p *Parser) parseWithDepth(
168182
p.scanner.RetainComments()
169183
}
170184
defer p.scanner.Cleanup()
185+
var numAnnotations tree.AnnotationIdx
186+
if options.numAnnotations != nil {
187+
numAnnotations = *options.numAnnotations
188+
}
171189
for {
172190
sql, tokens, done := p.scanOneStmt()
173-
stmt, err := p.parse(depth+1, sql, tokens, options.intType)
191+
stmt, err := p.parse(depth+1, sql, tokens, options.intType, numAnnotations)
174192
if err != nil {
175193
return nil, err
176194
}
177195
if stmt.AST != nil {
178196
stmts = append(stmts, stmt)
197+
if options.numAnnotations != nil {
198+
if buildutil.CrdbTestBuild && numAnnotations > stmt.NumAnnotations {
199+
return nil, errors.AssertionFailedf(
200+
"annotation index has regressed: numAnnotations=%d, stmt.NumAnnotations=%d ",
201+
numAnnotations, stmt.NumAnnotations,
202+
)
203+
}
204+
// If this stmt used any annotations, we need to advance the
205+
// number of annotations accordingly.
206+
numAnnotations = stmt.NumAnnotations
207+
}
179208
}
180209
if done {
181210
break
@@ -186,9 +215,13 @@ func (p *Parser) parseWithDepth(
186215

187216
// parse parses a statement from the given scanned tokens.
188217
func (p *Parser) parse(
189-
depth int, sql string, tokens []sqlSymType, nakedIntType *types.T,
218+
depth int,
219+
sql string,
220+
tokens []sqlSymType,
221+
nakedIntType *types.T,
222+
numAnnotations tree.AnnotationIdx,
190223
) (statements.Statement[tree.Statement], error) {
191-
p.lexer.init(sql, tokens, nakedIntType)
224+
p.lexer.init(sql, tokens, nakedIntType, numAnnotations)
192225
defer p.lexer.cleanup()
193226
if p.parserImpl.Parse(&p.lexer) != 0 {
194227
if p.lexer.lastError == nil {
@@ -371,27 +404,32 @@ func ParseTablePattern(sql string) (tree.TablePattern, error) {
371404
// the results are undefined if the string contains invalid SQL
372405
// syntax.
373406
func ParseExprs(exprs []string) (tree.Exprs, error) {
374-
return ParseExprsWithOptions(exprs, DefaultParseOptions)
407+
res, _, err := ParseExprsWithOptions(exprs, DefaultParseOptions)
408+
return res, err
375409
}
376410

377411
// ParseExprsWithOptions parses a comma-delimited sequence of SQL scalar
378412
// expressions with the provided options. The caller is responsible for
379413
// ensuring that the input is, in fact, a comma-delimited sequence of SQL
380414
// scalar expressions — the results are undefined if the string contains
381415
// invalid SQL syntax.
382-
func ParseExprsWithOptions(exprs []string, opts ParseOptions) (tree.Exprs, error) {
416+
//
417+
// It also returns the number of annotations used.
418+
func ParseExprsWithOptions(
419+
exprs []string, opts ParseOptions,
420+
) (tree.Exprs, tree.AnnotationIdx, error) {
383421
if len(exprs) == 0 {
384-
return tree.Exprs{}, nil
422+
return tree.Exprs{}, 0, nil
385423
}
386424
stmt, err := ParseOneWithOptions(fmt.Sprintf("SET ROW (%s)", strings.Join(exprs, ",")), opts)
387425
if err != nil {
388-
return nil, err
426+
return nil, 0, err
389427
}
390428
set, ok := stmt.AST.(*tree.SetVar)
391429
if !ok {
392-
return nil, errors.AssertionFailedf("expected a SET statement, but found %T", stmt)
430+
return nil, 0, errors.AssertionFailedf("expected a SET statement, but found %T", stmt)
393431
}
394-
return set.Values, nil
432+
return set.Values, stmt.NumAnnotations, nil
395433
}
396434

397435
// ParseExpr parses a SQL scalar expression. The caller is responsible

pkg/sql/parser/testdata/do

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,45 @@ END IF;
406406
END;
407407
$$;
408408
-- identifiers removed
409+
410+
parse
411+
DO $$
412+
DECLARE
413+
foo bar;
414+
_ _[];
415+
BEGIN
416+
END;
417+
$$;
418+
----
419+
DO $$
420+
DECLARE
421+
foo bar;
422+
_ _[];
423+
BEGIN
424+
END;
425+
$$;
426+
-- normalized!
427+
DO $$
428+
DECLARE
429+
foo bar;
430+
_ _[];
431+
BEGIN
432+
END;
433+
$$;
434+
-- fully parenthesized
435+
DO $$
436+
DECLARE
437+
foo bar;
438+
_ _[];
439+
BEGIN
440+
END;
441+
$$;
442+
-- literals removed
443+
DO $$
444+
DECLARE
445+
_ _;
446+
_ _[];
447+
BEGIN
448+
END;
449+
$$;
450+
-- identifiers removed

0 commit comments

Comments
 (0)