Skip to content

Commit 80ef5a5

Browse files
committed
ref(db): stop retrying on context is cancelled
When the context is cancelled, the query will fail and there is no point to retry.
1 parent 93f13ca commit 80ef5a5

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

db.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ func (t *Transaction) PGX(ctx context.Context, transactions ...func(pgx.Tx) erro
101101
}()
102102
err = fn(tx)
103103
}()
104+
if errors.Is(err, context.Canceled) {
105+
err = &retry.StopError{Err: err}
106+
ctx = context.Background()
107+
}
104108
if err != nil {
105109
e := errors.Wrap(tx.Rollback(ctx), "rolling back transaction")
106110
return multierror.Append(err, e)
@@ -138,6 +142,10 @@ func (t *Transaction) DB(ctx context.Context, transactions ...func(Tx) error) er
138142
}()
139143
err = fn(tx)
140144
}()
145+
if errors.Is(err, context.Canceled) {
146+
err = &retry.StopError{Err: err}
147+
ctx = context.Background()
148+
}
141149
if err != nil {
142150
e := errors.Wrap(tx.Rollback(), "rolling back transaction")
143151
return multierror.Append(err, e)

db_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func testTransactionPGX(t *testing.T) {
7373
t.Run("RetrySuccess", testTransactionPGXRetrySuccess)
7474
t.Run("MultipleFunctions", testTransactionPGXMultipleFunctions)
7575
t.Run("RealDatabase", testTransactionPGXRealDatabase)
76+
t.Run("ContextCancelled", testTransactionPGXContextCancelled)
7677
}
7778

7879
func testTransactionPGXNilDatabase(t *testing.T) {
@@ -510,6 +511,43 @@ func testTransactionPGXRealDatabase(t *testing.T) {
510511
assert.ElementsMatch(t, names, gotNames)
511512
}
512513

514+
func testTransactionPGXContextCancelled(t *testing.T) {
515+
t.Parallel()
516+
if testing.Short() {
517+
t.Skip("slow test")
518+
}
519+
ctx, cancel := context.WithCancel(context.Background())
520+
defer cancel()
521+
addr := getDB(t)
522+
config, err := pgxpool.ParseConfig(addr)
523+
require.NoError(t, err)
524+
db, err := pgxpool.ConnectConfig(ctx, config)
525+
require.NoError(t, err)
526+
527+
tr, err := dbtools.NewTransaction(db, dbtools.RetryCount(10))
528+
require.NoError(t, err)
529+
530+
calls := 0
531+
// we are not using the same context to make sure the query causes the
532+
// error.
533+
err = tr.PGX(ctx, func(tx pgx.Tx) error {
534+
calls++
535+
query := `CREATE TABLE dbtest (name VARCHAR(100))`
536+
_, err := tx.Exec(ctx, query)
537+
return err
538+
}, func(tx pgx.Tx) error {
539+
cancel()
540+
query := `INSERT INTO dbtest (name) VALUES ('a')`
541+
_, err := tx.Exec(ctx, query)
542+
return err
543+
}, func(pgx.Tx) error {
544+
t.Error("didn't expect to get this")
545+
return nil
546+
})
547+
assertInError(t, err, context.Canceled)
548+
assert.Equal(t, 1, calls)
549+
}
550+
513551
func testTransactionDB(t *testing.T) {
514552
t.Run("NilDatabase", testTransactionDBNilDatabase)
515553
t.Run("BeginError", testTransactionDBBeginError)
@@ -522,6 +560,7 @@ func testTransactionDB(t *testing.T) {
522560
t.Run("RetrySuccess", testTransactionDBRetrySuccess)
523561
t.Run("MultipleFunctions", testTransactionDBMultipleFunctions)
524562
t.Run("RealDatabase", testTransactionDBRealDatabase)
563+
t.Run("ContextCancelled", testTransactionDBContextCancelled)
525564
}
526565

527566
func testTransactionDBNilDatabase(t *testing.T) {
@@ -936,3 +975,39 @@ func testTransactionDBRealDatabase(t *testing.T) {
936975
require.NoError(t, err)
937976
assert.ElementsMatch(t, names, gotNames)
938977
}
978+
979+
func testTransactionDBContextCancelled(t *testing.T) {
980+
t.Parallel()
981+
if testing.Short() {
982+
t.Skip("slow test")
983+
}
984+
addr := getDB(t)
985+
config, err := pgx.ParseConfig(addr)
986+
require.NoError(t, err)
987+
db := stdlib.OpenDB(*config)
988+
ctx, cancel := context.WithCancel(context.Background())
989+
defer cancel()
990+
991+
tr, err := dbtools.NewTransaction(db, dbtools.RetryCount(100))
992+
require.NoError(t, err)
993+
994+
calls := 0
995+
// we are not using the same context to make sure the query causes the
996+
// error.
997+
err = tr.DB(context.Background(), func(tx dbtools.Tx) error {
998+
calls++
999+
query := `CREATE TABLE dbtest (name VARCHAR(100) NOT NULL)`
1000+
_, err := tx.ExecContext(ctx, query)
1001+
return err
1002+
}, func(tx dbtools.Tx) error {
1003+
cancel()
1004+
query := `INSERT INTO dbtest (name) VALUES ('a')`
1005+
_, err := tx.ExecContext(ctx, query)
1006+
return err
1007+
}, func(dbtools.Tx) error {
1008+
t.Error("didn't expect to get this")
1009+
return nil
1010+
})
1011+
assertInError(t, err, context.Canceled)
1012+
assert.Equal(t, 1, calls)
1013+
}

0 commit comments

Comments
 (0)