Skip to content

Commit 84b2b86

Browse files
authored
feat: add recovery middleware (#44)
1 parent e5e767f commit 84b2b86

File tree

8 files changed

+176
-0
lines changed

8 files changed

+176
-0
lines changed

context.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,8 @@ func (c *Context) Get(key string) (any, bool) {
194194
val, ok := c.storage[key]
195195
return val, ok
196196
}
197+
198+
// Debug returns whether we are in debug mode or not.
199+
func (c *Context) Debug() bool {
200+
return c.kid.Debug()
201+
}

context_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,12 @@ func TestContext_HTMLString(t *testing.T) {
457457
assert.Equal(t, "<p>Hello</p>", res.Body.String())
458458
assert.Equal(t, "text/html", res.Header().Get("Content-Type"))
459459
}
460+
461+
func TestContext_Debug(t *testing.T) {
462+
k := New()
463+
k.debug = true
464+
465+
ctx := newContext(k)
466+
467+
assert.True(t, ctx.Debug())
468+
}

kid.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,13 @@ func (k *Kid) Debug() bool {
236236
return k.debug
237237
}
238238

239+
// NewContext basically is a helper function and can be used in testing.
240+
func (k *Kid) NewContext(req *http.Request, res http.ResponseWriter) *Context {
241+
ctx := newContext(k)
242+
ctx.reset(req, res)
243+
return ctx
244+
}
245+
239246
// ApplyOptions applies the given options.
240247
func (k *Kid) ApplyOptions(opts ...Option) {
241248
for _, opt := range opts {

kid_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,3 +694,11 @@ func TestPanicIfNil(t *testing.T) {
694694
panicIfNil(x, "")
695695
})
696696
}
697+
698+
func TestKid_NewContext(t *testing.T) {
699+
k := New()
700+
701+
ctx := k.NewContext(nil, nil)
702+
703+
assert.NotNil(t, ctx)
704+
}

middlewares/recovery.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package middlewares
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"os"
8+
"runtime/debug"
9+
10+
"github.com/mojixcoder/kid"
11+
)
12+
13+
// RecoveryConfig is the config used to build a Recovery middleware.
14+
type RecoveryConfig struct {
15+
// LogRecovers logs when a recovery happens, only in debug mode.
16+
LogRecovers bool
17+
18+
// PrintStacktrace prints the entire stacktrace if true, only in debug mode.
19+
PrintStacktrace bool
20+
21+
// Writer is the writer for logging recoveries and stacktraces.
22+
Writer io.Writer
23+
24+
// OnRecovery is the function which will be called when a recovery occurs.
25+
OnRecovery func(c *kid.Context, err any)
26+
}
27+
28+
// DefaultRecoverConfig is the default Recovery config.
29+
var DefaultRecoveryConfig = RecoveryConfig{
30+
LogRecovers: true,
31+
Writer: os.Stdout,
32+
OnRecovery: func(c *kid.Context, err any) {
33+
c.JSON(http.StatusInternalServerError, kid.Map{"message": http.StatusText(http.StatusInternalServerError)})
34+
},
35+
}
36+
37+
// NewRecovery returns a new Recovery middleware.
38+
func NewRecovery() kid.MiddlewareFunc {
39+
return NewRecoveryWithConfig(DefaultRecoveryConfig)
40+
}
41+
42+
// NewRecoveryWithConfig returns a new Recovery middleware with the given config.
43+
func NewRecoveryWithConfig(cfg RecoveryConfig) kid.MiddlewareFunc {
44+
return func(next kid.HandlerFunc) kid.HandlerFunc {
45+
return func(c *kid.Context) {
46+
defer func() {
47+
if err := recover(); err != nil {
48+
if cfg.LogRecovers && c.Debug() {
49+
fmt.Fprintf(cfg.Writer, "[RECOVERY] panic recovered: %v\n", err)
50+
}
51+
52+
if cfg.PrintStacktrace && c.Debug() {
53+
stack := debug.Stack()
54+
fmt.Fprintf(cfg.Writer, "%s", string(stack))
55+
}
56+
57+
if cfg.OnRecovery != nil {
58+
cfg.OnRecovery(c, err)
59+
}
60+
}
61+
}()
62+
63+
next(c)
64+
}
65+
}
66+
}

middlewares/recovery_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package middlewares
2+
3+
import (
4+
"bytes"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/mojixcoder/kid"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
var flag bool
14+
15+
var recoveryHandler kid.HandlerFunc = func(c *kid.Context) {
16+
panic("err")
17+
}
18+
19+
func TestNewRecoveryWithConfig(t *testing.T) {
20+
k := kid.New()
21+
var buf bytes.Buffer
22+
23+
recovery := NewRecoveryWithConfig(RecoveryConfig{LogRecovers: true, Writer: &buf})
24+
25+
ctx := k.NewContext(nil, httptest.NewRecorder())
26+
recovery(recoveryHandler)(ctx)
27+
assert.Equal(t, "[RECOVERY] panic recovered: err\n", buf.String())
28+
29+
buf.Reset()
30+
recovery = NewRecoveryWithConfig(RecoveryConfig{PrintStacktrace: true, Writer: &buf})
31+
32+
ctx = k.NewContext(nil, httptest.NewRecorder())
33+
recovery(recoveryHandler)(ctx)
34+
assert.NotEmpty(t, buf.String())
35+
36+
buf.Reset()
37+
k.ApplyOptions(kid.WithDebug(false))
38+
recovery(recoveryHandler)(ctx)
39+
assert.Empty(t, buf.String())
40+
41+
buf.Reset()
42+
recovery = NewRecoveryWithConfig(RecoveryConfig{
43+
OnRecovery: func(c *kid.Context, err any) {
44+
flag = true
45+
},
46+
})
47+
48+
ctx = k.NewContext(nil, httptest.NewRecorder())
49+
recovery(recoveryHandler)(ctx)
50+
assert.True(t, flag)
51+
}
52+
53+
func TestNewRecovery(t *testing.T) {
54+
k := kid.New()
55+
56+
recovery := NewRecovery()
57+
58+
res := httptest.NewRecorder()
59+
ctx := k.NewContext(nil, res)
60+
recovery(recoveryHandler)(ctx)
61+
62+
assert.Equal(t, res.Code, http.StatusInternalServerError)
63+
assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", res.Body.String())
64+
}

response.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ type (
2222

2323
// Written returns true if response has already been written otherwise returns false.
2424
Written() bool
25+
26+
// Status returns the status code.
27+
Status() int
2528
}
2629

2730
// response implements ResponseWriter.
@@ -85,6 +88,11 @@ func (r *response) Written() bool {
8588
return r.written
8689
}
8790

91+
// Status returns the status code.
92+
func (r *response) Status() int {
93+
return r.status
94+
}
95+
8896
// Flush implements the http.Flusher interface.
8997
func (r *response) Flush() {
9098
r.WriteHeaderNow()

response_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,12 @@ func TestResponseWriter_Hijack(t *testing.T) {
9595
})
9696
assert.True(t, res.Written())
9797
}
98+
99+
func TestResponseWriter_Status(t *testing.T) {
100+
w := httptest.NewRecorder()
101+
res := newResponse(w).(*response)
102+
103+
res.WriteHeader(http.StatusAccepted)
104+
105+
assert.Equal(t, http.StatusAccepted, res.Status())
106+
}

0 commit comments

Comments
 (0)