Skip to content

Commit db4b0d4

Browse files
authored
Merge pull request #21 from PostHog/fix/transaction-status-tracking
Track transaction status in PostgreSQL wire protocol
2 parents 9266156 + d09ca12 commit db4b0d4

File tree

2 files changed

+130
-22
lines changed

2 files changed

+130
-22
lines changed

server/conn.go

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ type portal struct {
2929
described bool // true if Describe was called on this portal
3030
}
3131

32+
// Transaction status constants for PostgreSQL wire protocol
33+
const (
34+
txStatusIdle = 'I' // Not in a transaction
35+
txStatusTransaction = 'T' // In a transaction
36+
txStatusError = 'E' // In a failed transaction
37+
)
38+
3239
type clientConn struct {
3340
server *Server
3441
conn net.Conn
@@ -40,6 +47,7 @@ type clientConn struct {
4047
pid int32
4148
stmts map[string]*preparedStmt // prepared statements by name
4249
portals map[string]*portal // portals by name
50+
txStatus byte // current transaction status ('I', 'T', or 'E')
4351
}
4452

4553
func (c *clientConn) serve() error {
@@ -48,6 +56,7 @@ func (c *clientConn) serve() error {
4856
c.pid = int32(os.Getpid())
4957
c.stmts = make(map[string]*preparedStmt)
5058
c.portals = make(map[string]*portal)
59+
c.txStatus = txStatusIdle
5160

5261
// Handle startup
5362
if err := c.handleStartup(); err != nil {
@@ -75,7 +84,7 @@ func (c *clientConn) serve() error {
7584
c.sendInitialParams()
7685

7786
// Send ready for query
78-
if err := writeReadyForQuery(c.writer, 'I'); err != nil {
87+
if err := writeReadyForQuery(c.writer, c.txStatus); err != nil {
7988
return err
8089
}
8190
c.writer.Flush()
@@ -235,7 +244,7 @@ func (c *clientConn) messageLoop() error {
235244

236245
case msgSync:
237246
// Extended query protocol - Sync
238-
if err := writeReadyForQuery(c.writer, 'I'); err != nil {
247+
if err := writeReadyForQuery(c.writer, c.txStatus); err != nil {
239248
return err
240249
}
241250
c.writer.Flush()
@@ -262,7 +271,7 @@ func (c *clientConn) handleQuery(body []byte) error {
262271

263272
if query == "" {
264273
writeEmptyQueryResponse(c.writer)
265-
writeReadyForQuery(c.writer, 'I')
274+
writeReadyForQuery(c.writer, c.txStatus)
266275
c.writer.Flush()
267276
return nil
268277
}
@@ -286,14 +295,16 @@ func (c *clientConn) handleQuery(body []byte) error {
286295
result, err := c.db.Exec(query)
287296
if err != nil {
288297
c.sendError("ERROR", "42000", err.Error())
289-
writeReadyForQuery(c.writer, 'I')
298+
c.setTxError()
299+
writeReadyForQuery(c.writer, c.txStatus)
290300
c.writer.Flush()
291301
return nil
292302
}
293303

304+
c.updateTxStatus(cmdType)
294305
tag := c.buildCommandTag(cmdType, result)
295306
writeCommandComplete(c.writer, tag)
296-
writeReadyForQuery(c.writer, 'I')
307+
writeReadyForQuery(c.writer, c.txStatus)
297308
c.writer.Flush()
298309
return nil
299310
}
@@ -302,7 +313,8 @@ func (c *clientConn) handleQuery(body []byte) error {
302313
rows, err := c.db.Query(query)
303314
if err != nil {
304315
c.sendError("ERROR", "42000", err.Error())
305-
writeReadyForQuery(c.writer, 'I')
316+
c.setTxError()
317+
writeReadyForQuery(c.writer, c.txStatus)
306318
c.writer.Flush()
307319
return nil
308320
}
@@ -312,15 +324,17 @@ func (c *clientConn) handleQuery(body []byte) error {
312324
cols, err := rows.Columns()
313325
if err != nil {
314326
c.sendError("ERROR", "42000", err.Error())
315-
writeReadyForQuery(c.writer, 'I')
327+
c.setTxError()
328+
writeReadyForQuery(c.writer, c.txStatus)
316329
c.writer.Flush()
317330
return nil
318331
}
319332

320333
colTypes, err := rows.ColumnTypes()
321334
if err != nil {
322335
c.sendError("ERROR", "42000", err.Error())
323-
writeReadyForQuery(c.writer, 'I')
336+
c.setTxError()
337+
writeReadyForQuery(c.writer, c.txStatus)
324338
c.writer.Flush()
325339
return nil
326340
}
@@ -350,10 +364,10 @@ func (c *clientConn) handleQuery(body []byte) error {
350364
rowCount++
351365
}
352366

353-
// Send command complete
367+
// Send command complete (SELECT doesn't change transaction status)
354368
tag := fmt.Sprintf("SELECT %d", rowCount)
355369
writeCommandComplete(c.writer, tag)
356-
writeReadyForQuery(c.writer, 'I')
370+
writeReadyForQuery(c.writer, c.txStatus)
357371
c.writer.Flush()
358372

359373
return nil
@@ -434,6 +448,26 @@ func (c *clientConn) getCommandType(upperQuery string) string {
434448
}
435449
}
436450

451+
// updateTxStatus updates the transaction status based on the executed command.
452+
// This is called after a successful command execution.
453+
func (c *clientConn) updateTxStatus(cmdType string) {
454+
switch cmdType {
455+
case "BEGIN":
456+
c.txStatus = txStatusTransaction
457+
case "COMMIT", "ROLLBACK":
458+
c.txStatus = txStatusIdle
459+
}
460+
// For other commands, keep the current status
461+
}
462+
463+
// setTxError marks the transaction as failed if we're in a transaction.
464+
// This should be called when a query fails within a transaction.
465+
func (c *clientConn) setTxError() {
466+
if c.txStatus == txStatusTransaction {
467+
c.txStatus = txStatusError
468+
}
469+
}
470+
437471
func (c *clientConn) buildCommandTag(cmdType string, result sql.Result) string {
438472
switch cmdType {
439473
case "INSERT":
@@ -475,14 +509,15 @@ func (c *clientConn) handleCopy(query, upperQuery string) error {
475509
result, err := c.db.Exec(query)
476510
if err != nil {
477511
c.sendError("ERROR", "42000", err.Error())
478-
writeReadyForQuery(c.writer, 'I')
512+
c.setTxError()
513+
writeReadyForQuery(c.writer, c.txStatus)
479514
c.writer.Flush()
480515
return nil
481516
}
482517

483518
rowsAffected, _ := result.RowsAffected()
484519
writeCommandComplete(c.writer, fmt.Sprintf("COPY %d", rowsAffected))
485-
writeReadyForQuery(c.writer, 'I')
520+
writeReadyForQuery(c.writer, c.txStatus)
486521
c.writer.Flush()
487522
return nil
488523
}
@@ -492,7 +527,8 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error {
492527
matches := copyToStdoutRegex.FindStringSubmatch(query)
493528
if len(matches) < 2 {
494529
c.sendError("ERROR", "42601", "Invalid COPY TO STDOUT syntax")
495-
writeReadyForQuery(c.writer, 'I')
530+
c.setTxError()
531+
writeReadyForQuery(c.writer, c.txStatus)
496532
c.writer.Flush()
497533
return nil
498534
}
@@ -518,7 +554,8 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error {
518554
rows, err := c.db.Query(selectQuery)
519555
if err != nil {
520556
c.sendError("ERROR", "42000", err.Error())
521-
writeReadyForQuery(c.writer, 'I')
557+
c.setTxError()
558+
writeReadyForQuery(c.writer, c.txStatus)
522559
c.writer.Flush()
523560
return nil
524561
}
@@ -527,7 +564,8 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error {
527564
cols, err := rows.Columns()
528565
if err != nil {
529566
c.sendError("ERROR", "42000", err.Error())
530-
writeReadyForQuery(c.writer, 'I')
567+
c.setTxError()
568+
writeReadyForQuery(c.writer, c.txStatus)
531569
c.writer.Flush()
532570
return nil
533571
}
@@ -578,7 +616,7 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error {
578616
}
579617

580618
writeCommandComplete(c.writer, fmt.Sprintf("COPY %d", rowCount))
581-
writeReadyForQuery(c.writer, 'I')
619+
writeReadyForQuery(c.writer, c.txStatus)
582620
c.writer.Flush()
583621
return nil
584622
}
@@ -588,7 +626,8 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error {
588626
matches := copyFromStdinRegex.FindStringSubmatch(query)
589627
if len(matches) < 2 {
590628
c.sendError("ERROR", "42601", "Invalid COPY FROM STDIN syntax")
591-
writeReadyForQuery(c.writer, 'I')
629+
c.setTxError()
630+
writeReadyForQuery(c.writer, c.txStatus)
592631
c.writer.Flush()
593632
return nil
594633
}
@@ -613,7 +652,8 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error {
613652
testRows, err := c.db.Query(colQuery)
614653
if err != nil {
615654
c.sendError("ERROR", "42P01", fmt.Sprintf("relation \"%s\" does not exist", tableName))
616-
writeReadyForQuery(c.writer, 'I')
655+
c.setTxError()
656+
writeReadyForQuery(c.writer, c.txStatus)
617657
c.writer.Flush()
618658
return nil
619659
}
@@ -679,29 +719,32 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error {
679719

680720
if _, err := c.db.Exec(insertSQL, args...); err != nil {
681721
c.sendError("ERROR", "22P02", fmt.Sprintf("invalid input: %v", err))
682-
writeReadyForQuery(c.writer, 'I')
722+
c.setTxError()
723+
writeReadyForQuery(c.writer, c.txStatus)
683724
c.writer.Flush()
684725
return nil
685726
}
686727
rowCount++
687728
}
688729

689730
writeCommandComplete(c.writer, fmt.Sprintf("COPY %d", rowCount))
690-
writeReadyForQuery(c.writer, 'I')
731+
writeReadyForQuery(c.writer, c.txStatus)
691732
c.writer.Flush()
692733
return nil
693734

694735
case msgCopyFail:
695736
// Client cancelled COPY
696737
errMsg := string(bytes.TrimRight(body, "\x00"))
697738
c.sendError("ERROR", "57014", fmt.Sprintf("COPY failed: %s", errMsg))
698-
writeReadyForQuery(c.writer, 'I')
739+
c.setTxError()
740+
writeReadyForQuery(c.writer, c.txStatus)
699741
c.writer.Flush()
700742
return nil
701743

702744
default:
703745
c.sendError("ERROR", "08P01", fmt.Sprintf("unexpected message type during COPY: %c", msgType))
704-
writeReadyForQuery(c.writer, 'I')
746+
c.setTxError()
747+
writeReadyForQuery(c.writer, c.txStatus)
705748
c.writer.Flush()
706749
return nil
707750
}
@@ -1189,8 +1232,10 @@ func (c *clientConn) handleExecute(body []byte) {
11891232
result, err := c.db.Exec(p.stmt.convertedQuery, args...)
11901233
if err != nil {
11911234
c.sendError("ERROR", "42000", err.Error())
1235+
c.setTxError()
11921236
return
11931237
}
1238+
c.updateTxStatus(cmdType)
11941239
tag := c.buildCommandTag(cmdType, result)
11951240
writeCommandComplete(c.writer, tag)
11961241
return
@@ -1200,13 +1245,15 @@ func (c *clientConn) handleExecute(body []byte) {
12001245
rows, err := c.db.Query(p.stmt.convertedQuery, args...)
12011246
if err != nil {
12021247
c.sendError("ERROR", "42000", err.Error())
1248+
c.setTxError()
12031249
return
12041250
}
12051251
defer rows.Close()
12061252

12071253
cols, err := rows.Columns()
12081254
if err != nil {
12091255
c.sendError("ERROR", "42000", err.Error())
1256+
c.setTxError()
12101257
return
12111258
}
12121259

@@ -1244,6 +1291,7 @@ func (c *clientConn) handleExecute(body []byte) {
12441291

12451292
if err := rows.Scan(valuePtrs...); err != nil {
12461293
c.sendError("ERROR", "42000", err.Error())
1294+
c.setTxError()
12471295
return
12481296
}
12491297

server/conn_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,63 @@ func TestRedactConnectionString(t *testing.T) {
272272
})
273273
}
274274
}
275+
276+
func TestTransactionStatusTracking(t *testing.T) {
277+
c := &clientConn{txStatus: txStatusIdle}
278+
279+
// Initially should be idle
280+
if c.txStatus != txStatusIdle {
281+
t.Errorf("initial txStatus = %c, want %c", c.txStatus, txStatusIdle)
282+
}
283+
284+
// BEGIN should set to transaction
285+
c.updateTxStatus("BEGIN")
286+
if c.txStatus != txStatusTransaction {
287+
t.Errorf("after BEGIN txStatus = %c, want %c", c.txStatus, txStatusTransaction)
288+
}
289+
290+
// SELECT should not change status
291+
c.updateTxStatus("SELECT")
292+
if c.txStatus != txStatusTransaction {
293+
t.Errorf("after SELECT txStatus = %c, want %c", c.txStatus, txStatusTransaction)
294+
}
295+
296+
// COMMIT should set back to idle
297+
c.updateTxStatus("COMMIT")
298+
if c.txStatus != txStatusIdle {
299+
t.Errorf("after COMMIT txStatus = %c, want %c", c.txStatus, txStatusIdle)
300+
}
301+
302+
// Test ROLLBACK path
303+
c.updateTxStatus("BEGIN")
304+
if c.txStatus != txStatusTransaction {
305+
t.Errorf("after second BEGIN txStatus = %c, want %c", c.txStatus, txStatusTransaction)
306+
}
307+
c.updateTxStatus("ROLLBACK")
308+
if c.txStatus != txStatusIdle {
309+
t.Errorf("after ROLLBACK txStatus = %c, want %c", c.txStatus, txStatusIdle)
310+
}
311+
}
312+
313+
func TestTransactionErrorStatus(t *testing.T) {
314+
c := &clientConn{txStatus: txStatusIdle}
315+
316+
// Error outside transaction should not change status
317+
c.setTxError()
318+
if c.txStatus != txStatusIdle {
319+
t.Errorf("error outside transaction txStatus = %c, want %c", c.txStatus, txStatusIdle)
320+
}
321+
322+
// Error inside transaction should set to error
323+
c.updateTxStatus("BEGIN")
324+
c.setTxError()
325+
if c.txStatus != txStatusError {
326+
t.Errorf("error inside transaction txStatus = %c, want %c", c.txStatus, txStatusError)
327+
}
328+
329+
// ROLLBACK should recover from error state
330+
c.updateTxStatus("ROLLBACK")
331+
if c.txStatus != txStatusIdle {
332+
t.Errorf("after ROLLBACK from error txStatus = %c, want %c", c.txStatus, txStatusIdle)
333+
}
334+
}

0 commit comments

Comments
 (0)