Skip to content

Commit ba260b6

Browse files
Dean KarnDean Karn
authored andcommitted
update gzip middleware to handle no return content
1 parent 98b3480 commit ba260b6

File tree

2 files changed

+49
-11
lines changed

2 files changed

+49
-11
lines changed

middleware/gzip.go

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,24 @@ type gzipWriter struct {
1919
sniffComplete bool
2020
}
2121

22-
func (w gzipWriter) Write(b []byte) (int, error) {
22+
func (w *gzipWriter) Write(b []byte) (int, error) {
2323

2424
if !w.sniffComplete {
2525
if w.Header().Get(lars.ContentType) == "" {
2626
w.Header().Set(lars.ContentType, http.DetectContentType(b))
2727
}
28+
2829
w.sniffComplete = true
2930
}
3031

3132
return w.Writer.Write(b)
3233
}
3334

34-
func (w gzipWriter) Flush() error {
35+
func (w *gzipWriter) Flush() error {
3536
return w.Writer.(*gzip.Writer).Flush()
3637
}
3738

38-
func (w gzipWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
39+
func (w *gzipWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
3940
return w.ResponseWriter.(http.Hijacker).Hijack()
4041
}
4142

@@ -45,7 +46,7 @@ func (w *gzipWriter) CloseNotify() <-chan bool {
4546

4647
var writerPool = sync.Pool{
4748
New: func() interface{} {
48-
return gzip.NewWriter(ioutil.Discard)
49+
return &gzipWriter{Writer: gzip.NewWriter(ioutil.Discard)}
4950
},
5051
}
5152

@@ -57,15 +58,25 @@ func Gzip(c lars.Context) {
5758

5859
if strings.Contains(c.Request().Header.Get(lars.AcceptEncoding), lars.Gzip) {
5960

60-
w := writerPool.Get().(*gzip.Writer)
61+
gw := writerPool.Get().(*gzipWriter)
62+
gw.sniffComplete = false
63+
w := gw.Writer.(*gzip.Writer)
6164
w.Reset(c.Response().Writer())
65+
gw.ResponseWriter = c.Response().Writer()
6266

6367
defer func() {
68+
69+
if !gw.sniffComplete {
70+
// We have to reset response to it's pristine state when
71+
// nothing is written to body.
72+
c.Response().Header().Del(lars.ContentEncoding)
73+
w.Reset(ioutil.Discard)
74+
}
75+
6476
w.Close()
65-
writerPool.Put(w)
77+
writerPool.Put(gw)
6678
}()
6779

68-
gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}
6980
c.Response().Header().Set(lars.ContentEncoding, lars.Gzip)
7081
c.Response().SetWriter(gw)
7182
}
@@ -87,7 +98,8 @@ func GzipLevel(level int) lars.HandlerFunc {
8798
var pool = sync.Pool{
8899
New: func() interface{} {
89100
z, _ := gzip.NewWriterLevel(ioutil.Discard, level)
90-
return z
101+
102+
return &gzipWriter{Writer: z}
91103
},
92104
}
93105

@@ -96,15 +108,25 @@ func GzipLevel(level int) lars.HandlerFunc {
96108

97109
if strings.Contains(c.Request().Header.Get(lars.AcceptEncoding), lars.Gzip) {
98110

99-
w := pool.Get().(*gzip.Writer)
111+
gw := pool.Get().(*gzipWriter)
112+
gw.sniffComplete = false
113+
w := gw.Writer.(*gzip.Writer)
100114
w.Reset(c.Response().Writer())
115+
gw.ResponseWriter = c.Response().Writer()
101116

102117
defer func() {
118+
119+
if !gw.sniffComplete {
120+
// We have to reset response to it's pristine state when
121+
// nothing is written to body.
122+
c.Response().Header().Del(lars.ContentEncoding)
123+
w.Reset(ioutil.Discard)
124+
}
125+
103126
w.Close()
104-
pool.Put(w)
127+
pool.Put(gw)
105128
}()
106129

107-
gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}
108130
c.Response().Header().Set(lars.ContentEncoding, lars.Gzip)
109131
c.Response().SetWriter(gw)
110132
}

middleware/gzip_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ func TestGzip(t *testing.T) {
3333
l.Get("/test", func(c lars.Context) {
3434
c.Response().Write([]byte("test"))
3535
})
36+
l.Get("/empty", func(c lars.Context) {
37+
})
3638

3739
server := httptest.NewServer(l.Serve())
3840
defer server.Close()
@@ -65,6 +67,12 @@ func TestGzip(t *testing.T) {
6567
b, err = ioutil.ReadAll(r)
6668
Equal(t, err, nil)
6769
Equal(t, string(b), "test")
70+
71+
req, _ = http.NewRequest(http.MethodGet, server.URL+"/empty", nil)
72+
73+
resp, err = client.Do(req)
74+
Equal(t, err, nil)
75+
Equal(t, resp.StatusCode, http.StatusOK)
6876
}
6977

7078
func TestGzipLevel(t *testing.T) {
@@ -77,6 +85,8 @@ func TestGzipLevel(t *testing.T) {
7785
l.Get("/test", func(c lars.Context) {
7886
c.Response().Write([]byte("test"))
7987
})
88+
l.Get("/empty", func(c lars.Context) {
89+
})
8090

8191
server := httptest.NewServer(l.Serve())
8292
defer server.Close()
@@ -109,6 +119,12 @@ func TestGzipLevel(t *testing.T) {
109119
b, err = ioutil.ReadAll(r)
110120
Equal(t, err, nil)
111121
Equal(t, string(b), "test")
122+
123+
req, _ = http.NewRequest(http.MethodGet, server.URL+"/empty", nil)
124+
125+
resp, err = client.Do(req)
126+
Equal(t, err, nil)
127+
Equal(t, resp.StatusCode, http.StatusOK)
112128
}
113129

114130
func TestGzipFlush(t *testing.T) {

0 commit comments

Comments
 (0)