Skip to content

Commit 4cf7478

Browse files
authored
test(hook): add more tests (#4)
1 parent 2527043 commit 4cf7478

File tree

7 files changed

+236
-66
lines changed

7 files changed

+236
-66
lines changed

driver_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package sqlbreaker
2+
3+
import (
4+
"database/sql/driver"
5+
"testing"
6+
7+
"github.com/chenquan/sqlbreaker/pkg/breaker"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestNewDefaultDriver(t *testing.T) {
12+
assert.NotNil(t, NewDefaultDriver(driver.Driver(nil)))
13+
}
14+
15+
func TestNewDriver(t *testing.T) {
16+
assert.NotNil(t, NewDriver(breaker.Breaker(nil), driver.Driver(nil)))
17+
}
18+
19+
func TestNewBreakerHook(t *testing.T) {
20+
hook := NewBreakerHook(breaker.Breaker(nil))
21+
assert.Equal(t, &Hook{brk: breaker.Breaker(nil)}, hook)
22+
}

hook.go

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7+
"errors"
78

89
"github.com/chenquan/sqlbreaker/pkg/breaker"
910
"github.com/chenquan/sqlplus"
@@ -123,30 +124,22 @@ func (h *Hook) allow(ctx context.Context) (context.Context, error) {
123124
if err != nil {
124125
return ctx, err
125126
}
126-
ctx = newContextWithAllow(ctx, allow)
127+
ctx = context.WithValue(ctx, allowKey{}, allow)
127128

128-
return ctx, nil
129+
return ctx, err
129130
}
130131

131132
func (h *Hook) handleAllow(ctx context.Context, err error) {
132-
allow := allowFromContext(ctx)
133-
if err == nil || sql.ErrNoRows == err {
134-
allow.Accept()
133+
value := ctx.Value(allowKey{})
134+
if value == nil {
135135
return
136136
}
137137

138-
allow.Reject(err.Error())
139-
}
140-
141-
func newContextWithAllow(ctx context.Context, allow breaker.Promise) context.Context {
142-
return context.WithValue(ctx, allowKey{}, allow)
143-
}
144-
145-
func allowFromContext(ctx context.Context) (allow breaker.Promise) {
146-
value := ctx.Value(allowKey{})
147-
if value != nil {
148-
return &breaker.NopPromise{}
138+
allow := value.(breaker.Promise)
139+
if err == nil || errors.Is(err, sql.ErrNoRows) {
140+
allow.Accept()
141+
return
149142
}
150143

151-
return value.(breaker.Promise)
144+
allow.Reject(err.Error())
152145
}

hook_test.go

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
package sqlbreaker
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"errors"
8+
"testing"
9+
10+
"github.com/chenquan/sqlbreaker/pkg/breaker"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
func TestHook_BeginTx(t *testing.T) {
15+
breakerHook := NewBreakerHook(breaker.NewBreaker())
16+
checkWithContext(t, func(ctx context.Context) (context.Context, error) {
17+
ctx, _, err := breakerHook.BeforeBeginTx(ctx, driver.TxOptions{
18+
Isolation: 0,
19+
ReadOnly: false,
20+
}, nil)
21+
return ctx, err
22+
}, func(ctx context.Context, err error) (context.Context, error) {
23+
ctx, _, err = breakerHook.AfterBeginTx(ctx, driver.TxOptions{}, nil, err)
24+
return ctx, err
25+
})
26+
}
27+
28+
func TestHook_Connect(t *testing.T) {
29+
breakerHook := NewBreakerHook(breaker.NewBreaker())
30+
31+
check(t, func(ctx context.Context) (context.Context, error) {
32+
ctx, err := breakerHook.BeforeConnect(ctx, nil)
33+
return ctx, err
34+
}, func(ctx context.Context) (context.Context, error) {
35+
ctx, _, err := breakerHook.AfterConnect(ctx, nil, nil)
36+
return ctx, err
37+
})
38+
}
39+
40+
func TestHook_Commit(t *testing.T) {
41+
breakerHook := NewBreakerHook(breaker.NewBreaker())
42+
43+
ctx, err := breakerHook.BeforeCommit(context.Background(), nil)
44+
assert.True(t, ctx == context.Background())
45+
assert.NoError(t, err)
46+
47+
ctx, err = breakerHook.AfterCommit(context.Background(), nil)
48+
assert.True(t, ctx == context.Background())
49+
assert.NoError(t, err)
50+
}
51+
52+
func TestHook_ExecContext(t *testing.T) {
53+
54+
breakerHook := NewBreakerHook(breaker.NewBreaker())
55+
checkWithContext(t, func(ctx context.Context) (context.Context, error) {
56+
ctx, _, _, err := breakerHook.BeforeExecContext(ctx, "", nil, nil)
57+
return ctx, err
58+
}, func(ctx context.Context, err error) (context.Context, error) {
59+
ctx, _, err = breakerHook.AfterExecContext(ctx, "", nil, nil, err)
60+
return ctx, err
61+
})
62+
}
63+
64+
func TestHook_PrepareContext(t *testing.T) {
65+
b := breaker.NewBreaker()
66+
breakerHook := NewBreakerHook(b)
67+
checkWithContext(t, func(ctx context.Context) (context.Context, error) {
68+
ctx, _, err := breakerHook.BeforePrepareContext(ctx, "", nil)
69+
70+
return ctx, err
71+
}, func(ctx context.Context, err error) (context.Context, error) {
72+
ctx, _, err = breakerHook.AfterPrepareContext(ctx, "", nil, err)
73+
return ctx, err
74+
})
75+
}
76+
77+
func TestHook_QueryContext(t *testing.T) {
78+
breakerHook := NewBreakerHook(breaker.NewBreaker())
79+
checkWithContext(t, func(ctx context.Context) (context.Context, error) {
80+
ctx, _, _, err := breakerHook.BeforeQueryContext(ctx, "", nil, nil)
81+
82+
return ctx, err
83+
}, func(ctx context.Context, err error) (context.Context, error) {
84+
ctx, _, err = breakerHook.AfterQueryContext(ctx, "", nil, nil, err)
85+
return ctx, err
86+
})
87+
}
88+
89+
func TestHook_Rollback(t *testing.T) {
90+
breakerHook := NewBreakerHook(breaker.NewBreaker())
91+
check(t, func(ctx context.Context) (context.Context, error) {
92+
ctx, err := breakerHook.BeforeRollback(ctx, nil)
93+
return ctx, err
94+
}, func(ctx context.Context) (context.Context, error) {
95+
ctx, err := breakerHook.AfterRollback(ctx, nil)
96+
return ctx, err
97+
})
98+
}
99+
100+
func TestHook_StmtExecContext(t *testing.T) {
101+
breakerHook := NewBreakerHook(breaker.NewBreaker())
102+
checkWithContext(t, func(ctx context.Context) (context.Context, error) {
103+
ctx, _, err := breakerHook.BeforeStmtExecContext(ctx, "", nil, nil)
104+
return ctx, err
105+
}, func(ctx context.Context, err error) (context.Context, error) {
106+
ctx, _, err = breakerHook.AfterStmtExecContext(ctx, "", nil, nil, err)
107+
return ctx, err
108+
})
109+
}
110+
111+
func TestHook_StmtQueryContext(t *testing.T) {
112+
breakerHook := NewBreakerHook(breaker.NewBreaker())
113+
checkWithContext(t, func(ctx context.Context) (context.Context, error) {
114+
ctx, _, err := breakerHook.BeforeStmtQueryContext(ctx, "", nil, nil)
115+
return ctx, err
116+
}, func(ctx context.Context, err error) (context.Context, error) {
117+
ctx, _, err = breakerHook.AfterStmtQueryContext(ctx, "", nil, nil, err)
118+
return ctx, err
119+
})
120+
}
121+
122+
func checkWithContext(t *testing.T, before func(ctx context.Context) (context.Context, error), after func(ctx context.Context, err error) (context.Context, error)) {
123+
t.Run("allow", func(t *testing.T) {
124+
for i := 0; i < 100; i++ {
125+
ctx, err := before(context.Background())
126+
127+
assert.True(t, ctx.Value(allowKey{}) != nil)
128+
assert.NoError(t, err)
129+
130+
if i%2 == 0 {
131+
ctx, err = after(ctx, nil)
132+
assert.NoError(t, err)
133+
} else {
134+
ctx, err = after(ctx, sql.ErrNoRows)
135+
assert.ErrorIs(t, err, sql.ErrNoRows)
136+
}
137+
138+
assert.True(t, ctx.Value(allowKey{}) != nil)
139+
}
140+
})
141+
142+
t.Run("not allowed", func(t *testing.T) {
143+
b := breaker.NewBreaker()
144+
breakerHook := NewBreakerHook(b)
145+
146+
openBreaker := false
147+
for i := 0; i < 1000; i++ {
148+
ctx, _, _, err := breakerHook.BeforeExecContext(context.Background(), "", nil, nil)
149+
150+
if err == breaker.ErrServiceUnavailable {
151+
openBreaker = true
152+
assert.True(t, ctx.Value(allowKey{}) == nil)
153+
} else {
154+
assert.True(t, ctx.Value(allowKey{}) != nil)
155+
}
156+
157+
_, _, err = breakerHook.AfterExecContext(ctx, "", nil, nil, errors.New("any"))
158+
assert.Error(t, err)
159+
}
160+
161+
assert.True(t, openBreaker)
162+
})
163+
}
164+
165+
func check(t *testing.T, before, after func(ctx context.Context) (context.Context, error)) {
166+
167+
ctx, err := before(context.Background())
168+
assert.True(t, ctx == context.Background())
169+
assert.NoError(t, err)
170+
171+
ctx, err = after(ctx)
172+
assert.True(t, ctx == context.Background())
173+
assert.NoError(t, err)
174+
}

pkg/breaker/breaker.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ func newLoggedThrottle(name string, t internalThrottle) loggedThrottle {
110110

111111
func (lt loggedThrottle) allow() (Promise, error) {
112112
promise, err := lt.internalThrottle.allow()
113+
if err != nil {
114+
return nil, err
115+
}
116+
113117
return promiseWithReason{
114118
promise: promise,
115119
errWin: lt.errWin,

pkg/breaker/breaker_test.go

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,32 @@ import (
99
)
1010

1111
func TestCircuitBreaker_Allow(t *testing.T) {
12-
b := NewBreaker()
13-
assert.True(t, len(b.Name()) > 0)
14-
_, err := b.Allow()
15-
assert.Nil(t, err)
12+
t.Run("allow", func(t *testing.T) {
13+
b := NewBreaker(WithName("any"))
14+
assert.True(t, len(b.Name()) > 0)
15+
for i := 0; i < 1000; i++ {
16+
allow, err := b.Allow()
17+
assert.Nil(t, err)
18+
allow.Accept()
19+
}
20+
})
21+
22+
t.Run("not allowed", func(t *testing.T) {
23+
b := NewBreaker()
24+
assert.True(t, len(b.Name()) > 0)
25+
openBreaker := false
26+
for i := 0; i < 1000; i++ {
27+
allow, err := b.Allow()
28+
if err == ErrServiceUnavailable {
29+
openBreaker = true
30+
} else {
31+
allow.Reject("any")
32+
assert.Nil(t, err)
33+
}
34+
}
35+
36+
assert.True(t, openBreaker)
37+
})
1638
}
1739

1840
func TestErrorWindow(t *testing.T) {

pkg/breaker/nopbreaker.go

Lines changed: 0 additions & 25 deletions
This file was deleted.

pkg/breaker/nopbreaker_test.go

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)