Skip to content

Commit 55df21f

Browse files
author
Muir Manders
authored
Use httpsnoop to wrap ResponseWriter. (#193)
Wrapping http.ResponseWriter is fraught with danger. Our compress handler made sure to implement all the optional ResponseWriter interfaces, but that made it implement them even if the underlying writer did not. For example, if the underlying ResponseWriter was _not_ an http.Hijacker, the compress writer nonetheless appeared to implement http.Hijacker, but would panic if you called Hijack(). On the other hand, the logging handler checked for certain combinations of optional interfaces and only implemented them as appropriate. However, it didn't check for all optional interfaces or all combinations, so most optional interfaces would still get lost. Fix both problems by using httpsnoop to do the wrapping. It uses code generation to ensure correctness, and it handles std lib changes like the http.Pusher addition in Go 1.8. Fixes #169.
1 parent 2188616 commit 55df21f

File tree

10 files changed

+98
-145
lines changed

10 files changed

+98
-145
lines changed

compress.go

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,30 @@ import (
1010
"io"
1111
"net/http"
1212
"strings"
13+
14+
"github.com/felixge/httpsnoop"
1315
)
1416

1517
const acceptEncoding string = "Accept-Encoding"
1618

1719
type compressResponseWriter struct {
18-
io.Writer
19-
http.ResponseWriter
20-
http.Hijacker
21-
http.Flusher
22-
http.CloseNotifier
23-
}
24-
25-
func (w *compressResponseWriter) WriteHeader(c int) {
26-
w.ResponseWriter.Header().Del("Content-Length")
27-
w.ResponseWriter.WriteHeader(c)
20+
compressor io.Writer
21+
w http.ResponseWriter
2822
}
2923

30-
func (w *compressResponseWriter) Header() http.Header {
31-
return w.ResponseWriter.Header()
24+
func (cw *compressResponseWriter) WriteHeader(c int) {
25+
cw.w.Header().Del("Content-Length")
26+
cw.w.WriteHeader(c)
3227
}
3328

34-
func (w *compressResponseWriter) Write(b []byte) (int, error) {
35-
h := w.ResponseWriter.Header()
29+
func (cw *compressResponseWriter) Write(b []byte) (int, error) {
30+
h := cw.w.Header()
3631
if h.Get("Content-Type") == "" {
3732
h.Set("Content-Type", http.DetectContentType(b))
3833
}
3934
h.Del("Content-Length")
4035

41-
return w.Writer.Write(b)
36+
return cw.compressor.Write(b)
4237
}
4338

4439
type flusher interface {
@@ -47,12 +42,12 @@ type flusher interface {
4742

4843
func (w *compressResponseWriter) Flush() {
4944
// Flush compressed data if compressor supports it.
50-
if f, ok := w.Writer.(flusher); ok {
45+
if f, ok := w.compressor.(flusher); ok {
5146
f.Flush()
5247
}
5348
// Flush HTTP response.
54-
if w.Flusher != nil {
55-
w.Flusher.Flush()
49+
if f, ok := w.w.(http.Flusher); ok {
50+
f.Flush()
5651
}
5752
}
5853

@@ -119,28 +114,22 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler {
119114
w.Header().Set("Content-Encoding", encoding)
120115
r.Header.Del(acceptEncoding)
121116

122-
hijacker, ok := w.(http.Hijacker)
123-
if !ok { /* w is not Hijacker... oh well... */
124-
hijacker = nil
117+
cw := &compressResponseWriter{
118+
w: w,
119+
compressor: encWriter,
125120
}
126121

127-
flusher, ok := w.(http.Flusher)
128-
if !ok {
129-
flusher = nil
130-
}
131-
132-
closeNotifier, ok := w.(http.CloseNotifier)
133-
if !ok {
134-
closeNotifier = nil
135-
}
136-
137-
w = &compressResponseWriter{
138-
Writer: encWriter,
139-
ResponseWriter: w,
140-
Hijacker: hijacker,
141-
Flusher: flusher,
142-
CloseNotifier: closeNotifier,
143-
}
122+
w = httpsnoop.Wrap(w, httpsnoop.Hooks{
123+
Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc {
124+
return cw.Write
125+
},
126+
WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
127+
return cw.WriteHeader
128+
},
129+
Flush: func(httpsnoop.FlushFunc) httpsnoop.FlushFunc {
130+
return cw.Flush
131+
},
132+
})
144133

145134
h.ServeHTTP(w, r)
146135
})

compress_test.go

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ func compressedRequest(w *httptest.ResponseRecorder, compression string) {
2929
acceptEncoding: []string{compression},
3030
},
3131
})
32-
3332
}
3433

3534
func TestCompressHandlerNoCompression(t *testing.T) {
@@ -165,6 +164,7 @@ type fullyFeaturedResponseWriter struct{}
165164
func (fullyFeaturedResponseWriter) Header() http.Header {
166165
return http.Header{}
167166
}
167+
168168
func (fullyFeaturedResponseWriter) Write([]byte) (int, error) {
169169
return 0, nil
170170
}
@@ -193,9 +193,6 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
193193
)
194194
var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
195195
comp := r.Header.Get(acceptEncoding)
196-
if _, ok := rw.(*compressResponseWriter); !ok {
197-
t.Fatalf("ResponseWriter wasn't wrapped by compressResponseWriter, got %T type", rw)
198-
}
199196
if _, ok := rw.(http.Flusher); !ok {
200197
t.Errorf("ResponseWriter lost http.Flusher interface for %q", comp)
201198
}
@@ -207,9 +204,7 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
207204
}
208205
})
209206
h = CompressHandler(h)
210-
var (
211-
rw fullyFeaturedResponseWriter
212-
)
207+
var rw fullyFeaturedResponseWriter
213208
r, err := http.NewRequest("GET", "/", nil)
214209
if err != nil {
215210
t.Fatalf("Failed to create test request: %v", err)
@@ -220,3 +215,32 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
220215
r.Header.Set(acceptEncoding, "deflate")
221216
h.ServeHTTP(rw, r)
222217
}
218+
219+
type paltryResponseWriter struct{}
220+
221+
func (paltryResponseWriter) Header() http.Header {
222+
return http.Header{}
223+
}
224+
225+
func (paltryResponseWriter) Write([]byte) (int, error) {
226+
return 0, nil
227+
}
228+
func (paltryResponseWriter) WriteHeader(int) {}
229+
230+
func TestCompressHandlerDoesntInventInterfaces(t *testing.T) {
231+
var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
232+
if _, ok := rw.(http.Hijacker); ok {
233+
t.Error("ResponseWriter shouldn't implement http.Hijacker")
234+
}
235+
})
236+
237+
h = CompressHandler(h)
238+
239+
var rw paltryResponseWriter
240+
r, err := http.NewRequest("GET", "/", nil)
241+
if err != nil {
242+
t.Fatalf("Failed to create test request: %v", err)
243+
}
244+
r.Header.Set(acceptEncoding, "gzip")
245+
h.ServeHTTP(rw, r)
246+
}

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
module github.com/gorilla/handlers
22

33
go 1.14
4+
5+
require github.com/felixge/httpsnoop v1.0.1

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ=
2+
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=

handlers.go

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ type responseLogger struct {
5151
size int
5252
}
5353

54-
func (l *responseLogger) Header() http.Header {
55-
return l.w.Header()
56-
}
57-
5854
func (l *responseLogger) Write(b []byte) (int, error) {
5955
size, err := l.w.Write(b)
6056
l.size += size
@@ -74,39 +70,16 @@ func (l *responseLogger) Size() int {
7470
return l.size
7571
}
7672

77-
func (l *responseLogger) Flush() {
78-
f, ok := l.w.(http.Flusher)
79-
if ok {
80-
f.Flush()
81-
}
82-
}
83-
84-
type hijackLogger struct {
85-
responseLogger
86-
}
87-
88-
func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
89-
h := l.responseLogger.w.(http.Hijacker)
90-
conn, rw, err := h.Hijack()
91-
if err == nil && l.responseLogger.status == 0 {
73+
func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
74+
conn, rw, err := l.w.(http.Hijacker).Hijack()
75+
if err == nil && l.status == 0 {
9276
// The status will be StatusSwitchingProtocols if there was no error and
9377
// WriteHeader has not been called yet
94-
l.responseLogger.status = http.StatusSwitchingProtocols
78+
l.status = http.StatusSwitchingProtocols
9579
}
9680
return conn, rw, err
9781
}
9882

99-
type closeNotifyWriter struct {
100-
loggingResponseWriter
101-
http.CloseNotifier
102-
}
103-
104-
type hijackCloseNotifier struct {
105-
loggingResponseWriter
106-
http.Hijacker
107-
http.CloseNotifier
108-
}
109-
11083
// isContentType validates the Content-Type header matches the supplied
11184
// contentType. That is, its type and subtype match.
11285
func isContentType(h http.Header, contentType string) bool {

handlers_go18.go

Lines changed: 0 additions & 29 deletions
This file was deleted.

handlers_go18_test.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@ import (
99
"testing"
1010
)
1111

12+
// *httptest.ResponseRecorder doesn't implement Pusher, so wrap it.
13+
type pushRecorder struct {
14+
*httptest.ResponseRecorder
15+
}
16+
17+
func (pr pushRecorder) Push(target string, opts *http.PushOptions) error {
18+
return nil
19+
}
20+
1221
func TestLoggingHandlerWithPush(t *testing.T) {
1322
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
1423
if _, ok := w.(http.Pusher); !ok {
@@ -18,7 +27,7 @@ func TestLoggingHandlerWithPush(t *testing.T) {
1827
})
1928

2029
logger := LoggingHandler(ioutil.Discard, handler)
21-
logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/"))
30+
logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/"))
2231
}
2332

2433
func TestCombinedLoggingHandlerWithPush(t *testing.T) {
@@ -30,5 +39,5 @@ func TestCombinedLoggingHandlerWithPush(t *testing.T) {
3039
})
3140

3241
logger := CombinedLoggingHandler(ioutil.Discard, handler)
33-
logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/"))
42+
logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/"))
3443
}

handlers_pre18.go

Lines changed: 0 additions & 7 deletions
This file was deleted.

logging.go

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
"strconv"
1313
"time"
1414
"unicode/utf8"
15+
16+
"github.com/felixge/httpsnoop"
1517
)
1618

1719
// Logging
@@ -39,10 +41,10 @@ type loggingHandler struct {
3941

4042
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
4143
t := time.Now()
42-
logger := makeLogger(w)
44+
logger, w := makeLogger(w)
4345
url := *req.URL
4446

45-
h.handler.ServeHTTP(logger, req)
47+
h.handler.ServeHTTP(w, req)
4648
if req.MultipartForm != nil {
4749
req.MultipartForm.RemoveAll()
4850
}
@@ -58,27 +60,16 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
5860
h.formatter(h.writer, params)
5961
}
6062

61-
func makeLogger(w http.ResponseWriter) loggingResponseWriter {
62-
var logger loggingResponseWriter = &responseLogger{w: w, status: http.StatusOK}
63-
if _, ok := w.(http.Hijacker); ok {
64-
logger = &hijackLogger{responseLogger{w: w, status: http.StatusOK}}
65-
}
66-
h, ok1 := logger.(http.Hijacker)
67-
c, ok2 := w.(http.CloseNotifier)
68-
if ok1 && ok2 {
69-
return hijackCloseNotifier{logger, h, c}
70-
}
71-
if ok2 {
72-
return &closeNotifyWriter{logger, c}
73-
}
74-
return logger
75-
}
76-
77-
type commonLoggingResponseWriter interface {
78-
http.ResponseWriter
79-
http.Flusher
80-
Status() int
81-
Size() int
63+
func makeLogger(w http.ResponseWriter) (*responseLogger, http.ResponseWriter) {
64+
logger := &responseLogger{w: w, status: http.StatusOK}
65+
return logger, httpsnoop.Wrap(w, httpsnoop.Hooks{
66+
Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc {
67+
return logger.Write
68+
},
69+
WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
70+
return logger.WriteHeader
71+
},
72+
})
8273
}
8374

8475
const lowerhex = "0123456789abcdef"
@@ -145,7 +136,6 @@ func appendQuoted(buf []byte, s string) []byte {
145136
}
146137
}
147138
return buf
148-
149139
}
150140

151141
// buildCommonLogLine builds a log entry for req in Apache Common Log Format.
@@ -160,7 +150,6 @@ func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int
160150
}
161151

162152
host, _, err := net.SplitHostPort(req.RemoteAddr)
163-
164153
if err != nil {
165154
host = req.RemoteAddr
166155
}

0 commit comments

Comments
 (0)