Skip to content

Commit fad7247

Browse files
committed
optbuilder: refactor routine output stmt finalization
This commit is a mechanical refactor to the logic that finalizes a routine's result type and last body statement. This change will make it easier for PL/pgSQL `RETURN NEXT` and `RETURN QUERY` statements to perform their own validation. Informs #105240 Release note: None
1 parent bf0922c commit fad7247

File tree

2 files changed

+85
-112
lines changed

2 files changed

+85
-112
lines changed

pkg/sql/opt/optbuilder/plpgsql.go

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,8 +1168,8 @@ func (b *plpgsqlBuilder) buildPLpgSQLStatements(stmts []ast.Statement, s *scope)
11681168
// statement, and then the following PL/pgSQL statements in the second.
11691169
doCon := b.makeContinuation("_stmt_do")
11701170
doCon.def.Volatility = volatility.Volatile
1171-
body, bodyProps := b.ob.buildPLpgSQLDoBody(t)
1172-
b.appendBodyStmt(&doCon, body, bodyProps)
1171+
bodyScope := b.ob.buildPLpgSQLDoBody(t)
1172+
b.appendBodyStmtFromScope(&doCon, bodyScope)
11731173
b.appendPlpgSQLStmts(&doCon, stmts[i+1:])
11741174
return b.callContinuation(&doCon, s)
11751175

@@ -2124,31 +2124,24 @@ func (b *plpgsqlBuilder) makeContinuationWithTyp(
21242124
return con
21252125
}
21262126

2127-
// appendBodyStmt adds the given body statement and its required properties to
2128-
// the definition of a continuation function. Only the last body statement will
2129-
// return results; all others will only be executed for their side effects
2130-
// (e.g. RAISE statement).
2127+
// appendBodyStmtFromScope adds the given body statement and its required
2128+
// properties from the given scope to the definition of a continuation function.
2129+
// Only the last body statement will return results; all others will only be
2130+
// executed for their side effects (e.g. RAISE statement).
21312131
//
2132-
// appendBodyStmt is separate from makeContinuation to allow recursive routine
2133-
// definitions, which need to push the continuation before it is finished. The
2134-
// separation also allows for appending multiple body statements.
2135-
func (b *plpgsqlBuilder) appendBodyStmt(
2136-
con *continuation, body memo.RelExpr, bodyProps *physical.Required,
2137-
) {
2132+
// appendBodyStmtFromScope is separate from makeContinuation to allow recursive
2133+
// routine definitions, which need to push the continuation before it is
2134+
// finished. The separation also allows for appending multiple body statements.
2135+
func (b *plpgsqlBuilder) appendBodyStmtFromScope(con *continuation, bodyScope *scope) {
21382136
// Set the volatility of the continuation routine to the least restrictive
21392137
// volatility level in the Relational properties of the body statements.
2140-
vol := body.Relational().VolatilitySet.ToVolatility()
2138+
bodyExpr := bodyScope.expr
2139+
vol := bodyExpr.Relational().VolatilitySet.ToVolatility()
21412140
if con.def.Volatility < vol {
21422141
con.def.Volatility = vol
21432142
}
2144-
con.def.Body = append(con.def.Body, body)
2145-
con.def.BodyProps = append(con.def.BodyProps, bodyProps)
2146-
}
2147-
2148-
// appendBodyStmtFromScope is similar to appendBodyStmt, but retrieves the body
2149-
// statement its required properties from the given scope for convenience.
2150-
func (b *plpgsqlBuilder) appendBodyStmtFromScope(con *continuation, bodyScope *scope) {
2151-
b.appendBodyStmt(con, bodyScope.expr, bodyScope.makePhysicalProps())
2143+
con.def.Body = append(con.def.Body, bodyExpr)
2144+
con.def.BodyProps = append(con.def.BodyProps, bodyScope.makePhysicalProps())
21522145
}
21532146

21542147
// appendPlpgSQLStmts builds the given PLpgSQL statements into a relational

pkg/sql/opt/optbuilder/routine.go

Lines changed: 71 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"github.com/cockroachdb/cockroach/pkg/security/username"
1212
"github.com/cockroachdb/cockroach/pkg/sql/opt"
1313
"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
14-
"github.com/cockroachdb/cockroach/pkg/sql/opt/props"
1514
"github.com/cockroachdb/cockroach/pkg/sql/opt/props/physical"
1615
"github.com/cockroachdb/cockroach/pkg/sql/parser"
1716
"github.com/cockroachdb/cockroach/pkg/sql/parser/statements"
@@ -23,6 +22,7 @@ import (
2322
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
2423
"github.com/cockroachdb/cockroach/pkg/sql/sem/volatility"
2524
"github.com/cockroachdb/cockroach/pkg/sql/types"
25+
"github.com/cockroachdb/cockroach/pkg/util/buildutil"
2626
"github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented"
2727
"github.com/cockroachdb/errors"
2828
)
@@ -405,17 +405,14 @@ func (b *Builder) buildRoutine(
405405

406406
for i := range stmts {
407407
stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
408-
expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps()
409408

410409
// The last statement produces the output of the UDF.
411410
if i == len(stmts)-1 {
412411
rTyp := b.finalizeRoutineReturnType(f, stmtScope, inScope, oldInsideDataSource)
413-
expr, physProps = b.finishBuildLastStmt(
414-
stmtScope, bodyScope, isSetReturning, oldInsideDataSource, rTyp,
415-
)
412+
stmtScope = b.finishRoutineReturnStmt(stmtScope, isSetReturning, oldInsideDataSource, rTyp)
416413
}
417-
body[i] = expr
418-
bodyProps[i] = physProps
414+
body[i] = stmtScope.expr
415+
bodyProps[i] = stmtScope.makePhysicalProps()
419416
}
420417

421418
if b.verboseTracing {
@@ -443,19 +440,15 @@ func (b *Builder) buildRoutine(
443440
class: param.Class,
444441
})
445442
}
446-
var expr memo.RelExpr
447-
var physProps *physical.Required
448443
options := basePLOptions().SetIsProcedure(isProc)
449444
plBuilder := newPLpgSQLBuilder(
450445
b, options, def.Name, stmt.AST.Label, colRefs, routineParams, f.ResolvedType(), outScope,
451446
)
452447
stmtScope := plBuilder.buildRootBlock(stmt.AST, bodyScope, routineParams)
453448
rTyp := b.finalizeRoutineReturnType(f, stmtScope, inScope, oldInsideDataSource)
454-
expr, physProps = b.finishBuildLastStmt(
455-
stmtScope, bodyScope, isSetReturning, oldInsideDataSource, rTyp,
456-
)
457-
body = []memo.RelExpr{expr}
458-
bodyProps = []*physical.Required{physProps}
449+
stmtScope = b.finishRoutineReturnStmt(stmtScope, isSetReturning, oldInsideDataSource, rTyp)
450+
body = []memo.RelExpr{stmtScope.expr}
451+
bodyProps = []*physical.Required{stmtScope.makePhysicalProps()}
459452
if b.verboseTracing {
460453
bodyStmts = []string{stmt.String()}
461454
}
@@ -486,44 +479,37 @@ func (b *Builder) buildRoutine(
486479
return routine
487480
}
488481

489-
// finishBuildLastStmt manages the columns returned by the last statement of a
490-
// routine. Depending on the context and return type of the routine, this may
491-
// mean expanding a tuple into multiple columns, or combining multiple columns
492-
// into a tuple.
493-
func (b *Builder) finishBuildLastStmt(
494-
stmtScope, bodyScope *scope, isSetReturning, insideDataSource bool, rTyp *types.T,
495-
) (expr memo.RelExpr, physProps *physical.Required) {
482+
// finishRoutineReturnStmt manages the output columns for a statement that will
483+
// be added to the result set of a routine. Depending on the context and return
484+
// type of the routine, this may mean expanding a tuple into multiple columns,
485+
// or combining multiple columns into a tuple.
486+
func (b *Builder) finishRoutineReturnStmt(
487+
stmtScope *scope, isSetReturning, insideDataSource bool, rTyp *types.T,
488+
) *scope {
496489
// NOTE: the result columns of the last statement may not reflect the return
497490
// type until after the call to maybeAddRoutineAssignmentCasts. Therefore, the
498491
// logic below must take care in distinguishing the resolved return type from
499492
// the result column type(s).
500-
expr, physProps = stmtScope.expr, stmtScope.makePhysicalProps()
501-
493+
//
502494
// Add a LIMIT 1 to the last statement if the UDF is not
503495
// set-returning. This is valid because any other rows after the
504496
// first can simply be ignored. The limit could be beneficial
505497
// because it could allow additional optimization.
506498
if !isSetReturning {
507499
b.buildLimit(&tree.Limit{Count: tree.NewDInt(1)}, b.allocScope(), stmtScope)
508-
expr = stmtScope.expr
509-
// The limit expression will maintain the desired ordering, if any,
510-
// so the physical props ordering can be cleared. The presentation
511-
// must remain.
512-
physProps.Ordering = props.OrderingChoice{}
513500
}
514501

515502
// Depending on the context in which the UDF was called, it may be necessary
516503
// to either combine multiple result columns into a tuple, or to expand a
517504
// tuple result column into multiple columns.
518-
cols := physProps.Presentation
519-
scopeCols := stmtScope.cols
520-
isSingleTupleResult := len(scopeCols) == 1 && scopeCols[0].typ.Family() == types.TupleFamily
505+
isSingleTupleResult := len(stmtScope.cols) == 1 &&
506+
stmtScope.cols[0].typ.Family() == types.TupleFamily
521507
if insideDataSource {
522508
// The UDF is a data source. If it returns a composite type and the last
523509
// statement returns a single tuple column, the elements of the column
524510
// should be expanded into individual columns.
525511
if rTyp.Family() == types.TupleFamily && isSingleTupleResult {
526-
expr, physProps = b.expandRoutineTupleIntoCols(cols[0].ID, bodyScope.push(), expr)
512+
stmtScope = b.expandRoutineTupleIntoCols(stmtScope)
527513
}
528514
} else {
529515
// Only a single column can be returned from a routine, unless it is a UDF
@@ -533,18 +519,16 @@ func (b *Builder) finishBuildLastStmt(
533519
// 2. The routine returns RECORD, and the (single) result column cannot
534520
// be coerced to the return type. Note that a procedure with OUT-params
535521
// always wraps the OUT-param types in a record.
536-
if len(cols) > 1 || (rTyp.Family() == types.TupleFamily && !scopeCols[0].typ.Equivalent(rTyp) &&
537-
!cast.ValidCast(scopeCols[0].typ, rTyp, cast.ContextAssignment)) {
538-
expr, physProps = b.combineRoutineColsIntoTuple(cols, bodyScope.push(), expr)
522+
if len(stmtScope.cols) > 1 ||
523+
(rTyp.Family() == types.TupleFamily && !stmtScope.cols[0].typ.Equivalent(rTyp) &&
524+
!cast.ValidCast(stmtScope.cols[0].typ, rTyp, cast.ContextAssignment)) {
525+
stmtScope = b.combineRoutineColsIntoTuple(stmtScope)
539526
}
540527
}
541528

542-
// We must preserve the presentation of columns as physical properties to
543-
// prevent the optimizer from pruning the output column(s). If necessary, we
544-
// add an assignment cast to the result column(s) so that its type matches the
545-
// function return type.
546-
cols = physProps.Presentation
547-
return b.maybeAddRoutineAssignmentCasts(cols, bodyScope, rTyp, expr, physProps, insideDataSource)
529+
// If necessary, we add an assignment cast to the result column(s) so that its
530+
// type matches the function return type.
531+
return b.maybeAddRoutineAssignmentCasts(stmtScope, rTyp, insideDataSource)
548532
}
549533

550534
// finalizeRoutineReturnType updates the routine's return type, taking into
@@ -599,51 +583,51 @@ func (b *Builder) finalizeRoutineReturnType(
599583

600584
// combineRoutineColsIntoTuple is a helper to combine individual result columns
601585
// into a single tuple column.
602-
func (b *Builder) combineRoutineColsIntoTuple(
603-
cols physical.Presentation, stmtScope *scope, inputExpr memo.RelExpr,
604-
) (memo.RelExpr, *physical.Required) {
605-
elems := make(memo.ScalarListExpr, len(cols))
606-
typContents := make([]*types.T, len(cols))
607-
for i := range cols {
608-
elems[i] = b.factory.ConstructVariable(cols[i].ID)
609-
typContents[i] = b.factory.Metadata().ColumnMeta(cols[i].ID).Type
586+
func (b *Builder) combineRoutineColsIntoTuple(stmtScope *scope) *scope {
587+
outScope := stmtScope.push()
588+
elems := make(memo.ScalarListExpr, len(stmtScope.cols))
589+
typContents := make([]*types.T, len(stmtScope.cols))
590+
for i := range stmtScope.cols {
591+
elems[i] = b.factory.ConstructVariable(stmtScope.cols[i].id)
592+
typContents[i] = stmtScope.cols[i].typ
610593
}
611594
colTyp := types.MakeTuple(typContents)
612595
tup := b.factory.ConstructTuple(elems, colTyp)
613-
col := b.synthesizeColumn(stmtScope, scopeColName(""), colTyp, nil /* expr */, tup)
614-
return b.constructProject(inputExpr, []scopeColumn{*col}), stmtScope.makePhysicalProps()
596+
b.synthesizeColumn(outScope, scopeColName(""), colTyp, nil /* expr */, tup)
597+
b.constructProjectForScope(stmtScope, outScope)
598+
return outScope
615599
}
616600

617601
// expandRoutineTupleIntoCols is a helper to expand the elements of a single
618602
// tuple result column into individual result columns.
619-
func (b *Builder) expandRoutineTupleIntoCols(
620-
tupleColID opt.ColumnID, stmtScope *scope, inputExpr memo.RelExpr,
621-
) (memo.RelExpr, *physical.Required) {
603+
func (b *Builder) expandRoutineTupleIntoCols(stmtScope *scope) *scope {
604+
// Assume that the input scope has a single tuple column.
605+
if buildutil.CrdbTestBuild {
606+
if len(stmtScope.cols) != 1 {
607+
panic(errors.AssertionFailedf("expected a single tuple column"))
608+
}
609+
}
610+
tupleColID := stmtScope.cols[0].id
611+
outScope := stmtScope.push()
622612
colTyp := b.factory.Metadata().ColumnMeta(tupleColID).Type
623-
elems := make([]scopeColumn, len(colTyp.TupleContents()))
624613
for i := range colTyp.TupleContents() {
625614
varExpr := b.factory.ConstructVariable(tupleColID)
626615
e := b.factory.ConstructColumnAccess(varExpr, memo.TupleOrdinal(i))
627-
col := b.synthesizeColumn(stmtScope, scopeColName(""), colTyp.TupleContents()[i], nil, e)
628-
elems[i] = *col
616+
b.synthesizeColumn(outScope, scopeColName(""), colTyp.TupleContents()[i], nil, e)
629617
}
630-
return b.constructProject(inputExpr, elems), stmtScope.makePhysicalProps()
618+
b.constructProjectForScope(stmtScope, outScope)
619+
return outScope
631620
}
632621

633622
// maybeAddRoutineAssignmentCasts checks whether the result columns of the last
634623
// statement in a routine match up with the return type. If not, it attempts to
635624
// assignment-cast the columns to the correct type.
636625
func (b *Builder) maybeAddRoutineAssignmentCasts(
637-
cols physical.Presentation,
638-
bodyScope *scope,
639-
rTyp *types.T,
640-
expr memo.RelExpr,
641-
physProps *physical.Required,
642-
insideDataSource bool,
643-
) (memo.RelExpr, *physical.Required) {
626+
stmtScope *scope, rTyp *types.T, insideDataSource bool,
627+
) *scope {
644628
if rTyp.Family() == types.VoidFamily {
645629
// Void routines don't return a result, so a cast is not necessary.
646-
return expr, physProps
630+
return stmtScope
647631
}
648632
var desiredTypes []*types.T
649633
if insideDataSource && rTyp.Family() == types.TupleFamily {
@@ -655,37 +639,35 @@ func (b *Builder) maybeAddRoutineAssignmentCasts(
655639
// type.
656640
desiredTypes = []*types.T{rTyp}
657641
}
658-
if len(desiredTypes) != len(cols) {
642+
if len(desiredTypes) != len(stmtScope.cols) {
659643
panic(errors.AssertionFailedf("expected types and cols to be the same length"))
660644
}
661645
needCast := false
662-
md := b.factory.Metadata()
663-
for i, col := range cols {
664-
colTyp, expectedTyp := md.ColumnMeta(col.ID).Type, desiredTypes[i]
665-
if !colTyp.Identical(expectedTyp) {
646+
for i, col := range stmtScope.cols {
647+
if !col.typ.Identical(desiredTypes[i]) {
666648
needCast = true
667649
break
668650
}
669651
}
670652
if !needCast {
671-
return expr, physProps
672-
}
673-
stmtScope := bodyScope.push()
674-
for i, col := range cols {
675-
colTyp, expectedTyp := md.ColumnMeta(col.ID).Type, desiredTypes[i]
676-
scalar := b.factory.ConstructVariable(cols[i].ID)
677-
if !colTyp.Identical(expectedTyp) {
678-
if !cast.ValidCast(colTyp, expectedTyp, cast.ContextAssignment) {
653+
return stmtScope
654+
}
655+
outScope := stmtScope.push()
656+
for i, col := range stmtScope.cols {
657+
scalar := b.factory.ConstructVariable(col.id)
658+
if !col.typ.Identical(desiredTypes[i]) {
659+
if !cast.ValidCast(col.typ, desiredTypes[i], cast.ContextAssignment) {
679660
panic(errors.AssertionFailedf(
680661
"invalid cast from %s to %s should have been caught earlier",
681-
colTyp.SQLStringForError(), expectedTyp.SQLStringForError(),
662+
col.typ.SQLStringForError(), desiredTypes[i].SQLStringForError(),
682663
))
683664
}
684-
scalar = b.factory.ConstructAssignmentCast(scalar, expectedTyp)
665+
scalar = b.factory.ConstructAssignmentCast(scalar, desiredTypes[i])
685666
}
686-
b.synthesizeColumn(stmtScope, scopeColName(""), expectedTyp, nil /* expr */, scalar)
667+
b.synthesizeColumn(outScope, scopeColName(""), desiredTypes[i], nil /* expr */, scalar)
687668
}
688-
return b.constructProject(expr, stmtScope.cols), stmtScope.makePhysicalProps()
669+
b.constructProjectForScope(stmtScope, outScope)
670+
return outScope
689671
}
690672

691673
// addDefaultArgs adds DEFAULT arguments to the list of user-supplied arguments
@@ -826,7 +808,7 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
826808
if !ok {
827809
panic(errors.AssertionFailedf("expected a plpgsql block"))
828810
}
829-
body, bodyProps := b.buildPLpgSQLDoBody(doBlockImpl)
811+
bodyScope := b.buildPLpgSQLDoBody(doBlockImpl)
830812

831813
// Build a CALL expression that invokes the routine.
832814
outScope := inScope.push()
@@ -839,8 +821,8 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
839821
Volatility: volatility.Volatile,
840822
RoutineType: tree.ProcedureRoutine,
841823
RoutineLang: tree.RoutineLangPLpgSQL,
842-
Body: []memo.RelExpr{body},
843-
BodyProps: []*physical.Required{bodyProps},
824+
Body: []memo.RelExpr{bodyScope.expr},
825+
BodyProps: []*physical.Required{bodyScope.makePhysicalProps()},
844826
BodyStmts: bodyStmts,
845827
},
846828
},
@@ -853,9 +835,7 @@ func (b *Builder) buildDo(do *tree.DoBlock, inScope *scope) *scope {
853835
}
854836

855837
// buildDoBody builds the body of the anonymous routine for a DO statement.
856-
func (b *Builder) buildPLpgSQLDoBody(
857-
do *plpgsqltree.DoBlock,
858-
) (body memo.RelExpr, bodyProps *physical.Required) {
838+
func (b *Builder) buildPLpgSQLDoBody(do *plpgsqltree.DoBlock) *scope {
859839
// Build an expression for each statement in the function body.
860840
options := basePLOptions().WithIsProcedure().WithIsDoBlock()
861841
plBuilder := newPLpgSQLBuilder(
@@ -866,7 +846,7 @@ func (b *Builder) buildPLpgSQLDoBody(
866846
// variables or columns from the calling context.
867847
bodyScope := b.allocScope()
868848
stmtScope := plBuilder.buildRootBlock(do.Block, bodyScope, nil /* routineParams */)
869-
return b.finishBuildLastStmt(
870-
stmtScope, bodyScope, false /* isSetReturning */, false /* insideDataSource */, types.Void,
849+
return b.finishRoutineReturnStmt(
850+
stmtScope, false /* isSetReturning */, false /* insideDataSource */, types.Void,
871851
)
872852
}

0 commit comments

Comments
 (0)