Skip to content

Commit 169733a

Browse files
authored
Merge pull request avast#71 from Hrily/hrishi/attempts-for-error
feat: add support for attempts based on error types
2 parents 3472f1e + e575c17 commit 169733a

File tree

3 files changed

+63
-20
lines changed

3 files changed

+63
-20
lines changed

options.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@ type Timer interface {
2323
}
2424

2525
type Config struct {
26-
attempts uint
27-
delay time.Duration
28-
maxDelay time.Duration
29-
maxJitter time.Duration
30-
onRetry OnRetryFunc
31-
retryIf RetryIfFunc
32-
delayType DelayTypeFunc
33-
lastErrorOnly bool
34-
context context.Context
35-
timer Timer
26+
attempts uint
27+
attemptsForError map[error]uint
28+
delay time.Duration
29+
maxDelay time.Duration
30+
maxJitter time.Duration
31+
onRetry OnRetryFunc
32+
retryIf RetryIfFunc
33+
delayType DelayTypeFunc
34+
lastErrorOnly bool
35+
context context.Context
36+
timer Timer
3637

3738
maxBackOffN uint
3839
}
@@ -58,6 +59,15 @@ func Attempts(attempts uint) Option {
5859
}
5960
}
6061

62+
// AttemptsForError sets count of retry in case execution results in given `err`
63+
// Retries for the given `err` are also counted against total retries.
64+
// The retry will stop if any of given retries is exhausted.
65+
func AttemptsForError(attempts uint, err error) Option {
66+
return func(c *Config) {
67+
c.attemptsForError[err] = attempts
68+
}
69+
}
70+
6171
// Delay set delay between retry
6272
// default is 100ms
6373
func Delay(delay time.Duration) Option {

retry.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,14 @@ func Do(retryableFunc RetryableFunc, opts ...Option) error {
117117
errorLog = make(Error, 1)
118118
}
119119

120+
attemptsForError := make(map[error]uint, len(config.attemptsForError))
121+
for err, attempts := range config.attemptsForError {
122+
attemptsForError[err] = attempts
123+
}
124+
120125
lastErrIndex := n
121-
for n < config.attempts {
126+
shouldRetry := true
127+
for shouldRetry {
122128
err := retryableFunc()
123129

124130
if err != nil {
@@ -130,6 +136,14 @@ func Do(retryableFunc RetryableFunc, opts ...Option) error {
130136

131137
config.onRetry(n, err)
132138

139+
for errToCheck, attempts := range attemptsForError {
140+
if errors.Is(err, errToCheck) {
141+
attempts--
142+
attemptsForError[errToCheck] = attempts
143+
shouldRetry = shouldRetry && attempts > 0
144+
}
145+
}
146+
133147
// if this is last attempt - don't wait
134148
if n == config.attempts-1 {
135149
break
@@ -150,6 +164,8 @@ func Do(retryableFunc RetryableFunc, opts ...Option) error {
150164
}
151165

152166
n++
167+
shouldRetry = shouldRetry && n < config.attempts
168+
153169
if !config.lastErrorOnly {
154170
lastErrIndex = n
155171
}
@@ -163,15 +179,16 @@ func Do(retryableFunc RetryableFunc, opts ...Option) error {
163179

164180
func newDefaultRetryConfig() *Config {
165181
return &Config{
166-
attempts: uint(10),
167-
delay: 100 * time.Millisecond,
168-
maxJitter: 100 * time.Millisecond,
169-
onRetry: func(n uint, err error) {},
170-
retryIf: IsRecoverable,
171-
delayType: CombineDelay(BackOffDelay, RandomDelay),
172-
lastErrorOnly: false,
173-
context: context.Background(),
174-
timer: &timerImpl{},
182+
attempts: uint(10),
183+
attemptsForError: make(map[error]uint),
184+
delay: 100 * time.Millisecond,
185+
maxJitter: 100 * time.Millisecond,
186+
onRetry: func(n uint, err error) {},
187+
retryIf: IsRecoverable,
188+
delayType: CombineDelay(BackOffDelay, RandomDelay),
189+
lastErrorOnly: false,
190+
context: context.Background(),
191+
timer: &timerImpl{},
175192
}
176193
}
177194

retry_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,22 @@ func TestZeroAttemptsWithoutError(t *testing.T) {
110110
assert.Equal(t, count, 1)
111111
}
112112

113+
func TestAttemptsForError(t *testing.T) {
114+
count := uint(0)
115+
testErr := os.ErrInvalid
116+
attemptsForTestError := uint(3)
117+
err := Do(
118+
func() error {
119+
count++
120+
return testErr
121+
},
122+
AttemptsForError(attemptsForTestError, testErr),
123+
Attempts(5),
124+
)
125+
assert.Error(t, err)
126+
assert.Equal(t, attemptsForTestError, count)
127+
}
128+
113129
func TestDefaultSleep(t *testing.T) {
114130
start := time.Now()
115131
err := Do(

0 commit comments

Comments
 (0)