@@ -29,6 +29,13 @@ type portal struct {
2929 described bool // true if Describe was called on this portal
3030}
3131
32+ // Transaction status constants for PostgreSQL wire protocol
33+ const (
34+ txStatusIdle = 'I' // Not in a transaction
35+ txStatusTransaction = 'T' // In a transaction
36+ txStatusError = 'E' // In a failed transaction
37+ )
38+
3239type clientConn struct {
3340 server * Server
3441 conn net.Conn
@@ -40,6 +47,7 @@ type clientConn struct {
4047 pid int32
4148 stmts map [string ]* preparedStmt // prepared statements by name
4249 portals map [string ]* portal // portals by name
50+ txStatus byte // current transaction status ('I', 'T', or 'E')
4351}
4452
4553func (c * clientConn ) serve () error {
@@ -48,6 +56,7 @@ func (c *clientConn) serve() error {
4856 c .pid = int32 (os .Getpid ())
4957 c .stmts = make (map [string ]* preparedStmt )
5058 c .portals = make (map [string ]* portal )
59+ c .txStatus = txStatusIdle
5160
5261 // Handle startup
5362 if err := c .handleStartup (); err != nil {
@@ -75,7 +84,7 @@ func (c *clientConn) serve() error {
7584 c .sendInitialParams ()
7685
7786 // Send ready for query
78- if err := writeReadyForQuery (c .writer , 'I' ); err != nil {
87+ if err := writeReadyForQuery (c .writer , c . txStatus ); err != nil {
7988 return err
8089 }
8190 c .writer .Flush ()
@@ -235,7 +244,7 @@ func (c *clientConn) messageLoop() error {
235244
236245 case msgSync :
237246 // Extended query protocol - Sync
238- if err := writeReadyForQuery (c .writer , 'I' ); err != nil {
247+ if err := writeReadyForQuery (c .writer , c . txStatus ); err != nil {
239248 return err
240249 }
241250 c .writer .Flush ()
@@ -262,7 +271,7 @@ func (c *clientConn) handleQuery(body []byte) error {
262271
263272 if query == "" {
264273 writeEmptyQueryResponse (c .writer )
265- writeReadyForQuery (c .writer , 'I' )
274+ writeReadyForQuery (c .writer , c . txStatus )
266275 c .writer .Flush ()
267276 return nil
268277 }
@@ -286,14 +295,16 @@ func (c *clientConn) handleQuery(body []byte) error {
286295 result , err := c .db .Exec (query )
287296 if err != nil {
288297 c .sendError ("ERROR" , "42000" , err .Error ())
289- writeReadyForQuery (c .writer , 'I' )
298+ c .setTxError ()
299+ writeReadyForQuery (c .writer , c .txStatus )
290300 c .writer .Flush ()
291301 return nil
292302 }
293303
304+ c .updateTxStatus (cmdType )
294305 tag := c .buildCommandTag (cmdType , result )
295306 writeCommandComplete (c .writer , tag )
296- writeReadyForQuery (c .writer , 'I' )
307+ writeReadyForQuery (c .writer , c . txStatus )
297308 c .writer .Flush ()
298309 return nil
299310 }
@@ -302,7 +313,8 @@ func (c *clientConn) handleQuery(body []byte) error {
302313 rows , err := c .db .Query (query )
303314 if err != nil {
304315 c .sendError ("ERROR" , "42000" , err .Error ())
305- writeReadyForQuery (c .writer , 'I' )
316+ c .setTxError ()
317+ writeReadyForQuery (c .writer , c .txStatus )
306318 c .writer .Flush ()
307319 return nil
308320 }
@@ -312,15 +324,17 @@ func (c *clientConn) handleQuery(body []byte) error {
312324 cols , err := rows .Columns ()
313325 if err != nil {
314326 c .sendError ("ERROR" , "42000" , err .Error ())
315- writeReadyForQuery (c .writer , 'I' )
327+ c .setTxError ()
328+ writeReadyForQuery (c .writer , c .txStatus )
316329 c .writer .Flush ()
317330 return nil
318331 }
319332
320333 colTypes , err := rows .ColumnTypes ()
321334 if err != nil {
322335 c .sendError ("ERROR" , "42000" , err .Error ())
323- writeReadyForQuery (c .writer , 'I' )
336+ c .setTxError ()
337+ writeReadyForQuery (c .writer , c .txStatus )
324338 c .writer .Flush ()
325339 return nil
326340 }
@@ -350,10 +364,10 @@ func (c *clientConn) handleQuery(body []byte) error {
350364 rowCount ++
351365 }
352366
353- // Send command complete
367+ // Send command complete (SELECT doesn't change transaction status)
354368 tag := fmt .Sprintf ("SELECT %d" , rowCount )
355369 writeCommandComplete (c .writer , tag )
356- writeReadyForQuery (c .writer , 'I' )
370+ writeReadyForQuery (c .writer , c . txStatus )
357371 c .writer .Flush ()
358372
359373 return nil
@@ -434,6 +448,26 @@ func (c *clientConn) getCommandType(upperQuery string) string {
434448 }
435449}
436450
451+ // updateTxStatus updates the transaction status based on the executed command.
452+ // This is called after a successful command execution.
453+ func (c * clientConn ) updateTxStatus (cmdType string ) {
454+ switch cmdType {
455+ case "BEGIN" :
456+ c .txStatus = txStatusTransaction
457+ case "COMMIT" , "ROLLBACK" :
458+ c .txStatus = txStatusIdle
459+ }
460+ // For other commands, keep the current status
461+ }
462+
463+ // setTxError marks the transaction as failed if we're in a transaction.
464+ // This should be called when a query fails within a transaction.
465+ func (c * clientConn ) setTxError () {
466+ if c .txStatus == txStatusTransaction {
467+ c .txStatus = txStatusError
468+ }
469+ }
470+
437471func (c * clientConn ) buildCommandTag (cmdType string , result sql.Result ) string {
438472 switch cmdType {
439473 case "INSERT" :
@@ -475,14 +509,15 @@ func (c *clientConn) handleCopy(query, upperQuery string) error {
475509 result , err := c .db .Exec (query )
476510 if err != nil {
477511 c .sendError ("ERROR" , "42000" , err .Error ())
478- writeReadyForQuery (c .writer , 'I' )
512+ c .setTxError ()
513+ writeReadyForQuery (c .writer , c .txStatus )
479514 c .writer .Flush ()
480515 return nil
481516 }
482517
483518 rowsAffected , _ := result .RowsAffected ()
484519 writeCommandComplete (c .writer , fmt .Sprintf ("COPY %d" , rowsAffected ))
485- writeReadyForQuery (c .writer , 'I' )
520+ writeReadyForQuery (c .writer , c . txStatus )
486521 c .writer .Flush ()
487522 return nil
488523}
@@ -492,7 +527,8 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error {
492527 matches := copyToStdoutRegex .FindStringSubmatch (query )
493528 if len (matches ) < 2 {
494529 c .sendError ("ERROR" , "42601" , "Invalid COPY TO STDOUT syntax" )
495- writeReadyForQuery (c .writer , 'I' )
530+ c .setTxError ()
531+ writeReadyForQuery (c .writer , c .txStatus )
496532 c .writer .Flush ()
497533 return nil
498534 }
@@ -518,7 +554,8 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error {
518554 rows , err := c .db .Query (selectQuery )
519555 if err != nil {
520556 c .sendError ("ERROR" , "42000" , err .Error ())
521- writeReadyForQuery (c .writer , 'I' )
557+ c .setTxError ()
558+ writeReadyForQuery (c .writer , c .txStatus )
522559 c .writer .Flush ()
523560 return nil
524561 }
@@ -527,7 +564,8 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error {
527564 cols , err := rows .Columns ()
528565 if err != nil {
529566 c .sendError ("ERROR" , "42000" , err .Error ())
530- writeReadyForQuery (c .writer , 'I' )
567+ c .setTxError ()
568+ writeReadyForQuery (c .writer , c .txStatus )
531569 c .writer .Flush ()
532570 return nil
533571 }
@@ -578,7 +616,7 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error {
578616 }
579617
580618 writeCommandComplete (c .writer , fmt .Sprintf ("COPY %d" , rowCount ))
581- writeReadyForQuery (c .writer , 'I' )
619+ writeReadyForQuery (c .writer , c . txStatus )
582620 c .writer .Flush ()
583621 return nil
584622}
@@ -588,7 +626,8 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error {
588626 matches := copyFromStdinRegex .FindStringSubmatch (query )
589627 if len (matches ) < 2 {
590628 c .sendError ("ERROR" , "42601" , "Invalid COPY FROM STDIN syntax" )
591- writeReadyForQuery (c .writer , 'I' )
629+ c .setTxError ()
630+ writeReadyForQuery (c .writer , c .txStatus )
592631 c .writer .Flush ()
593632 return nil
594633 }
@@ -613,7 +652,8 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error {
613652 testRows , err := c .db .Query (colQuery )
614653 if err != nil {
615654 c .sendError ("ERROR" , "42P01" , fmt .Sprintf ("relation \" %s\" does not exist" , tableName ))
616- writeReadyForQuery (c .writer , 'I' )
655+ c .setTxError ()
656+ writeReadyForQuery (c .writer , c .txStatus )
617657 c .writer .Flush ()
618658 return nil
619659 }
@@ -679,29 +719,32 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error {
679719
680720 if _ , err := c .db .Exec (insertSQL , args ... ); err != nil {
681721 c .sendError ("ERROR" , "22P02" , fmt .Sprintf ("invalid input: %v" , err ))
682- writeReadyForQuery (c .writer , 'I' )
722+ c .setTxError ()
723+ writeReadyForQuery (c .writer , c .txStatus )
683724 c .writer .Flush ()
684725 return nil
685726 }
686727 rowCount ++
687728 }
688729
689730 writeCommandComplete (c .writer , fmt .Sprintf ("COPY %d" , rowCount ))
690- writeReadyForQuery (c .writer , 'I' )
731+ writeReadyForQuery (c .writer , c . txStatus )
691732 c .writer .Flush ()
692733 return nil
693734
694735 case msgCopyFail :
695736 // Client cancelled COPY
696737 errMsg := string (bytes .TrimRight (body , "\x00 " ))
697738 c .sendError ("ERROR" , "57014" , fmt .Sprintf ("COPY failed: %s" , errMsg ))
698- writeReadyForQuery (c .writer , 'I' )
739+ c .setTxError ()
740+ writeReadyForQuery (c .writer , c .txStatus )
699741 c .writer .Flush ()
700742 return nil
701743
702744 default :
703745 c .sendError ("ERROR" , "08P01" , fmt .Sprintf ("unexpected message type during COPY: %c" , msgType ))
704- writeReadyForQuery (c .writer , 'I' )
746+ c .setTxError ()
747+ writeReadyForQuery (c .writer , c .txStatus )
705748 c .writer .Flush ()
706749 return nil
707750 }
@@ -1189,8 +1232,10 @@ func (c *clientConn) handleExecute(body []byte) {
11891232 result , err := c .db .Exec (p .stmt .convertedQuery , args ... )
11901233 if err != nil {
11911234 c .sendError ("ERROR" , "42000" , err .Error ())
1235+ c .setTxError ()
11921236 return
11931237 }
1238+ c .updateTxStatus (cmdType )
11941239 tag := c .buildCommandTag (cmdType , result )
11951240 writeCommandComplete (c .writer , tag )
11961241 return
@@ -1200,13 +1245,15 @@ func (c *clientConn) handleExecute(body []byte) {
12001245 rows , err := c .db .Query (p .stmt .convertedQuery , args ... )
12011246 if err != nil {
12021247 c .sendError ("ERROR" , "42000" , err .Error ())
1248+ c .setTxError ()
12031249 return
12041250 }
12051251 defer rows .Close ()
12061252
12071253 cols , err := rows .Columns ()
12081254 if err != nil {
12091255 c .sendError ("ERROR" , "42000" , err .Error ())
1256+ c .setTxError ()
12101257 return
12111258 }
12121259
@@ -1244,6 +1291,7 @@ func (c *clientConn) handleExecute(body []byte) {
12441291
12451292 if err := rows .Scan (valuePtrs ... ); err != nil {
12461293 c .sendError ("ERROR" , "42000" , err .Error ())
1294+ c .setTxError ()
12471295 return
12481296 }
12491297
0 commit comments