@@ -73,6 +73,7 @@ func testTransactionPGX(t *testing.T) {
7373 t .Run ("RetrySuccess" , testTransactionPGXRetrySuccess )
7474 t .Run ("MultipleFunctions" , testTransactionPGXMultipleFunctions )
7575 t .Run ("RealDatabase" , testTransactionPGXRealDatabase )
76+ t .Run ("ContextCancelled" , testTransactionPGXContextCancelled )
7677}
7778
7879func testTransactionPGXNilDatabase (t * testing.T ) {
@@ -510,6 +511,43 @@ func testTransactionPGXRealDatabase(t *testing.T) {
510511 assert .ElementsMatch (t , names , gotNames )
511512}
512513
514+ func testTransactionPGXContextCancelled (t * testing.T ) {
515+ t .Parallel ()
516+ if testing .Short () {
517+ t .Skip ("slow test" )
518+ }
519+ ctx , cancel := context .WithCancel (context .Background ())
520+ defer cancel ()
521+ addr := getDB (t )
522+ config , err := pgxpool .ParseConfig (addr )
523+ require .NoError (t , err )
524+ db , err := pgxpool .ConnectConfig (ctx , config )
525+ require .NoError (t , err )
526+
527+ tr , err := dbtools .NewTransaction (db , dbtools .RetryCount (10 ))
528+ require .NoError (t , err )
529+
530+ calls := 0
531+ // we are not using the same context to make sure the query causes the
532+ // error.
533+ err = tr .PGX (ctx , func (tx pgx.Tx ) error {
534+ calls ++
535+ query := `CREATE TABLE dbtest (name VARCHAR(100))`
536+ _ , err := tx .Exec (ctx , query )
537+ return err
538+ }, func (tx pgx.Tx ) error {
539+ cancel ()
540+ query := `INSERT INTO dbtest (name) VALUES ('a')`
541+ _ , err := tx .Exec (ctx , query )
542+ return err
543+ }, func (pgx.Tx ) error {
544+ t .Error ("didn't expect to get this" )
545+ return nil
546+ })
547+ assertInError (t , err , context .Canceled )
548+ assert .Equal (t , 1 , calls )
549+ }
550+
513551func testTransactionDB (t * testing.T ) {
514552 t .Run ("NilDatabase" , testTransactionDBNilDatabase )
515553 t .Run ("BeginError" , testTransactionDBBeginError )
@@ -522,6 +560,7 @@ func testTransactionDB(t *testing.T) {
522560 t .Run ("RetrySuccess" , testTransactionDBRetrySuccess )
523561 t .Run ("MultipleFunctions" , testTransactionDBMultipleFunctions )
524562 t .Run ("RealDatabase" , testTransactionDBRealDatabase )
563+ t .Run ("ContextCancelled" , testTransactionDBContextCancelled )
525564}
526565
527566func testTransactionDBNilDatabase (t * testing.T ) {
@@ -936,3 +975,39 @@ func testTransactionDBRealDatabase(t *testing.T) {
936975 require .NoError (t , err )
937976 assert .ElementsMatch (t , names , gotNames )
938977}
978+
979+ func testTransactionDBContextCancelled (t * testing.T ) {
980+ t .Parallel ()
981+ if testing .Short () {
982+ t .Skip ("slow test" )
983+ }
984+ addr := getDB (t )
985+ config , err := pgx .ParseConfig (addr )
986+ require .NoError (t , err )
987+ db := stdlib .OpenDB (* config )
988+ ctx , cancel := context .WithCancel (context .Background ())
989+ defer cancel ()
990+
991+ tr , err := dbtools .NewTransaction (db , dbtools .RetryCount (100 ))
992+ require .NoError (t , err )
993+
994+ calls := 0
995+ // we are not using the same context to make sure the query causes the
996+ // error.
997+ err = tr .DB (context .Background (), func (tx dbtools.Tx ) error {
998+ calls ++
999+ query := `CREATE TABLE dbtest (name VARCHAR(100) NOT NULL)`
1000+ _ , err := tx .ExecContext (ctx , query )
1001+ return err
1002+ }, func (tx dbtools.Tx ) error {
1003+ cancel ()
1004+ query := `INSERT INTO dbtest (name) VALUES ('a')`
1005+ _ , err := tx .ExecContext (ctx , query )
1006+ return err
1007+ }, func (dbtools.Tx ) error {
1008+ t .Error ("didn't expect to get this" )
1009+ return nil
1010+ })
1011+ assertInError (t , err , context .Canceled )
1012+ assert .Equal (t , 1 , calls )
1013+ }
0 commit comments