diff --git a/crdb/common.go b/crdb/common.go index 991f8c6..29063ce 100644 --- a/crdb/common.go +++ b/crdb/common.go @@ -14,7 +14,10 @@ package crdb -import "context" +import ( + "context" + "time" +) // Tx abstracts the operations needed by ExecuteInTx so that different // frameworks (e.g. go's sql package, pgx, gorm) can be used with ExecuteInTx. @@ -60,8 +63,10 @@ func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { return err } - maxRetries := numRetriesFromContext(ctx) - retryCount := 0 + // establish the retry policy + retryPolicy := getRetryPolicy(ctx) + // set up the retry policy state + retryFunc := retryPolicy.NewRetry() for { releaseFailed := false err = fn() @@ -82,13 +87,48 @@ func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { return err } - if rollbackErr := tx.Exec(ctx, "ROLLBACK TO SAVEPOINT cockroach_restart"); rollbackErr != nil { - return newTxnRestartError(rollbackErr, err) + // We have a retryable error. Check the retry policy. + delay, retryErr := retryFunc(err) + // Check if the context has been cancelled + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if delay > 0 && retryErr == nil { + // When backoff is needed, we don't want to hold locks while waiting for a backoff, + // so restart the entire transaction: + // - tx.Exec(ctx, "ROLLBACK") sends SQL to the server: + // it doesn't call tx.Rollback() (which would close the Go sql.Tx object) + // - The underlying connection remains open: the *sql.Tx wrapper maintains the database connection. + // Only the server-side transaction is rolled back. + // - tx.Exec(ctx, "BEGIN") starts a new server-side transaction on the same connection wrapped by the + // same *sql.Tx object + // - The defer handles cleanup - It calls tx.Rollback() (the Go method) only on errors, + // which closes the Go object and returns the connection to the pool + if restartErr := tx.Exec(ctx, "ROLLBACK"); restartErr != nil { + return newTxnRestartError(restartErr, err, "ROLLBACK") + } + if restartErr := tx.Exec(ctx, "BEGIN"); restartErr != nil { + return newTxnRestartError(restartErr, err, "BEGIN") + } + if restartErr := tx.Exec(ctx, "SAVEPOINT cockroach_restart"); restartErr != nil { + return newTxnRestartError(restartErr, err, "SAVEPOINT cockroach_restart") + } + } else { + if rollbackErr := tx.Exec(ctx, "ROLLBACK TO SAVEPOINT cockroach_restart"); rollbackErr != nil { + return newTxnRestartError(rollbackErr, err, "ROLLBACK TO SAVEPOINT cockroach_restart") + } + } + + if retryErr != nil { + return retryErr } - retryCount++ - if maxRetries > 0 && retryCount > maxRetries { - return newMaxRetriesExceededError(err, maxRetries) + if delay > 0 { + select { + case <-time.After(delay): + case <-ctx.Done(): + return ctx.Err() + } } } } diff --git a/crdb/error.go b/crdb/error.go index be18fcd..c907698 100644 --- a/crdb/error.go +++ b/crdb/error.go @@ -66,13 +66,13 @@ type TxnRestartError struct { msg string } -func newTxnRestartError(err error, retryErr error) *TxnRestartError { - const msgPattern = "restarting txn failed. ROLLBACK TO SAVEPOINT " + +func newTxnRestartError(err error, retryErr error, op string) *TxnRestartError { + const msgPattern = "restarting txn failed. %s " + "encountered error: %s. Original error: %s." return &TxnRestartError{ txError: txError{cause: err}, retryCause: retryErr, - msg: fmt.Sprintf(msgPattern, err, retryErr), + msg: fmt.Sprintf(msgPattern, op, err, retryErr), } } diff --git a/crdb/retry.go b/crdb/retry.go new file mode 100644 index 0000000..51f0a7f --- /dev/null +++ b/crdb/retry.go @@ -0,0 +1,284 @@ +// Copyright 2025 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package crdb + +import ( + "time" +) + +// RetryFunc owns the state for a transaction retry operation. Usually, this is +// just the retry count. RetryFunc is not assumed to be safe for concurrent use. +// +// The function is called after each retryable error to determine whether to +// retry and how long to wait. It receives the retryable error that triggered +// the retry attempt. +// +// Return values: +// - duration: The delay to wait before the next retry attempt. If 0, retry +// immediately without delay. +// - error: If non-nil, stops retrying and returns this error to the caller +// (typically a MaxRetriesExceededError). If nil, the retry will proceed +// after the specified duration. +// +// Example behavior: +// - (100ms, nil): Wait 100ms, then retry +// - (0, nil): Retry immediately +// - (0, err): Stop retrying, return err to caller +type RetryFunc func(err error) (time.Duration, error) + +// RetryPolicy constructs a new instance of a RetryFunc for each transaction +// it is used with. Instances of RetryPolicy can likely be immutable and +// should be safe for concurrent calls to NewRetry. +type RetryPolicy interface { + NewRetry() RetryFunc +} + +const ( + // NoRetries is a sentinel value for LimitBackoffRetryPolicy.RetryLimit + // indicating that no retries should be attempted. When a policy has + // RetryLimit set to NoRetries, the transaction will be attempted only + // once, and any retryable error will immediately return a + // MaxRetriesExceededError. + // + // Use WithNoRetries(ctx) to create a context with this behavior. + NoRetries = -1 + + // UnlimitedRetries indicates that retries should continue indefinitely + // until the transaction succeeds or a non-retryable error occurs. This + // is represented by setting RetryLimit to 0. + // + // Use WithMaxRetries(ctx, 0) to create a context with unlimited retries, + // though this is generally not recommended in production as it can lead + // to infinite retry loops. + UnlimitedRetries = 0 +) + +// LimitBackoffRetryPolicy implements RetryPolicy with a configurable retry limit +// and optional constant delay between retries. +// +// The RetryLimit field controls retry behavior: +// - Positive value (e.g., 10): Retry up to that many times before failing +// - UnlimitedRetries (0): Retry indefinitely until success or non-retryable error +// - NoRetries (-1) or any negative value: Do not retry; fail immediately on first retryable error +// +// If Delay is greater than zero, the policy will wait for the specified duration +// between retry attempts. +// +// Example usage with limited retries and no delay: +// +// policy := &LimitBackoffRetryPolicy{ +// RetryLimit: 10, +// Delay: 0, +// } +// ctx := crdb.WithRetryPolicy(context.Background(), policy) +// err := crdb.ExecuteTx(ctx, db, nil, func(tx *sql.Tx) error { +// // transaction logic +// }) +// +// Example usage with fixed delay between retries: +// +// policy := &LimitBackoffRetryPolicy{ +// RetryLimit: 5, +// Delay: 100 * time.Millisecond, +// } +// ctx := crdb.WithRetryPolicy(context.Background(), policy) +// +// Example usage with unlimited retries: +// +// policy := &LimitBackoffRetryPolicy{ +// RetryLimit: UnlimitedRetries, // or 0 +// Delay: 50 * time.Millisecond, +// } +// +// Note: Convenience functions are available: +// - WithMaxRetries(ctx, n) creates a LimitBackoffRetryPolicy with RetryLimit=n and Delay=0 +// - WithNoRetries(ctx) creates a LimitBackoffRetryPolicy with RetryLimit=NoRetries +type LimitBackoffRetryPolicy struct { + // RetryLimit controls the retry behavior: + // - Positive value: Maximum number of retries before returning MaxRetriesExceededError + // - UnlimitedRetries (0): Retry indefinitely + // - NoRetries (-1) or any negative value: Do not retry, fail immediately + RetryLimit int + + // Delay is the fixed duration to wait between retry attempts. If 0, + // retries happen immediately without delay. + Delay time.Duration +} + +// NewRetry implements RetryPolicy. +func (l *LimitBackoffRetryPolicy) NewRetry() RetryFunc { + tryCount := 0 + return func(err error) (time.Duration, error) { + tryCount++ + // Any negative value (including NoRetries) means fail immediately + if l.RetryLimit < UnlimitedRetries { + return 0, newMaxRetriesExceededError(err, 0) + } + // UnlimitedRetries (0) means retry indefinitely, so skip the limit check + // Any positive value enforces the retry limit + if l.RetryLimit > UnlimitedRetries && tryCount > l.RetryLimit { + return 0, newMaxRetriesExceededError(err, l.RetryLimit) + } + return l.Delay, nil + } +} + +// ExpBackoffRetryPolicy implements RetryPolicy using an exponential backoff strategy +// where delays double with each retry attempt, with an optional maximum delay cap. +// +// The delay between retries doubles with each attempt, starting from BaseDelay: +// - Retry 1: BaseDelay +// - Retry 2: BaseDelay * 2 +// - Retry 3: BaseDelay * 4 +// - Retry N: BaseDelay * 2^(N-1) +// +// If MaxDelay is set (> 0), the delay is capped at that value once reached. +// This prevents excessive wait times during high retry counts and provides a +// predictable upper bound for backoff duration. +// +// The policy will retry up to RetryLimit times. When the limit is exceeded or +// if the delay calculation overflows without a MaxDelay set, it returns a +// MaxRetriesExceededError. +// +// Example usage with capped exponential backoff: +// +// policy := &ExpBackoffRetryPolicy{ +// RetryLimit: 10, +// BaseDelay: 100 * time.Millisecond, +// MaxDelay: 5 * time.Second, +// } +// ctx := crdb.WithRetryPolicy(context.Background(), policy) +// err := crdb.ExecuteTx(ctx, db, nil, func(tx *sql.Tx) error { +// // transaction logic that may encounter retryable errors +// return tx.ExecContext(ctx, "UPDATE ...") +// }) +// +// This configuration produces delays: 100ms, 200ms, 400ms, 800ms, 1.6s, 3.2s, +// then stays at 5s for all subsequent retries. +// +// Example usage with unbounded exponential backoff: +// +// policy := &ExpBackoffRetryPolicy{ +// RetryLimit: 5, +// BaseDelay: 1 * time.Second, +// MaxDelay: 0, // no cap +// } +// +// This configuration produces delays: 1s, 2s, 4s, 8s, 16s. +// Note: Setting MaxDelay to 0 means no cap, but be aware that delay overflow +// will cause the policy to fail early. +type ExpBackoffRetryPolicy struct { + // RetryLimit is the maximum number of retries allowed. After this many + // retries, a MaxRetriesExceededError is returned. + RetryLimit int + + // BaseDelay is the initial delay before the first retry. Each subsequent + // retry doubles this value: delay = BaseDelay * 2^(attempt-1). + BaseDelay time.Duration + + // MaxDelay is the maximum delay cap. If > 0, delays are capped at this + // value once reached. If 0, delays grow unbounded (until overflow, which + // causes early termination). + MaxDelay time.Duration +} + +// NewRetry implements RetryPolicy. +func (l *ExpBackoffRetryPolicy) NewRetry() RetryFunc { + tryCount := 0 + return func(err error) (time.Duration, error) { + tryCount++ + if tryCount > l.RetryLimit { + return 0, newMaxRetriesExceededError(err, l.RetryLimit) + } + delay := l.BaseDelay << (tryCount - 1) + if l.MaxDelay > 0 && delay > l.MaxDelay { + return l.MaxDelay, nil + } + if delay < l.BaseDelay { + // We've overflowed. + if l.MaxDelay > 0 { + return l.MaxDelay, nil + } + // There's no max delay. Giving up is probably better in + // practice than using a 290-year MAX_INT delay. + return 0, newMaxRetriesExceededError(err, tryCount) + } + return delay, nil + } +} + +// Vargo adapts third-party backoff strategies (like those from github.com/sethvargo/go-retry) +// into a RetryPolicy without creating a direct dependency on those libraries. +// +// This function allows you to use any backoff implementation that conforms to the +// VargoBackoff interface, providing flexibility to integrate external retry strategies +// with CockroachDB transaction retries. +// +// Example usage with a hypothetical external backoff library: +// +// import "github.com/sethvargo/go-retry" +// +// // Create a retry policy using an external backoff strategy +// policy := crdb.Vargo(func() crdb.VargoBackoff { +// // Fibonacci backoff: 1s, 1s, 2s, 3s, 5s, 8s... +// return retry.NewFibonacci(1 * time.Second) +// }) +// ctx := crdb.WithRetryPolicy(context.Background(), policy) +// err := crdb.ExecuteTx(ctx, db, nil, func(tx *sql.Tx) error { +// // transaction logic +// }) +// +// The function parameter should return a fresh VargoBackoff instance for each +// transaction, as backoff state is not safe for concurrent use. +func Vargo(fn func() VargoBackoff) RetryPolicy { + return &vargoAdapter{ + DelegateFactory: fn, + } +} + +// VargoBackoff is an interface for external backoff strategies that provide +// delays through a Next() method. This allows adaptation of backoff policies +// from libraries like github.com/sethvargo/go-retry without creating a direct +// dependency. +// +// Next returns the next backoff duration and a boolean indicating whether to +// stop retrying. When stop is true, the retry loop terminates with a +// MaxRetriesExceededError. +type VargoBackoff interface { + // Next returns the next delay duration and whether to stop retrying. + // When stop is true, no more retries will be attempted. + Next() (next time.Duration, stop bool) +} + +// vargoAdapter adapts backoff policies in the style of github.com/sethvargo/go-retry. +type vargoAdapter struct { + DelegateFactory func() VargoBackoff +} + +// NewRetry implements RetryPolicy by delegating to the external backoff strategy. +// It creates a fresh backoff instance using DelegateFactory and wraps its Next() +// method to conform to the RetryFunc signature. +func (b *vargoAdapter) NewRetry() RetryFunc { + delegate := b.DelegateFactory() + count := 0 + return func(err error) (time.Duration, error) { + count++ + d, stop := delegate.Next() + if stop { + return 0, newMaxRetriesExceededError(err, count) + } + return d, nil + } +} diff --git a/crdb/retry_test.go b/crdb/retry_test.go new file mode 100644 index 0000000..85de403 --- /dev/null +++ b/crdb/retry_test.go @@ -0,0 +1,276 @@ +// Copyright 2025 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package crdb + +import ( + "errors" + "testing" + "time" +) + +func assertDelays(t *testing.T, policy RetryPolicy, expectedDelays []time.Duration) { + t.Helper() + actualDelays := make([]time.Duration, 0, len(expectedDelays)) + rf := policy.NewRetry() + + // Test with nil error (normal retry case) + for { + delay, err := rf(nil) + if err != nil { + break + } + + actualDelays = append(actualDelays, delay) + if len(actualDelays) > len(expectedDelays) { + t.Fatalf("too many retries: expected %d", len(expectedDelays)) + } + } + if len(actualDelays) != len(expectedDelays) { + t.Errorf("wrong number of retries: expected %d, got %d", len(expectedDelays), len(actualDelays)) + } + for i, delay := range actualDelays { + expected := expectedDelays[i] + if delay != expected { + t.Errorf("wrong delay at index %d: expected %d, got %d", i, expected, delay) + } + } + + // Test that RetryFunc also works when passed a non-nil error + // The error passed to RetryFunc should not affect the retry logic + rf2 := policy.NewRetry() + testErr := errors.New("test retryable error") + actualDelays2 := make([]time.Duration, 0, len(expectedDelays)) + for { + delay, err := rf2(testErr) + if err != nil { + break + } + actualDelays2 = append(actualDelays2, delay) + if len(actualDelays2) > len(expectedDelays) { + t.Fatalf("too many retries with non-nil err: expected %d", len(expectedDelays)) + } + } + if len(actualDelays2) != len(expectedDelays) { + t.Errorf("wrong number of retries with non-nil err: expected %d, got %d", len(expectedDelays), len(actualDelays2)) + } +} + +func TestLimitBackoffRetryPolicy(t *testing.T) { + policy := &LimitBackoffRetryPolicy{ + RetryLimit: 3, + Delay: 1 * time.Second, + } + assertDelays(t, policy, []time.Duration{ + 1 * time.Second, + 1 * time.Second, + 1 * time.Second, + }) +} + +func TestExpBackoffRetryPolicy(t *testing.T) { + policy := &ExpBackoffRetryPolicy{ + RetryLimit: 5, + BaseDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + } + assertDelays(t, policy, []time.Duration{ + 1 * time.Second, + 2 * time.Second, + 4 * time.Second, + 5 * time.Second, + 5 * time.Second, + }) +} + +func TestNoRetries(t *testing.T) { + policy := &LimitBackoffRetryPolicy{ + RetryLimit: NoRetries, + Delay: 0, + } + // NoRetries should fail immediately without any retries + assertDelays(t, policy, []time.Duration{}) + + // Verify the error is returned on first call + rf := policy.NewRetry() + testErr := errors.New("test error") + delay, err := rf(testErr) + if err == nil { + t.Error("expected error on first call with NoRetries, got nil") + } + if delay != 0 { + t.Errorf("expected delay 0, got %v", delay) + } +} + +func TestUnlimitedRetries(t *testing.T) { + policy := &LimitBackoffRetryPolicy{ + RetryLimit: UnlimitedRetries, + Delay: 10 * time.Millisecond, + } + + // Test that UnlimitedRetries continues beyond any reasonable limit + rf := policy.NewRetry() + testErr := errors.New("test error") + + // Try 1000 retries - should all succeed with no error + for i := 0; i < 1000; i++ { + delay, err := rf(testErr) + if err != nil { + t.Fatalf("unexpected error at retry %d: %v", i, err) + } + if delay != 10*time.Millisecond { + t.Errorf("wrong delay at retry %d: expected 10ms, got %v", i, delay) + } + } +} + +func TestLimitBackoffRetryPolicyEdgeCases(t *testing.T) { + t.Run("zero BaseDelay with LimitBackoffRetryPolicy", func(t *testing.T) { + policy := &LimitBackoffRetryPolicy{ + RetryLimit: 3, + Delay: 0, // zero delay = immediate retries + } + assertDelays(t, policy, []time.Duration{0, 0, 0}) + }) + + t.Run("negative RetryLimit less than NoRetries", func(t *testing.T) { + // Negative values other than NoRetries (-1) should be treated as invalid + // but the implementation currently treats any negative as "no retries" + policy := &LimitBackoffRetryPolicy{ + RetryLimit: -5, + Delay: 0, + } + rf := policy.NewRetry() + _, err := rf(errors.New("test")) + // Should fail immediately like NoRetries + if err == nil { + t.Error("expected error for negative RetryLimit < NoRetries, got nil") + } + }) + + t.Run("very large RetryLimit", func(t *testing.T) { + policy := &LimitBackoffRetryPolicy{ + RetryLimit: 1000000, + Delay: 0, + } + rf := policy.NewRetry() + // Should be able to retry many times + for i := 0; i < 100; i++ { + _, err := rf(nil) + if err != nil { + t.Fatalf("unexpected error at retry %d with large limit: %v", i, err) + } + } + }) +} + +func TestExpBackoffRetryPolicyEdgeCases(t *testing.T) { + t.Run("zero BaseDelay", func(t *testing.T) { + policy := &ExpBackoffRetryPolicy{ + RetryLimit: 3, + BaseDelay: 0, + MaxDelay: 1 * time.Second, + } + // With zero base delay, all delays should be 0 + assertDelays(t, policy, []time.Duration{0, 0, 0}) + }) + + t.Run("MaxDelay less than BaseDelay", func(t *testing.T) { + policy := &ExpBackoffRetryPolicy{ + RetryLimit: 3, + BaseDelay: 1 * time.Second, + MaxDelay: 100 * time.Millisecond, // smaller than base + } + // All delays should be capped at MaxDelay + assertDelays(t, policy, []time.Duration{ + 100 * time.Millisecond, + 100 * time.Millisecond, + 100 * time.Millisecond, + }) + }) + + t.Run("MaxDelay equals BaseDelay", func(t *testing.T) { + policy := &ExpBackoffRetryPolicy{ + RetryLimit: 3, + BaseDelay: 1 * time.Second, + MaxDelay: 1 * time.Second, // same as base + } + // All delays should be capped at MaxDelay (no exponential growth) + assertDelays(t, policy, []time.Duration{ + 1 * time.Second, + 1 * time.Second, + 1 * time.Second, + }) + }) + + t.Run("zero MaxDelay with potential overflow", func(t *testing.T) { + policy := &ExpBackoffRetryPolicy{ + RetryLimit: 100, + BaseDelay: 1 * time.Hour, + MaxDelay: 0, // no cap + } + rf := policy.NewRetry() + + // First few should work fine + for i := 0; i < 5; i++ { + delay, err := rf(nil) + if err != nil { + t.Fatalf("unexpected error at retry %d: %v", i, err) + } + expected := (1 * time.Hour) << i + if delay != expected { + t.Errorf("retry %d: expected delay %v, got %v", i, expected, delay) + } + } + + // Eventually should overflow and fail + var overflowed bool + for i := 5; i < 100; i++ { + _, err := rf(nil) + if err != nil { + overflowed = true + break + } + } + if !overflowed { + t.Error("expected overflow error with large base delay and no MaxDelay") + } + }) + + t.Run("single retry with exponential backoff", func(t *testing.T) { + policy := &ExpBackoffRetryPolicy{ + RetryLimit: 1, + BaseDelay: 100 * time.Millisecond, + MaxDelay: 0, + } + assertDelays(t, policy, []time.Duration{100 * time.Millisecond}) + }) + + t.Run("NoRetries with ExpBackoffRetryPolicy", func(t *testing.T) { + // ExpBackoffRetryPolicy doesn't have NoRetries logic, but testing + // with RetryLimit=0 to see behavior + policy := &ExpBackoffRetryPolicy{ + RetryLimit: 0, + BaseDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + } + rf := policy.NewRetry() + // With RetryLimit=0, first call should fail + _, err := rf(nil) + if err == nil { + t.Error("expected error with RetryLimit=0 on ExpBackoffRetryPolicy, got nil") + } + }) +} diff --git a/crdb/tx.go b/crdb/tx.go index 6e5f2d6..308f75e 100644 --- a/crdb/tx.go +++ b/crdb/tx.go @@ -20,6 +20,7 @@ import ( "context" "database/sql" "errors" + "time" ) // Execute runs fn and retries it as needed. It is used to add retry handling to @@ -48,20 +49,20 @@ import ( // following snippet, the original retryable error will be masked by the call to // fmt.Errorf, and the transaction will not be automatically retried. // -// crdb.Execute(func () error { -// rows, err := db.QueryContext(ctx, "SELECT ...") -// if err != nil { -// return fmt.Errorf("scanning row: %s", err) -// } -// defer rows.Close() -// for rows.Next() { -// // ... -// } -// if err := rows.Err(); err != nil { -// return fmt.Errorf("scanning row: %s", err) -// } -// return nil -// }) +// crdb.Execute(func () error { +// rows, err := db.QueryContext(ctx, "SELECT ...") +// if err != nil { +// return fmt.Errorf("scanning row: %s", err) +// } +// defer rows.Close() +// for rows.Next() { +// // ... +// } +// if err := rows.Err(); err != nil { +// return fmt.Errorf("scanning row: %s", err) +// } +// return nil +// }) // // Instead, add context by returning an error that implements either: // - a `Cause() error` method, in the manner of github.com/pkg/errors, or @@ -74,23 +75,22 @@ import ( // 1.13's special `%w` formatter with fmt.Errorf(), for example // fmt.Errorf("scanning row: %w", err). // -// import "github.com/pkg/errors" -// -// crdb.Execute(func () error { -// rows, err := db.QueryContext(ctx, "SELECT ...") -// if err != nil { -// return errors.Wrap(err, "scanning row") -// } -// defer rows.Close() -// for rows.Next() { -// // ... -// } -// if err := rows.Err(); err != nil { -// return errors.Wrap(err, "scanning row") -// } -// return nil -// }) +// import "github.com/pkg/errors" // +// crdb.Execute(func () error { +// rows, err := db.QueryContext(ctx, "SELECT ...") +// if err != nil { +// return errors.Wrap(err, "scanning row") +// } +// defer rows.Close() +// for rows.Next() { +// // ... +// } +// if err := rows.Err(); err != nil { +// return errors.Wrap(err, "scanning row") +// } +// return nil +// }) func Execute(fn func() error) (err error) { for { err = fn() @@ -105,7 +105,7 @@ func Execute(fn func() error) (err error) { // operations with configurable parameters. type ExecuteCtxFunc func(context.Context, ...interface{}) error -// ExecuteCtx runs fn and retries it as needed, respecting a maximum retry count +// ExecuteCtx runs fn and retries it as needed, respecting a retry policy // obtained from the context. It is used to add configurable retry handling to // the execution of a single statement. If a multi-statement transaction is // being run, use ExecuteTx instead. @@ -116,6 +116,8 @@ type ExecuteCtxFunc func(context.Context, ...interface{}) error // returns a max retries exceeded error wrapping the last retryable error // encountered. // +// Arbitrary retry policies can be configured using WithRetryPolicy(ctx, p). +// // The fn parameter accepts variadic arguments which are passed through on each // retry attempt, allowing for flexible parameterization of the retried operation. // @@ -143,8 +145,11 @@ type ExecuteCtxFunc func(context.Context, ...interface{}) error // return nil // }, userID) func ExecuteCtx(ctx context.Context, fn ExecuteCtxFunc, args ...interface{}) (err error) { - maxRetries := numRetriesFromContext(ctx) - for n := 0; n <= maxRetries; n++ { + // establish the retry policy + retryPolicy := getRetryPolicy(ctx) + // set up the retry policy state + retryFunc := retryPolicy.NewRetry() + for { if err = ctx.Err(); err != nil { return err } @@ -153,29 +158,93 @@ func ExecuteCtx(ctx context.Context, fn ExecuteCtxFunc, args ...interface{}) (er if err == nil || !errIsRetryable(err) { return err } + delay, retryErr := retryFunc(err) + if retryErr != nil { + return retryErr + } + if delay > 0 { + select { + case <-time.After(delay): + case <-ctx.Done(): + return ctx.Err() + } + } } - - return newMaxRetriesExceededError(err, maxRetries) } type txConfigKey struct{} -// WithMaxRetries configures context so that ExecuteTx retries tx specified -// number of times when encountering retryable errors. -// Setting retries to 0 will retry indefinitely. +// WithMaxRetries configures context so that ExecuteTx retries the transaction +// up to the specified number of times when encountering retryable errors. +// +// The retries parameter controls retry behavior: +// - Positive value (e.g., 10): Retry up to that many times before failing +// - 0 (UnlimitedRetries): Retry indefinitely until success or non-retryable error +// (not recommended in production as it can lead to infinite retry loops) +// +// This is a convenience function that creates a LimitBackoffRetryPolicy with +// no delay between retries (immediate retries). +// +// Example with limited retries: +// +// ctx := crdb.WithMaxRetries(context.Background(), 10) +// err := crdb.ExecuteTx(ctx, db, nil, func(tx *sql.Tx) error { +// // Will retry up to 10 times on retryable errors +// return tx.ExecContext(ctx, "UPDATE ...") +// }) +// +// Example with unlimited retries (use with caution): +// +// ctx := crdb.WithMaxRetries(context.Background(), 0) +// // Will retry indefinitely - ensure you have a context timeout! +// +// To disable retries entirely, use WithNoRetries(ctx) instead. func WithMaxRetries(ctx context.Context, retries int) context.Context { - return context.WithValue(ctx, txConfigKey{}, retries) + p := &LimitBackoffRetryPolicy{retries, 0} + return WithRetryPolicy(ctx, p) } -const defaultRetries = 50 +// WithNoRetries configures context so that ExecuteTx will not retry on +// retryable errors. The transaction will be attempted exactly once. +// +// This is useful when you want to handle retries manually or when operating +// in a context where automatic retries are not desired (e.g., in testing, +// or when implementing custom retry logic). +// +// Example usage: +// +// ctx := crdb.WithNoRetries(context.Background()) +// err := crdb.ExecuteTx(ctx, db, nil, func(tx *sql.Tx) error { +// // This will execute only once, no automatic retries +// return tx.ExecContext(ctx, "UPDATE ...") +// }) +// if err != nil { +// // Handle error manually, potentially implementing custom retry logic +// } +func WithNoRetries(ctx context.Context) context.Context { + p := &LimitBackoffRetryPolicy{NoRetries, 0} + return WithRetryPolicy(ctx, p) +} + +// WithRetryPolicy uses an arbitrary retry policy to perform retries. +func WithRetryPolicy(ctx context.Context, policy RetryPolicy) context.Context { + return context.WithValue(ctx, txConfigKey{}, policy) +} -func numRetriesFromContext(ctx context.Context) int { +// getRetryPolicy retrieves the RetryPolicy from the context or the default +func getRetryPolicy(ctx context.Context) RetryPolicy { + retryPolicy := defaultRetryPolicy if v := ctx.Value(txConfigKey{}); v != nil { - if retries, ok := v.(int); ok && retries >= 0 { - return retries - } + retryPolicy = v.(RetryPolicy) } - return defaultRetries + + return retryPolicy +} + +const defaultRetries = 50 + +var defaultRetryPolicy RetryPolicy = &LimitBackoffRetryPolicy{ + RetryLimit: defaultRetries, } // ExecuteTx runs fn inside a transaction and retries it as needed. On @@ -201,12 +270,12 @@ func numRetriesFromContext(ctx context.Context) int { // following snippet, the original retryable error will be masked by the call to // fmt.Errorf, and the transaction will not be automatically retried. // -// crdb.ExecuteTx(ctx, db, txopts, func (tx *sql.Tx) error { -// if err := tx.ExecContext(ctx, "UPDATE..."); err != nil { -// return fmt.Errorf("updating record: %s", err) -// } -// return nil -// }) +// crdb.ExecuteTx(ctx, db, txopts, func (tx *sql.Tx) error { +// if err := tx.ExecContext(ctx, "UPDATE..."); err != nil { +// return fmt.Errorf("updating record: %s", err) +// } +// return nil +// }) // // Instead, add context by returning an error that implements either: // - a `Cause() error` method, in the manner of github.com/pkg/errors, or @@ -219,15 +288,14 @@ func numRetriesFromContext(ctx context.Context) int { // 1.13's special `%w` formatter with fmt.Errorf(), for example // fmt.Errorf("scanning row: %w", err). // -// import "github.com/pkg/errors" -// -// crdb.ExecuteTx(ctx, db, txopts, func (tx *sql.Tx) error { -// if err := tx.ExecContext(ctx, "UPDATE..."); err != nil { -// return errors.Wrap(err, "updating record") -// } -// return nil -// }) +// import "github.com/pkg/errors" // +// crdb.ExecuteTx(ctx, db, txopts, func (tx *sql.Tx) error { +// if err := tx.ExecContext(ctx, "UPDATE..."); err != nil { +// return errors.Wrap(err, "updating record") +// } +// return nil +// }) func ExecuteTx(ctx context.Context, db *sql.DB, opts *sql.TxOptions, fn func(*sql.Tx) error) error { // Start a transaction. tx, err := db.BeginTx(ctx, opts) @@ -254,7 +322,7 @@ func (tx stdlibTxnAdapter) Commit(context.Context) error { return tx.tx.Commit() } -// Commit is part of the tx interface. +// Rollback is part of the tx interface. func (tx stdlibTxnAdapter) Rollback(context.Context) error { return tx.tx.Rollback() } diff --git a/crdb/tx_test.go b/crdb/tx_test.go index 5713be1..40cbbc5 100644 --- a/crdb/tx_test.go +++ b/crdb/tx_test.go @@ -101,13 +101,65 @@ func TestExecuteTx(t *testing.T) { // TestConfigureRetries verifies that the number of retries can be specified // via context. func TestConfigureRetries(t *testing.T) { - ctx := context.Background() - if numRetriesFromContext(ctx) != defaultRetries { - t.Fatal("expect default number of retries") - } + // Test no retries (using WithNoRetries) + ctx := WithNoRetries(context.Background()) + requireRetries(t, ctx, 0) + + // Test single retry + ctx = WithMaxRetries(context.Background(), 1) + requireRetries(t, ctx, 1) + + // Test default retries + ctx = context.Background() + requireRetries(t, ctx, defaultRetries) + + // Test custom retry limit ctx = WithMaxRetries(context.Background(), 123+defaultRetries) - if numRetriesFromContext(ctx) != defaultRetries+123 { - t.Fatal("expected default+123 retires") + requireRetries(t, ctx, 123+defaultRetries) + + // Test exponential backoff policy + ctx = WithRetryPolicy(context.Background(), &ExpBackoffRetryPolicy{ + RetryLimit: 10, + BaseDelay: 10, + MaxDelay: 1000, + }) + requireRetries(t, ctx, 10) + + // Test unlimited retries (0) - can't test easily without infinite loop, + // so we just verify the policy is set correctly + ctx = WithMaxRetries(context.Background(), 0) + p := getRetryPolicy(ctx) + if lbp, ok := p.(*LimitBackoffRetryPolicy); ok { + if lbp.RetryLimit != UnlimitedRetries { + t.Fatalf("expected UnlimitedRetries (0), got %d", lbp.RetryLimit) + } + } else { + t.Fatal("expected LimitBackoffRetryPolicy") + } +} + +func requireRetries(t *testing.T, ctx context.Context, numRetries int) { + p := getRetryPolicy(ctx) + if p == nil { + t.Fatal("expected non-nil retry policy") + } + + rf := p.NewRetry() + tryCount := 0 + for { + // we try + tryCount++ + + // Then, decide whether we're out of retries. + // The first try is not a retry, so we should + _, err := rf(nil) + if err != nil { + retryCount := tryCount - 1 + if retryCount != numRetries { + t.Fatalf("expected %d retries, got %d", numRetries, retryCount) + } + return + } } } diff --git a/testserver/version/version.go b/testserver/version/version.go index bc84d7f..11231d4 100644 --- a/testserver/version/version.go +++ b/testserver/version/version.go @@ -58,7 +58,8 @@ func (v *Version) Metadata() string { } // String returns the string representation, in the format: -// "v1.2.3-beta+md" +// +// "v1.2.3-beta+md" func (v Version) String() string { var b bytes.Buffer fmt.Fprintf(&b, "v%d.%d.%d", v.major, v.minor, v.patch) @@ -84,7 +85,9 @@ var numericRE = regexp.MustCompile(`^(0|[1-9][0-9]*)$`) // Parse creates a version from a string. The string must be a valid semantic // version (as per https://semver.org/spec/v2.0.0.html) in the format: -// "vMINOR.MAJOR.PATCH[-PRERELEASE][+METADATA]". +// +// "vMINOR.MAJOR.PATCH[-PRERELEASE][+METADATA]". +// // MINOR, MAJOR, and PATCH are numeric values (without any leading 0s). // PRERELEASE and METADATA can contain ASCII characters and digits, hyphens and // dots.