Skip to content

Commit f596ec7

Browse files
committed
feat: accept multiple functions
1 parent eeda313 commit f596ec7

File tree

4 files changed

+113
-9
lines changed

4 files changed

+113
-9
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ err := l.Do(func() error {
2323
})
2424
```
2525

26+
You can provide multiple functions:
27+
28+
```go
29+
err := l.Do(func() error {
30+
return nil
31+
}, func() error {
32+
return nil
33+
}}
34+
```
35+
2636
If you want to stop retrying you can return a special error:
2737

2838
```go

retry.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,20 @@ type Retry struct {
4242
Attempts int
4343
}
4444

45+
type repeatFunc func() error
46+
4547
// Do calls fn until it returns nil or a StopError. It delays and retries if
4648
// the fn returns any errors or panics. The value fo the returned error, or the
4749
// Err of a StopError, or an error with the panic message will be returned at
4850
// the last cycle.
49-
func (r Retry) Do(fn func() error) error {
51+
func (r Retry) Do(fn1 repeatFunc, fns ...repeatFunc) error {
5052
method := r.Method
5153
if method == nil {
5254
method = StandardDelay
5355
}
5456
var err error
5557
for i := 0; i < r.Attempts; i++ {
56-
func() {
57-
defer func() {
58-
if e := recover(); e != nil {
59-
err = fmt.Errorf("function caused a panic: %v", e)
60-
}
61-
}()
62-
err = fn()
63-
}()
58+
err = r.do(fn1, fns...)
6459
if err == nil {
6560
return nil
6661
}
@@ -79,6 +74,24 @@ func (r Retry) Do(fn func() error) error {
7974
return err
8075
}
8176

77+
func (r Retry) do(fn1 repeatFunc, fns ...repeatFunc) error {
78+
var err error
79+
for _, fn := range append([]repeatFunc{fn1}, fns...) {
80+
func() {
81+
defer func() {
82+
if e := recover(); e != nil {
83+
err = fmt.Errorf("function caused a panic: %v", e)
84+
}
85+
}()
86+
err = fn()
87+
}()
88+
if err != nil {
89+
return err
90+
}
91+
}
92+
return nil
93+
}
94+
8295
// StandardDelay always delays the same amount of time.
8396
func StandardDelay(_ int, delay time.Duration) time.Duration { return delay }
8497

retry_example_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,27 @@ func ExampleIncrementalDelay() {
136136
// Output:
137137
// Error: <nil>
138138
}
139+
140+
func ExampleRetry_Do_multipleFuncs() {
141+
l := &retry.Retry{
142+
Attempts: 4,
143+
Delay: time.Nanosecond,
144+
}
145+
err := l.Do(func() error {
146+
fmt.Println("Running func 1.")
147+
return nil
148+
}, func() error {
149+
fmt.Println("Running func 2.")
150+
return nil
151+
}, func() error {
152+
fmt.Println("Running func 3.")
153+
return nil
154+
})
155+
fmt.Println("Error:", err)
156+
157+
// Output:
158+
// Running func 1.
159+
// Running func 2.
160+
// Running func 3.
161+
// Error: <nil>
162+
}

retry_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ func TestLoopDo(t *testing.T) {
1717
t.Run("Stopping", testLoopDoStopping)
1818
t.Run("Panic", testLoopDoPanic)
1919
t.Run("Sleep", testLoopDoSleep)
20+
t.Run("MultiFunc", testLoopDoMultiFunc)
2021
}
2122

2223
func testLoopDoReturn(t *testing.T) {
@@ -256,3 +257,59 @@ func testLoopDoSleepIncrementalMethodZero(t *testing.T) {
256257
fmt.Sprintf("take (%s) more than expected: %s", finished.Sub(started), time.Second),
257258
)
258259
}
260+
261+
func testLoopDoMultiFunc(t *testing.T) {
262+
t.Parallel()
263+
t.Run("FirstErrors", testLoopDoMultiFuncFirstErrors)
264+
t.Run("SecondErrors", testLoopDoMultiFuncSecondErrors)
265+
t.Run("NoErrors", testLoopDoMultiFuncNoErrors)
266+
}
267+
268+
func testLoopDoMultiFuncFirstErrors(t *testing.T) {
269+
t.Parallel()
270+
l := &retry.Retry{
271+
Attempts: 3,
272+
}
273+
err := l.Do(func() error {
274+
return assert.AnError
275+
}, func() error {
276+
t.Error("should not be called")
277+
return nil
278+
})
279+
assert.Equal(t, assert.AnError, errors.Cause(err))
280+
}
281+
282+
func testLoopDoMultiFuncSecondErrors(t *testing.T) {
283+
t.Parallel()
284+
l := &retry.Retry{
285+
Attempts: 3,
286+
}
287+
288+
calls := 0
289+
err := l.Do(func() error {
290+
calls++
291+
return nil
292+
}, func() error {
293+
return assert.AnError
294+
})
295+
assert.Equal(t, assert.AnError, errors.Cause(err))
296+
assert.Equal(t, 3, calls)
297+
}
298+
299+
func testLoopDoMultiFuncNoErrors(t *testing.T) {
300+
t.Parallel()
301+
l := &retry.Retry{
302+
Attempts: 3,
303+
}
304+
305+
calls := 0
306+
err := l.Do(func() error {
307+
calls++
308+
return nil
309+
}, func() error {
310+
calls++
311+
return nil
312+
})
313+
assert.NoError(t, err)
314+
assert.Equal(t, 2, calls)
315+
}

0 commit comments

Comments
 (0)