Skip to content

Commit 194129d

Browse files
authored
Merge pull request #1699 from pafuent/improve_decompress_middleware
Adding sync.Pool to Decompress middleware
2 parents 8c27828 + 2386e17 commit 194129d

File tree

3 files changed

+131
-8
lines changed

3 files changed

+131
-8
lines changed

middleware/compress.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
5959
config.Level = DefaultGzipConfig.Level
6060
}
6161

62-
pool := gzipPool(config)
62+
pool := gzipCompressPool(config)
6363

6464
return func(next echo.HandlerFunc) echo.HandlerFunc {
6565
return func(c echo.Context) error {
@@ -133,7 +133,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
133133
return http.ErrNotSupported
134134
}
135135

136-
func gzipPool(config GzipConfig) sync.Pool {
136+
func gzipCompressPool(config GzipConfig) sync.Pool {
137137
return sync.Pool{
138138
New: func() interface{} {
139139
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)

middleware/decompress.go

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,115 @@ package middleware
33
import (
44
"bytes"
55
"compress/gzip"
6-
"github.com/labstack/echo/v4"
76
"io"
87
"io/ioutil"
8+
"net/http"
9+
"sync"
10+
11+
"github.com/labstack/echo/v4"
912
)
1013

1114
type (
1215
// DecompressConfig defines the config for Decompress middleware.
1316
DecompressConfig struct {
1417
// Skipper defines a function to skip middleware.
1518
Skipper Skipper
19+
20+
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
21+
GzipDecompressPool Decompressor
1622
}
1723
)
1824

1925
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
2026
const GZIPEncoding string = "gzip"
2127

28+
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
29+
type Decompressor interface {
30+
gzipDecompressPool() sync.Pool
31+
}
32+
2233
var (
2334
//DefaultDecompressConfig defines the config for decompress middleware
24-
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper}
35+
DefaultDecompressConfig = DecompressConfig{
36+
Skipper: DefaultSkipper,
37+
GzipDecompressPool: &DefaultGzipDecompressPool{},
38+
}
2539
)
2640

41+
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
42+
type DefaultGzipDecompressPool struct {
43+
}
44+
45+
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
46+
return sync.Pool{
47+
New: func() interface{} {
48+
// create with an empty reader (but with GZIP header)
49+
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
50+
if err != nil {
51+
return err
52+
}
53+
54+
b := new(bytes.Buffer)
55+
w.Reset(b)
56+
w.Flush()
57+
w.Close()
58+
59+
r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
60+
if err != nil {
61+
return err
62+
}
63+
return r
64+
},
65+
}
66+
}
67+
2768
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
2869
func Decompress() echo.MiddlewareFunc {
2970
return DecompressWithConfig(DefaultDecompressConfig)
3071
}
3172

3273
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
3374
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
75+
// Defaults
76+
if config.Skipper == nil {
77+
config.Skipper = DefaultGzipConfig.Skipper
78+
}
79+
if config.GzipDecompressPool == nil {
80+
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
81+
}
82+
3483
return func(next echo.HandlerFunc) echo.HandlerFunc {
84+
pool := config.GzipDecompressPool.gzipDecompressPool()
3585
return func(c echo.Context) error {
3686
if config.Skipper(c) {
3787
return next(c)
3888
}
3989
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
4090
case GZIPEncoding:
41-
gr, err := gzip.NewReader(c.Request().Body)
42-
if err != nil {
91+
b := c.Request().Body
92+
93+
i := pool.Get()
94+
gr, ok := i.(*gzip.Reader)
95+
if !ok {
96+
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
97+
}
98+
99+
if err := gr.Reset(b); err != nil {
100+
pool.Put(gr)
43101
if err == io.EOF { //ignore if body is empty
44102
return next(c)
45103
}
46104
return err
47105
}
48-
defer gr.Close()
49106
var buf bytes.Buffer
50107
io.Copy(&buf, gr)
108+
109+
gr.Close()
110+
pool.Put(gr)
111+
112+
b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
113+
51114
r := ioutil.NopCloser(&buf)
52-
defer r.Close()
53115
c.Request().Body = r
54116
}
55117
return next(c)

middleware/decompress_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package middleware
33
import (
44
"bytes"
55
"compress/gzip"
6+
"errors"
67
"io/ioutil"
78
"net/http"
89
"net/http/httptest"
910
"strings"
11+
"sync"
1012
"testing"
1113

1214
"github.com/labstack/echo/v4"
@@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) {
4345
assert.Equal(body, string(b))
4446
}
4547

48+
func TestDecompressDefaultConfig(t *testing.T) {
49+
e := echo.New()
50+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
51+
rec := httptest.NewRecorder()
52+
c := e.NewContext(req, rec)
53+
54+
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
55+
c.Response().Write([]byte("test")) // For Content-Type sniffing
56+
return nil
57+
})
58+
h(c)
59+
60+
assert := assert.New(t)
61+
assert.Equal("test", rec.Body.String())
62+
63+
// Decompress
64+
body := `{"name": "echo"}`
65+
gz, _ := gzipString(body)
66+
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
67+
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
68+
rec = httptest.NewRecorder()
69+
c = e.NewContext(req, rec)
70+
h(c)
71+
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
72+
b, err := ioutil.ReadAll(req.Body)
73+
assert.NoError(err)
74+
assert.Equal(body, string(b))
75+
}
76+
4677
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
4778
e := echo.New()
4879
body := `{"name":"echo"}`
@@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) {
108139
assert.Equal(t, body, string(reqBody))
109140
}
110141

142+
type TestDecompressPoolWithError struct {
143+
}
144+
145+
func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
146+
return sync.Pool{
147+
New: func() interface{} {
148+
return errors.New("pool error")
149+
},
150+
}
151+
}
152+
153+
func TestDecompressPoolError(t *testing.T) {
154+
e := echo.New()
155+
e.Use(DecompressWithConfig(DecompressConfig{
156+
Skipper: DefaultSkipper,
157+
GzipDecompressPool: &TestDecompressPoolWithError{},
158+
}))
159+
body := `{"name": "echo"}`
160+
req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
161+
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
162+
rec := httptest.NewRecorder()
163+
c := e.NewContext(req, rec)
164+
e.ServeHTTP(rec, req)
165+
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
166+
reqBody, err := ioutil.ReadAll(c.Request().Body)
167+
assert.NoError(t, err)
168+
assert.Equal(t, body, string(reqBody))
169+
assert.Equal(t, rec.Code, http.StatusInternalServerError)
170+
}
171+
111172
func BenchmarkDecompress(b *testing.B) {
112173
e := echo.New()
113174
body := `{"name": "echo"}`

0 commit comments

Comments
 (0)