Skip to content

Commit d3085c5

Browse files
committed
pgwire: fix JDBC driver compatibility for procedures with COMMIT statements
Previously, when using the extended protocol (Parse/Bind/Describe/Execute/Sync) to call a PL/pgSQL procedure containing COMMIT statements, CockroachDB would send extra RowDescription messages after the COMMIT, causing the JDBC driver to throw NoSuchElementException due to unexpected message sequences. This change fixes the message flow to match JDBC driver expectations when procedures execute COMMIT statements internally. The fix ensures that the proper sequence of PostgreSQL wire protocol messages is sent, preventing the driver from entering an inconsistent state. Added comprehensive pgtest coverage for both simple and extended protocol procedure calls with various COMMIT patterns to prevent regressions. Fixes #158771 Release Notes (Bug Fix): Fixed compatibility issue with JDBC driver when calling PL/pgSQL procedures containing COMMIT statements via prepared statements. The driver would previously throw NoSuchElementException due to unexpected PostgreSQL wire protocol message sequences.
1 parent e6ed2ba commit d3085c5

File tree

15 files changed

+224
-26
lines changed

15 files changed

+224
-26
lines changed

pkg/sql/conn_executor.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1712,6 +1712,8 @@ type connExecutor struct {
17121712
// will be used to synthesize an ExecStmt command to avoid attempting to
17131713
// execute the portal twice.
17141714
resumeStmt statements.Statement[tree.Statement]
1715+
1716+
callRowDescSent bool
17151717
}
17161718

17171719
// shouldExecuteOnTxnRestart indicates that ex.onTxnRestart will be
@@ -2801,6 +2803,7 @@ func (ex *connExecutor) execCmd() (retErr error) {
28012803
ex.extraTxnState.storedProcTxnState.resumeProc = nil
28022804
ex.extraTxnState.storedProcTxnState.resumeStmt = statements.Statement[tree.Statement]{}
28032805
ex.extraTxnState.storedProcTxnState.txnModes = nil
2806+
ex.extraTxnState.storedProcTxnState.callRowDescSent = false
28042807
}
28052808

28062809
if err := ex.updateTxnRewindPosMaybe(ctx, cmd, pos, advInfo); err != nil {
@@ -3502,6 +3505,11 @@ func stmtHasNoData(stmt tree.Statement, resultColumns colinfo.ResultColumns) boo
35023505
return true
35033506
}
35043507
if stmt.StatementReturnType() == tree.Rows {
3508+
// If the procedure doesn't contain output parameters, write a NoData
3509+
// message.
3510+
if stmt.StatementTag() == tree.CallStmtTag {
3511+
return len(resultColumns) == 0
3512+
}
35053513
return false
35063514
}
35073515
// The statement may not always return rows (e.g. EXECUTE), but if it does,
@@ -4384,8 +4392,21 @@ func (ex *connExecutor) initStatementResult(
43844392
// ANALYZE), then the columns will be set later.
43854393
if ex.planner.instrumentation.outputMode == unmodifiedOutput &&
43864394
ast.StatementReturnType() == tree.Rows {
4395+
_, isCallStmt := ast.(*tree.Call)
4396+
// Only write RowDescription message if the procedure has output parameters.
4397+
skipWriteRowDesc := isCallStmt && len(cols) == 0
4398+
4399+
// For CALL statements, check if we already sent RowDescription to prevent
4400+
// duplicate messages when procedures contain COMMIT/ROLLBACK statements.
4401+
if isCallStmt && !skipWriteRowDesc {
4402+
if ex.extraTxnState.storedProcTxnState.callRowDescSent {
4403+
skipWriteRowDesc = true
4404+
} else {
4405+
ex.extraTxnState.storedProcTxnState.callRowDescSent = true
4406+
}
4407+
}
43874408
// Note that this call is necessary even if cols is nil.
4388-
res.SetColumns(ctx, cols)
4409+
res.SetColumns(ctx, cols, skipWriteRowDesc)
43894410
}
43904411
return nil
43914412
}

pkg/sql/conn_executor_exec.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3781,7 +3781,7 @@ func (ex *connExecutor) runObserverStatement(
37813781
func (ex *connExecutor) runShowSyntax(
37823782
ctx context.Context, stmt string, res RestrictedCommandResult,
37833783
) error {
3784-
res.SetColumns(ctx, colinfo.ShowSyntaxColumns)
3784+
res.SetColumns(ctx, colinfo.ShowSyntaxColumns, false /* skipRowDescription */)
37853785
var commErr error
37863786
parser.RunShowSyntax(ctx, stmt,
37873787
func(ctx context.Context, field, msg string) {
@@ -3800,7 +3800,7 @@ func (ex *connExecutor) runShowSyntax(
38003800
func (ex *connExecutor) runShowTransactionState(
38013801
ctx context.Context, res RestrictedCommandResult,
38023802
) error {
3803-
res.SetColumns(ctx, colinfo.ResultColumns{{Name: "TRANSACTION STATUS", Typ: types.String}})
3803+
res.SetColumns(ctx, colinfo.ResultColumns{{Name: "TRANSACTION STATUS", Typ: types.String}}, false)
38043804

38053805
state := fmt.Sprintf("%s", ex.machine.CurState())
38063806
return res.AddRow(ctx, tree.Datums{tree.NewDString(state)})
@@ -3863,7 +3863,7 @@ func (ex *connExecutor) runShowTransferState(
38633863
for i := 0; i < len(colNames); i++ {
38643864
cols[i] = colinfo.ResultColumn{Name: colNames[i], Typ: types.String}
38653865
}
3866-
res.SetColumns(ctx, cols)
3866+
res.SetColumns(ctx, cols, false /* skipRowDescription */)
38673867

38683868
var sessionState, sessionRevivalToken tree.Datum
38693869
var row tree.Datums
@@ -3897,7 +3897,7 @@ func (ex *connExecutor) runShowTransferState(
38973897
func (ex *connExecutor) runShowCompletions(
38983898
ctx context.Context, n *tree.ShowCompletions, res RestrictedCommandResult,
38993899
) error {
3900-
res.SetColumns(ctx, colinfo.ShowCompletionsColumns)
3900+
res.SetColumns(ctx, colinfo.ShowCompletionsColumns, false)
39013901
log.Dev.Warningf(ctx, "COMPLETION GENERATOR FOR: %+v", *n)
39023902
sd := ex.planner.SessionData()
39033903
override := sessiondata.InternalExecutorOverride{
@@ -3964,7 +3964,7 @@ func (ex *connExecutor) runShowLastQueryStatistics(
39643964
for i, n := range stmt.Columns {
39653965
resColumns[i] = colinfo.ResultColumn{Name: string(n), Typ: types.String}
39663966
}
3967-
res.SetColumns(ctx, resColumns)
3967+
res.SetColumns(ctx, resColumns, false /* skipRowDescription */)
39683968

39693969
phaseTimes := ex.statsCollector.PreviousPhaseTimes()
39703970

pkg/sql/conn_executor_prepare.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ func (ex *connExecutor) execDescribe(
665665
}
666666
// Sending a nil formatCodes is equivalent to sending all text format
667667
// codes.
668-
res.SetPortalOutput(ctx, cursor.Rows.Types(), nil /* formatCodes */)
668+
res.SetPortalOutput(ctx, cursor.Rows.Types(), nil /* formatCodes */, false /* skipRowDescription */)
669669
return nil, nil
670670
}
671671

@@ -676,7 +676,16 @@ func (ex *connExecutor) execDescribe(
676676
if stmtHasNoData(ast, portal.Stmt.Columns) {
677677
res.SetNoDataRowDescription()
678678
} else {
679-
res.SetPortalOutput(ctx, portal.Stmt.Columns, portal.OutFormats)
679+
var isCallStmt bool
680+
var skipRowDescription bool
681+
if ast != nil {
682+
_, isCallStmt = ast.(*tree.Call)
683+
skipRowDescription = isCallStmt && ex.extraTxnState.storedProcTxnState.callRowDescSent
684+
}
685+
res.SetPortalOutput(ctx, portal.Stmt.Columns, portal.OutFormats, skipRowDescription)
686+
if isCallStmt && !skipRowDescription {
687+
ex.extraTxnState.storedProcTxnState.callRowDescSent = true
688+
}
680689
}
681690
default:
682691
return retErr(pgerror.Newf(

pkg/sql/conn_executor_savepoints.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ func (ex *connExecutor) runShowSavepointState(
410410
res.SetColumns(ctx, colinfo.ResultColumns{
411411
{Name: "savepoint_name", Typ: types.String},
412412
{Name: "is_initial_savepoint", Typ: types.Bool},
413-
})
413+
}, false /* skipRowDescription */)
414414

415415
for _, entry := range ex.extraTxnState.savepoints {
416416
if err := res.AddRow(ctx, tree.Datums{

pkg/sql/conn_executor_show_commit_timestamp.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,6 @@ func (ex *connExecutor) execShowCommitTimestampInNoTxnState(
141141
func writeShowCommitTimestampRow(
142142
ctx context.Context, res RestrictedCommandResult, ts hlc.Timestamp,
143143
) error {
144-
res.SetColumns(ctx, colinfo.ShowCommitTimestampColumns)
144+
res.SetColumns(ctx, colinfo.ShowCommitTimestampColumns, false /* skipRowDescription */)
145145
return res.AddRow(ctx, tree.Datums{eval.TimestampToDecimalDatum(ts)})
146146
}

pkg/sql/conn_io.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ type RestrictedCommandResult interface {
867867
// can be nil.
868868
//
869869
// This needs to be called (once) before AddRow.
870-
SetColumns(context.Context, colinfo.ResultColumns)
870+
SetColumns(ctx context.Context, cols colinfo.ResultColumns, skipWriteRowDesc bool)
871871

872872
// ResetStmtType allows a client to change the statement type of the current
873873
// result, from the original one set when the result was created trough
@@ -948,7 +948,7 @@ type DescribeResult interface {
948948
SetPrepStmtOutput(context.Context, colinfo.ResultColumns)
949949
// SetPortalOutput tells the client about the results schema and formatting of
950950
// a portal.
951-
SetPortalOutput(context.Context, colinfo.ResultColumns, []pgwirebase.FormatCode)
951+
SetPortalOutput(ctx context.Context, cols colinfo.ResultColumns, fmtCode []pgwirebase.FormatCode, skipRowDescription bool)
952952
}
953953

954954
// ParseResult represents the result of a Parse command.
@@ -1114,7 +1114,9 @@ func (r *streamingCommandResult) RevokePortalPausability() error {
11141114
}
11151115

11161116
// SetColumns is part of the RestrictedCommandResult interface.
1117-
func (r *streamingCommandResult) SetColumns(ctx context.Context, cols colinfo.ResultColumns) {
1117+
func (r *streamingCommandResult) SetColumns(
1118+
ctx context.Context, cols colinfo.ResultColumns, skipWriteRowDesc bool,
1119+
) {
11181120
// The interface allows for cols to be nil, yet the iterator result expects
11191121
// non-nil value to indicate that it was the column metadata.
11201122
if cols == nil {
@@ -1259,7 +1261,7 @@ func (r *streamingCommandResult) SetPrepStmtOutput(context.Context, colinfo.Resu
12591261

12601262
// SetPortalOutput is part of the DescribeResult interface.
12611263
func (r *streamingCommandResult) SetPortalOutput(
1262-
context.Context, colinfo.ResultColumns, []pgwirebase.FormatCode,
1264+
context.Context, colinfo.ResultColumns, []pgwirebase.FormatCode, bool,
12631265
) {
12641266
}
12651267

pkg/sql/explain_bundle.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func setExplainBundleResult(
6060
warnings []string,
6161
) error {
6262
res.ResetStmtType(&tree.ExplainAnalyze{})
63-
res.SetColumns(ctx, colinfo.ExplainPlanColumns)
63+
res.SetColumns(ctx, colinfo.ExplainPlanColumns, false /* skipRowDescription */)
6464

6565
var text []string
6666
if bundle.collectionErr != nil {

pkg/sql/instrumentation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,7 @@ func (ih *instrumentationHelper) setExplainAnalyzeResult(
976976
trace tracingpb.Recording,
977977
) (commErr error) {
978978
res.ResetStmtType(&tree.ExplainAnalyze{})
979-
res.SetColumns(ctx, colinfo.ExplainPlanColumns)
979+
res.SetColumns(ctx, colinfo.ExplainPlanColumns, false /* skipRowDescription */)
980980

981981
if res.Err() != nil {
982982
// Can't add rows if there was an error.

pkg/sql/isession/command_result.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ func (i *internalCommandResult) SendNotice(
7373
return nil
7474
}
7575

76-
func (i *internalCommandResult) SetColumns(ctx context.Context, cols colinfo.ResultColumns) {
76+
func (i *internalCommandResult) SetColumns(
77+
ctx context.Context, cols colinfo.ResultColumns, skipWriteRowDesc bool,
78+
) {
7779
// We don't need this because the datums include type information.
7880
}
7981

@@ -166,7 +168,10 @@ func (i *internalCommandResult) SendCopyOut(
166168
}
167169

168170
func (i *internalCommandResult) SetPortalOutput(
169-
ctx context.Context, cols colinfo.ResultColumns, formatCodes []pgwirebase.FormatCode,
171+
ctx context.Context,
172+
cols colinfo.ResultColumns,
173+
fmtCode []pgwirebase.FormatCode,
174+
skipRowDescription bool,
170175
) {
171176
i.SetError(errors.AssertionFailedf("SetPortalOutput is not supported by the internal session"))
172177
}

pkg/sql/opt/optbuilder/plpgsql.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1272,7 +1272,7 @@ func (b *plpgsqlBuilder) buildPLpgSQLStatements(stmts []ast.Statement, s *scope)
12721272
))
12731273
}
12741274
}
1275-
b.checkDuplicateTargets(target, "CALL")
1275+
b.checkDuplicateTargets(target, tree.CallStmtTag)
12761276
if len(target) == 0 {
12771277
// When there is no INTO target, build the nested procedure call into a
12781278
// body statement that is only executed for its side effects.

0 commit comments

Comments
 (0)