Skip to content

Commit 704a825

Browse files
Merge pull request #105 from basecamp/handler-refactor
Refactor middleware naming
2 parents cccdad8 + 6b05f45 commit 704a825

9 files changed

+145
-50
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/klauspost/compress/gzhttp"
1010
)
1111

12-
func NewCompressionGuardMiddleware(next http.Handler) http.Handler {
12+
func NewCompressionGuardHandler(next http.Handler) http.Handler {
1313
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1414
// Check for user-specific headers in the request
1515
if hasUserSpecificRequestHeaders(r) {

internal/compression_guard_middleware_test.go renamed to internal/compression_guard_handler_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/stretchr/testify/assert"
1010
)
1111

12-
func TestCompressionGuardMiddleware(t *testing.T) {
12+
func TestCompressionGuardHandler(t *testing.T) {
1313
tests := []struct {
1414
name string
1515
requestHeaders map[string]string
@@ -85,7 +85,7 @@ func TestCompressionGuardMiddleware(t *testing.T) {
8585

8686
for _, tt := range tests {
8787
t.Run(tt.name, func(t *testing.T) {
88-
handler := NewCompressionGuardMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
88+
handler := NewCompressionGuardHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
8989
for k, v := range tt.responseHeader {
9090
w.Header().Set(k, v)
9191
}

internal/compression_handler.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package internal
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/klauspost/compress/gzhttp"
7+
)
8+
9+
func NewCompressionHandler(jitter int, disableOnAuth bool, next http.Handler) http.Handler {
10+
var wrapper func(http.Handler) http.HandlerFunc
11+
var err error
12+
13+
if jitter > 0 {
14+
wrapper, err = gzhttp.NewWrapper(
15+
gzhttp.MinSize(1024),
16+
gzhttp.CompressionLevel(6),
17+
gzhttp.RandomJitter(jitter, 0, false),
18+
)
19+
} else {
20+
wrapper, err = gzhttp.NewWrapper(
21+
gzhttp.MinSize(1024),
22+
gzhttp.CompressionLevel(6),
23+
)
24+
}
25+
26+
if err != nil {
27+
panic("failed to create gzip wrapper: " + err.Error())
28+
}
29+
30+
handler := wrapper(next)
31+
32+
if disableOnAuth {
33+
return NewCompressionGuardHandler(handler)
34+
}
35+
36+
return handler
37+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package internal
2+
3+
import (
4+
"compress/gzip"
5+
"io"
6+
"net/http"
7+
"net/http/httptest"
8+
"strings"
9+
"testing"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestCompressionHandler(t *testing.T) {
16+
largeBody := strings.Repeat("A", 2000)
17+
18+
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19+
w.Header().Set("Content-Type", "text/plain")
20+
_, err := w.Write([]byte(largeBody))
21+
require.NoError(t, err)
22+
})
23+
24+
t.Run("compresses responses", func(t *testing.T) {
25+
handler := NewCompressionHandler(0, false, upstream)
26+
27+
req := httptest.NewRequest("GET", "/", nil)
28+
req.Header.Set("Accept-Encoding", "gzip")
29+
rr := httptest.NewRecorder()
30+
31+
handler.ServeHTTP(rr, req)
32+
33+
assert.Equal(t, "gzip", rr.Header().Get("Content-Encoding"))
34+
35+
reader, err := gzip.NewReader(rr.Body)
36+
require.NoError(t, err)
37+
defer reader.Close()
38+
body, err := io.ReadAll(reader)
39+
require.NoError(t, err)
40+
assert.Equal(t, largeBody, string(body))
41+
})
42+
43+
t.Run("applies jitter when configured", func(t *testing.T) {
44+
handler := NewCompressionHandler(32, false, upstream)
45+
46+
req := httptest.NewRequest("GET", "/", nil)
47+
req.Header.Set("Accept-Encoding", "gzip")
48+
rr := httptest.NewRecorder()
49+
50+
handler.ServeHTTP(rr, req)
51+
52+
require.Equal(t, "gzip", rr.Header().Get("Content-Encoding"))
53+
54+
// Check for GZIP header with FCOMMENT flag (0x10)
55+
bodyBytes := rr.Body.Bytes()
56+
require.Greater(t, len(bodyBytes), 10)
57+
hasComment := (bodyBytes[3] & 0x10) != 0
58+
assert.True(t, hasComment, "Expected FCOMMENT flag due to jitter")
59+
})
60+
61+
t.Run("wraps with guard when disableOnAuth is true", func(t *testing.T) {
62+
handler := NewCompressionHandler(0, true, upstream)
63+
64+
req := httptest.NewRequest("GET", "/", nil)
65+
req.Header.Set("Accept-Encoding", "gzip")
66+
req.Header.Set("Cookie", "session=secret")
67+
rr := httptest.NewRecorder()
68+
69+
handler.ServeHTTP(rr, req)
70+
71+
// Should NOT be compressed due to Cookie header
72+
assert.Empty(t, rr.Header().Get("Content-Encoding"))
73+
assert.Equal(t, largeBody, rr.Body.String())
74+
})
75+
76+
t.Run("compresses authenticated requests when disableOnAuth is false", func(t *testing.T) {
77+
handler := NewCompressionHandler(0, false, upstream)
78+
79+
req := httptest.NewRequest("GET", "/", nil)
80+
req.Header.Set("Accept-Encoding", "gzip")
81+
req.Header.Set("Cookie", "session=secret")
82+
rr := httptest.NewRecorder()
83+
84+
handler.ServeHTTP(rr, req)
85+
86+
assert.Equal(t, "gzip", rr.Header().Get("Content-Encoding"))
87+
})
88+
}

internal/handler.go

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"log/slog"
55
"net/http"
66
"net/url"
7-
8-
"github.com/klauspost/compress/gzhttp"
97
)
108

119
type HandlerOptions struct {
@@ -26,46 +24,18 @@ func NewHandler(options HandlerOptions) http.Handler {
2624
handler := NewProxyHandler(options.targetUrl, options.badGatewayPage, options.forwardHeaders)
2725
handler = NewCacheHandler(options.cache, options.maxCacheableResponseBody, handler)
2826
handler = NewSendfileHandler(options.xSendfileEnabled, handler)
29-
handler = NewRequestStartMiddleware(handler)
27+
handler = NewRequestStartHandler(handler)
3028

3129
if options.gzipCompressionEnabled {
32-
var wrapper func(http.Handler) http.HandlerFunc
33-
var err error
34-
35-
if options.gzipCompressionJitter > 0 {
36-
wrapper, err = gzhttp.NewWrapper(
37-
gzhttp.MinSize(1024),
38-
gzhttp.CompressionLevel(6),
39-
gzhttp.RandomJitter(options.gzipCompressionJitter, 0, false),
40-
)
41-
} else {
42-
wrapper, err = gzhttp.NewWrapper(
43-
gzhttp.MinSize(1024),
44-
gzhttp.CompressionLevel(6),
45-
)
46-
}
47-
48-
if err != nil {
49-
// If we cannot create the wrapper with the requested configuration (including jitter),
50-
// we must fail hard rather than silently downgrading security or performance.
51-
panic("failed to create gzip wrapper: " + err.Error())
52-
}
53-
54-
gzipHandler := wrapper(handler)
55-
56-
if options.gzipCompressionDisableOnAuth {
57-
handler = NewCompressionGuardMiddleware(gzipHandler)
58-
} else {
59-
handler = gzipHandler
60-
}
30+
handler = NewCompressionHandler(options.gzipCompressionJitter, options.gzipCompressionDisableOnAuth, handler)
6131
}
6232

6333
if options.maxRequestBody > 0 {
6434
handler = http.MaxBytesHandler(handler, int64(options.maxRequestBody))
6535
}
6636

6737
if options.logRequests {
68-
handler = NewLoggingMiddleware(slog.Default(), handler)
38+
handler = NewLoggingHandler(slog.Default(), handler)
6939
}
7040

7141
return handler
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@ import (
99
"time"
1010
)
1111

12-
type LoggingMiddleware struct {
12+
type LoggingHandler struct {
1313
logger *slog.Logger
1414
next http.Handler
1515
}
1616

17-
func NewLoggingMiddleware(logger *slog.Logger, next http.Handler) *LoggingMiddleware {
18-
return &LoggingMiddleware{
17+
func NewLoggingHandler(logger *slog.Logger, next http.Handler) *LoggingHandler {
18+
return &LoggingHandler{
1919
logger: logger,
2020
next: next,
2121
}
2222
}
2323

24-
func (h *LoggingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
24+
func (h *LoggingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
2525
writer := newResponseWriter(w)
2626

2727
started := time.Now()
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ import (
1414
"github.com/stretchr/testify/require"
1515
)
1616

17-
func TestMiddleware_LoggingMiddleware(t *testing.T) {
17+
func TestLoggingHandler(t *testing.T) {
1818
out := &strings.Builder{}
1919
logger := slog.New(slog.NewJSONHandler(out, nil))
20-
middleware := NewLoggingMiddleware(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20+
handler := NewLoggingHandler(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2121
w.Header().Set("X-Cache", "miss")
2222
w.Header().Set("Content-Type", "text/html")
2323
w.WriteHeader(http.StatusCreated)
@@ -29,7 +29,7 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) {
2929
req.Header.Set("User-Agent", "Robot/1")
3030
req.Header.Set("Content-Type", "application/json")
3131

32-
middleware.ServeHTTP(httptest.NewRecorder(), req)
32+
handler.ServeHTTP(httptest.NewRecorder(), req)
3333

3434
logline := struct {
3535
Path string `json:"path"`
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"time"
77
)
88

9-
func NewRequestStartMiddleware(next http.Handler) http.Handler {
9+
func NewRequestStartHandler(next http.Handler) http.Handler {
1010
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1111
if r.Header.Get("X-Request-Start") == "" {
1212
timestamp := time.Now().UnixMilli()
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ import (
1010
"github.com/stretchr/testify/assert"
1111
)
1212

13-
func TestRequestStartMiddleware(t *testing.T) {
13+
func TestRequestStartHandler(t *testing.T) {
1414
var capturedHeader string
1515
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1616
capturedHeader = r.Header.Get("X-Request-Start")
1717
})
1818

19-
middleware := NewRequestStartMiddleware(nextHandler)
19+
handler := NewRequestStartHandler(nextHandler)
2020

2121
before := time.Now().UnixMilli()
2222
req := httptest.NewRequest("GET", "/", nil)
2323
w := httptest.NewRecorder()
24-
middleware.ServeHTTP(w, req)
24+
handler.ServeHTTP(w, req)
2525
after := time.Now().UnixMilli()
2626

2727
assert.NotEmpty(t, capturedHeader)
@@ -34,19 +34,19 @@ func TestRequestStartMiddleware(t *testing.T) {
3434
assert.LessOrEqual(t, timestamp, after)
3535
}
3636

37-
func TestRequestStartMiddlewareDoesNotOverwriteExistingHeader(t *testing.T) {
37+
func TestRequestStartHandlerDoesNotOverwriteExistingHeader(t *testing.T) {
3838
existingHeader := "t=1234567890"
3939
var capturedHeader string
4040
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4141
capturedHeader = r.Header.Get("X-Request-Start")
4242
})
4343

44-
middleware := NewRequestStartMiddleware(nextHandler)
44+
handler := NewRequestStartHandler(nextHandler)
4545

4646
req := httptest.NewRequest("GET", "/", nil)
4747
req.Header.Set("X-Request-Start", existingHeader)
4848
w := httptest.NewRecorder()
49-
middleware.ServeHTTP(w, req)
49+
handler.ServeHTTP(w, req)
5050

5151
assert.Equal(t, existingHeader, capturedHeader)
5252
}

0 commit comments

Comments
 (0)