Skip to content

Commit 56ec015

Browse files
authored
Fix tests flakyness (#78)
1 parent 5df417e commit 56ec015

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

extension/amqp/listener.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ import (
55
"io"
66
"time"
77

8-
"github.com/hellofresh/goengine/v2"
9-
"github.com/hellofresh/goengine/v2/driver/sql"
108
"github.com/mailru/easyjson"
119
"github.com/streadway/amqp"
10+
11+
"github.com/hellofresh/goengine/v2"
12+
"github.com/hellofresh/goengine/v2/driver/sql"
1213
)
1314

1415
// Ensure Listener implements sql.Listener
@@ -24,6 +25,7 @@ type (
2425
minReconnectInterval time.Duration
2526
maxReconnectInterval time.Duration
2627
logger goengine.Logger
28+
waitFn func(time.Duration)
2729
}
2830
)
2931

@@ -48,9 +50,15 @@ func NewListener(
4850
minReconnectInterval: minReconnectInterval,
4951
maxReconnectInterval: maxReconnectInterval,
5052
logger: logger,
53+
waitFn: time.Sleep,
5154
}, nil
5255
}
5356

57+
// WithWaitFn replaces the default function called to wait (time.Sleep)
58+
func (l *Listener) WithWaitFn(fn func(time.Duration)) {
59+
l.waitFn = fn
60+
}
61+
5462
// Listen receives messages from a queue, transforms them into a sql.ProjectionNotification and calls the trigger
5563
func (l *Listener) Listen(ctx context.Context, trigger sql.ProjectionTrigger) error {
5664
var nextReconnect time.Time
@@ -69,7 +77,7 @@ func (l *Listener) Listen(ctx context.Context, trigger sql.ProjectionTrigger) er
6977
entry.String("reconnect_in", reconnectInterval.String())
7078
})
7179

72-
time.Sleep(reconnectInterval)
80+
l.waitFn(reconnectInterval)
7381
reconnectInterval *= 2
7482
if reconnectInterval > l.maxReconnectInterval {
7583
reconnectInterval = l.maxReconnectInterval
@@ -85,7 +93,7 @@ func (l *Listener) Listen(ctx context.Context, trigger sql.ProjectionTrigger) er
8593
case <-ctx.Done():
8694
return context.Canceled
8795
default:
88-
time.Sleep(time.Until(nextReconnect))
96+
l.waitFn(time.Until(nextReconnect))
8997
}
9098
}
9199
}

extension/amqp/listener_test.go

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ import (
1010
"testing"
1111
"time"
1212

13-
"github.com/hellofresh/goengine/v2"
14-
"github.com/hellofresh/goengine/v2/driver/sql"
15-
"github.com/hellofresh/goengine/v2/extension/amqp"
16-
goengineLogger "github.com/hellofresh/goengine/v2/extension/logrus"
1713
"github.com/sirupsen/logrus"
1814
"github.com/sirupsen/logrus/hooks/test"
1915
libamqp "github.com/streadway/amqp"
2016
"github.com/stretchr/testify/assert"
2117
"github.com/stretchr/testify/require"
18+
19+
"github.com/hellofresh/goengine/v2"
20+
"github.com/hellofresh/goengine/v2/driver/sql"
21+
"github.com/hellofresh/goengine/v2/extension/amqp"
22+
goengineLogger "github.com/hellofresh/goengine/v2/extension/logrus"
2223
)
2324

2425
func TestListener_Listen(t *testing.T) {
@@ -79,47 +80,42 @@ func TestListener_Listen(t *testing.T) {
7980
ctx, ctxCancel := context.WithTimeout(context.Background(), time.Second)
8081
defer ctxCancel()
8182

82-
var consumeCalls []time.Time
83+
var waitCalls []time.Duration
8384
consume := func() (io.Closer, <-chan libamqp.Delivery, error) {
84-
consumeCalls = append(consumeCalls, time.Now())
85-
if len(consumeCalls) == 5 {
85+
if len(waitCalls) == 5 {
8686
ctxCancel()
8787
}
8888

89-
return nil, nil, fmt.Errorf("failure %d", len(consumeCalls))
89+
return nil, nil, fmt.Errorf("failure %d", len(waitCalls))
9090
}
9191

9292
logger, loggerHook := getLogger()
9393

9494
listener, err := amqp.NewListener(consume, time.Millisecond, 6*time.Millisecond, logger)
9595
ensure.NoError(err)
9696

97+
listener.WithWaitFn(func(d time.Duration) {
98+
waitCalls = append(waitCalls, d)
99+
})
100+
97101
err = listener.Listen(ctx, func(ctx context.Context, notification *sql.ProjectionNotification) error {
98102
ensure.Fail("Trigger should ever be called")
99103
return nil
100104
})
101105

102106
ensure.Equal(context.Canceled, err)
103107

104-
reconnectIntervals := []time.Duration{time.Millisecond, time.Millisecond * 2, time.Millisecond * 4, time.Millisecond * 6, time.Millisecond * 6}
105-
ensure.Len(consumeCalls, len(reconnectIntervals))
106-
for i := 1; i < len(reconnectIntervals); i++ {
107-
expectedInterval := reconnectIntervals[i-1]
108-
interval := consumeCalls[i].Sub(consumeCalls[i-1])
109-
110-
if expectedInterval > interval || interval > (expectedInterval+time.Millisecond*2) {
111-
assert.Fail(t, fmt.Sprintf("Invalid interval after consume %d (got %s expected between %s and %s)", i, interval, expectedInterval, expectedInterval+time.Millisecond))
112-
}
113-
}
108+
reconnectIntervals := []time.Duration{time.Millisecond, time.Millisecond * 2, time.Millisecond * 4, time.Millisecond * 6, time.Millisecond * 6, time.Millisecond * 6}
109+
ensure.Equal(waitCalls, reconnectIntervals)
114110

115111
// Ensure we get log output
116112
logEntries := loggerHook.AllEntries()
117-
ensure.Len(logEntries, len(reconnectIntervals))
113+
ensure.Len(logEntries, len(waitCalls))
118114
for i, log := range logEntries {
119115
assert.Equal(t, log.Level, logrus.ErrorLevel)
120116
assert.Equal(t, log.Message, "failed to start consuming amqp messages")
121-
assert.Equal(t, fmt.Errorf("failure %d", i+1), log.Data["error"])
122-
assert.Equal(t, reconnectIntervals[i].String(), log.Data["reconnect_in"])
117+
assert.Equal(t, fmt.Errorf("failure %d", i), log.Data["error"])
118+
assert.Equal(t, waitCalls[i].String(), log.Data["reconnect_in"])
123119
}
124120
})
125121

0 commit comments

Comments
 (0)