Skip to content

Commit 0602bf9

Browse files
authored
Merge pull request #153913 from yuzefovich/blathers/backport-release-25.3-151849
release-25.3: plpgsql: fix handling of annotations for DO blocks
2 parents 35d4de5 + 2d46197 commit 0602bf9

File tree

13 files changed

+292
-161
lines changed

13 files changed

+292
-161
lines changed

pkg/sql/catalog/redact/redact_test.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@ func TestRedactQueries(t *testing.T) {
3232
defer srv.Stopper().Stop(ctx)
3333
codec := srv.ApplicationLayer().Codec()
3434
tdb := sqlutils.MakeSQLRunner(db)
35-
tdb.Exec(t, "CREATE TABLE kv (k INT PRIMARY KEY, v STRING)")
36-
tdb.Exec(t, "CREATE VIEW view AS SELECT k, v FROM kv WHERE v <> 'constant literal'")
37-
tdb.Exec(t, "CREATE TABLE ctas AS SELECT k, v FROM kv WHERE v <> 'constant literal'")
35+
tdb.Exec(t, "CREATE TYPE my_enum AS ENUM ('foo', 'bar')")
36+
tdb.Exec(t, "CREATE TABLE kv (k INT PRIMARY KEY, v STRING, e my_enum)")
37+
tdb.Exec(t, "CREATE VIEW view AS SELECT k, v, e FROM kv WHERE v <> 'constant literal' AND e <> 'foo'")
38+
tdb.Exec(t, "CREATE TABLE ctas AS SELECT k, v, e FROM kv WHERE v <> 'constant literal' AND e <> 'foo'")
3839
tdb.Exec(t, `
3940
CREATE FUNCTION f1() RETURNS INT
4041
LANGUAGE SQL
4142
AS $$
4243
SELECT k FROM kv WHERE v != 'foo';
4344
SELECT k FROM kv WHERE v = 'bar';
45+
SELECT k FROM kv WHERE e != 'foo';
46+
SELECT k FROM kv WHERE e = 'bar';
4447
$$`)
4548
tdb.Exec(t, `
4649
CREATE FUNCTION f2() RETURNS INT
@@ -49,8 +52,9 @@ AS $$
4952
DECLARE
5053
x INT := 0;
5154
y TEXT := 'bar';
55+
z my_enum;
5256
BEGIN
53-
SELECT k FROM kv WHERE v != 'foo';
57+
SELECT k FROM kv WHERE v != 'foo' AND e != 'bar'::my_enum;
5458
RETURN x + 3;
5559
END;
5660
$$`)
@@ -61,7 +65,7 @@ $$`)
6165
)
6266
mut := tabledesc.NewBuilder(view.TableDesc()).BuildCreatedMutableTable()
6367
require.Empty(t, redact.Redact(mut.DescriptorProto()))
64-
require.Equal(t, `SELECT k, v FROM defaultdb.public.kv WHERE v != '_'`, mut.ViewQuery)
68+
require.Equal(t, `SELECT k, v, e FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_')`, mut.ViewQuery)
6569
})
6670

6771
t.Run("create table as", func(t *testing.T) {
@@ -70,14 +74,14 @@ $$`)
7074
)
7175
mut := tabledesc.NewBuilder(ctas.TableDesc()).BuildCreatedMutableTable()
7276
require.Empty(t, redact.Redact(mut.DescriptorProto()))
73-
require.Equal(t, `SELECT k, v FROM defaultdb.public.kv WHERE v != '_'`, mut.CreateQuery)
77+
require.Equal(t, `SELECT k, v, e FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_')`, mut.CreateQuery)
7478
})
7579

7680
t.Run("create function sql", func(t *testing.T) {
7781
fn := desctestutils.TestingGetFunctionDescriptor(kvDB, codec, "defaultdb", "public", "f1")
7882
mut := funcdesc.NewBuilder(fn.FuncDesc()).BuildCreatedMutableFunction()
7983
require.Empty(t, redact.Redact(mut.DescriptorProto()))
80-
require.Equal(t, `SELECT k FROM defaultdb.public.kv WHERE v != '_'; SELECT k FROM defaultdb.public.kv WHERE v = '_';`, mut.FunctionBody)
84+
require.Equal(t, `SELECT k FROM defaultdb.public.kv WHERE v != '_'; SELECT k FROM defaultdb.public.kv WHERE v = '_'; SELECT k FROM defaultdb.public.kv WHERE e != '_'; SELECT k FROM defaultdb.public.kv WHERE e = '_';`, mut.FunctionBody)
8185
})
8286

8387
t.Run("create function plpgsql", func(t *testing.T) {
@@ -87,8 +91,9 @@ $$`)
8791
require.Equal(t, `DECLARE
8892
x INT8 := _;
8993
y STRING := '_';
94+
z @100104;
9095
BEGIN
91-
SELECT k FROM defaultdb.public.kv WHERE v != '_';
96+
SELECT k FROM defaultdb.public.kv WHERE (v != '_') AND (e != '_':::@100104);
9297
RETURN x + _;
9398
END;
9499
`, mut.FunctionBody)

pkg/sql/logictest/testdata/logic_test/do

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,23 @@ END
285285
$$;
286286

287287
subtest end
288+
289+
subtest regression_151848
290+
291+
statement error pgcode 42704 type \"bar\" does not exist
292+
DO $$
293+
DECLARE
294+
foo bar;
295+
BEGIN
296+
END;
297+
$$;
298+
299+
statement error pgcode 42704 type \"_\" does not exist
300+
DO $$
301+
DECLARE
302+
foo _[];
303+
BEGIN
304+
END;
305+
$$;
306+
307+
subtest end

pkg/sql/opt/optbuilder/routine.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,19 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
840840
// TODO(drewk): Enable memo reuse with DO statements.
841841
b.DisableMemoReuse = true
842842

843-
defer func(oldInsideFuncDep bool) { b.insideFuncDef = oldInsideFuncDep }(b.insideFuncDef)
843+
doBlockImpl, ok := do.Code.(*plpgsqltree.DoBlock)
844+
if !ok {
845+
panic(errors.AssertionFailedf("expected a plpgsql block"))
846+
}
847+
848+
defer func(oldInsideFuncDep bool, oldAnn tree.Annotations) {
849+
b.insideFuncDef = oldInsideFuncDep
850+
b.semaCtx.Annotations = oldAnn
851+
b.evalCtx.Annotations = &b.semaCtx.Annotations
852+
}(b.insideFuncDef, b.semaCtx.Annotations)
844853
b.insideFuncDef = true
854+
b.semaCtx.Annotations = doBlockImpl.Annotations
855+
b.evalCtx.Annotations = &b.semaCtx.Annotations
845856

846857
// Build the routine body.
847858
var bodyStmts []string
@@ -850,10 +861,6 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
850861
fmtCtx.FormatNode(do.Code)
851862
bodyStmts = []string{fmtCtx.CloseAndGetString()}
852863
}
853-
doBlockImpl, ok := do.Code.(*plpgsqltree.DoBlock)
854-
if !ok {
855-
panic(errors.AssertionFailedf("expected a plpgsql block"))
856-
}
857864
bodyScope := b.buildPLpgSQLDoBody(doBlockImpl)
858865

859866
// Build a CALL expression that invokes the routine.

pkg/sql/parser/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ go_library(
3737
"//pkg/sql/sem/tree/treecmp", # keep
3838
"//pkg/sql/sem/tree/treewindow", # keep
3939
"//pkg/sql/types",
40+
"//pkg/util/buildutil",
4041
"//pkg/util/errorutil/unimplemented",
4142
"//pkg/util/vector", # keep
4243
"@com_github_cockroachdb_errors//:errors",

pkg/sql/parser/lexer.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@ type lexer struct {
3838
lastError error
3939
}
4040

41-
func (l *lexer) init(sql string, tokens []sqlSymType, nakedIntType *types.T) {
41+
// numAnnotations indicates the number of annotations that have already been
42+
// claimed.
43+
func (l *lexer) init(
44+
sql string, tokens []sqlSymType, nakedIntType *types.T, numAnnotations tree.AnnotationIdx,
45+
) {
4246
l.in = sql
4347
l.tokens = tokens
4448
l.lastPos = -1
4549
l.stmt = nil
4650
l.numPlaceholders = 0
47-
l.numAnnotations = 0
51+
l.numAnnotations = numAnnotations
4852
l.lastError = nil
4953

5054
l.nakedIntType = nakedIntType

pkg/sql/parser/lexer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestLexer(t *testing.T) {
3636
scanTokens = append(scanTokens, lval)
3737
}
3838
var l lexer
39-
l.init(d.sql, scanTokens, defaultNakedIntType)
39+
l.init(d.sql, scanTokens, defaultNakedIntType, 0 /* numAnnotations */)
4040
var lexTokens []int
4141
for {
4242
var lval sqlSymType

pkg/sql/parser/parse.go

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/cockroachdb/cockroach/pkg/sql/scanner"
2626
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
2727
"github.com/cockroachdb/cockroach/pkg/sql/types"
28+
"github.com/cockroachdb/cockroach/pkg/util/buildutil"
2829
"github.com/cockroachdb/errors"
2930
)
3031

@@ -47,6 +48,7 @@ type Parser struct {
4748
type ParseOptions struct {
4849
intType *types.T
4950
retainComments bool
51+
numAnnotations *tree.AnnotationIdx
5052
}
5153

5254
var DefaultParseOptions = ParseOptions{
@@ -64,6 +66,18 @@ func (po ParseOptions) WithIntType(t *types.T) ParseOptions {
6466
return po
6567
}
6668

69+
// WithNumAnnotations overrides how annotations are handled. If this option is
70+
// used, then
71+
// - the given integer indicates the number of annotations already claimed, and
72+
// - if multiple stmts are parsed, then unique annotation indexes are used
73+
// across all stmts.
74+
//
75+
// When this option is not used, then each stmt is parsed indepedently.
76+
func (po ParseOptions) WithNumAnnotations(numAnnotations tree.AnnotationIdx) ParseOptions {
77+
po.numAnnotations = &numAnnotations
78+
return po
79+
}
80+
6781
// INT8 is the historical interpretation of INT. This should be left
6882
// alone in the future, since there are many sql fragments stored
6983
// in various descriptors. Any user input that was created after
@@ -168,14 +182,29 @@ func (p *Parser) parseWithDepth(
168182
p.scanner.RetainComments()
169183
}
170184
defer p.scanner.Cleanup()
185+
var numAnnotations tree.AnnotationIdx
186+
if options.numAnnotations != nil {
187+
numAnnotations = *options.numAnnotations
188+
}
171189
for {
172190
sql, tokens, done := p.scanOneStmt()
173-
stmt, err := p.parse(depth+1, sql, tokens, options.intType)
191+
stmt, err := p.parse(depth+1, sql, tokens, options.intType, numAnnotations)
174192
if err != nil {
175193
return nil, err
176194
}
177195
if stmt.AST != nil {
178196
stmts = append(stmts, stmt)
197+
if options.numAnnotations != nil {
198+
if buildutil.CrdbTestBuild && numAnnotations > stmt.NumAnnotations {
199+
return nil, errors.AssertionFailedf(
200+
"annotation index has regressed: numAnnotations=%d, stmt.NumAnnotations=%d ",
201+
numAnnotations, stmt.NumAnnotations,
202+
)
203+
}
204+
// If this stmt used any annotations, we need to advance the
205+
// number of annotations accordingly.
206+
numAnnotations = stmt.NumAnnotations
207+
}
179208
}
180209
if done {
181210
break
@@ -186,9 +215,13 @@ func (p *Parser) parseWithDepth(
186215

187216
// parse parses a statement from the given scanned tokens.
188217
func (p *Parser) parse(
189-
depth int, sql string, tokens []sqlSymType, nakedIntType *types.T,
218+
depth int,
219+
sql string,
220+
tokens []sqlSymType,
221+
nakedIntType *types.T,
222+
numAnnotations tree.AnnotationIdx,
190223
) (statements.Statement[tree.Statement], error) {
191-
p.lexer.init(sql, tokens, nakedIntType)
224+
p.lexer.init(sql, tokens, nakedIntType, numAnnotations)
192225
defer p.lexer.cleanup()
193226
if p.parserImpl.Parse(&p.lexer) != 0 {
194227
if p.lexer.lastError == nil {
@@ -371,27 +404,32 @@ func ParseTablePattern(sql string) (tree.TablePattern, error) {
371404
// the results are undefined if the string contains invalid SQL
372405
// syntax.
373406
func ParseExprs(exprs []string) (tree.Exprs, error) {
374-
return ParseExprsWithOptions(exprs, DefaultParseOptions)
407+
res, _, err := ParseExprsWithOptions(exprs, DefaultParseOptions)
408+
return res, err
375409
}
376410

377411
// ParseExprsWithOptions parses a comma-delimited sequence of SQL scalar
378412
// expressions with the provided options. The caller is responsible for
379413
// ensuring that the input is, in fact, a comma-delimited sequence of SQL
380414
// scalar expressions — the results are undefined if the string contains
381415
// invalid SQL syntax.
382-
func ParseExprsWithOptions(exprs []string, opts ParseOptions) (tree.Exprs, error) {
416+
//
417+
// It also returns the number of annotations used.
418+
func ParseExprsWithOptions(
419+
exprs []string, opts ParseOptions,
420+
) (tree.Exprs, tree.AnnotationIdx, error) {
383421
if len(exprs) == 0 {
384-
return tree.Exprs{}, nil
422+
return tree.Exprs{}, 0, nil
385423
}
386424
stmt, err := ParseOneWithOptions(fmt.Sprintf("SET ROW (%s)", strings.Join(exprs, ",")), opts)
387425
if err != nil {
388-
return nil, err
426+
return nil, 0, err
389427
}
390428
set, ok := stmt.AST.(*tree.SetVar)
391429
if !ok {
392-
return nil, errors.AssertionFailedf("expected a SET statement, but found %T", stmt)
430+
return nil, 0, errors.AssertionFailedf("expected a SET statement, but found %T", stmt)
393431
}
394-
return set.Values, nil
432+
return set.Values, stmt.NumAnnotations, nil
395433
}
396434

397435
// ParseExpr parses a SQL scalar expression. The caller is responsible

pkg/sql/parser/testdata/do

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,45 @@ END IF;
406406
END;
407407
$$;
408408
-- identifiers removed
409+
410+
parse
411+
DO $$
412+
DECLARE
413+
foo bar;
414+
_ _[];
415+
BEGIN
416+
END;
417+
$$;
418+
----
419+
DO $$
420+
DECLARE
421+
foo bar;
422+
_ _[];
423+
BEGIN
424+
END;
425+
$$;
426+
-- normalized!
427+
DO $$
428+
DECLARE
429+
foo bar;
430+
_ _[];
431+
BEGIN
432+
END;
433+
$$;
434+
-- fully parenthesized
435+
DO $$
436+
DECLARE
437+
foo bar;
438+
_ _[];
439+
BEGIN
440+
END;
441+
$$;
442+
-- literals removed
443+
DO $$
444+
DECLARE
445+
_ _;
446+
_ _[];
447+
BEGIN
448+
END;
449+
$$;
450+
-- identifiers removed

0 commit comments

Comments
 (0)