diff --git a/README.md b/README.md index bb65197..fe26527 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,48 @@ func main() { } ``` +### Compress only when response meets minimum byte size + +```go +package main + +import ( + "log" + "net/http" + "strconv" + "strings" + + "github.com/gin-contrib/gzip" + "github.com/gin-gonic/gin" +) + +func main() { + r := gin.Default() + r.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithMinLength(2048))) + r.GET("/ping", func(c *gin.Context) { + sizeStr := c.Query("size") + size, _ := strconv.Atoi(sizeStr) + c.String(http.StatusOK, strings.Repeat("a", size)) + }) + + // Listen and Server in 0.0.0.0:8080 + if err := r.Run(":8080"); err != nil { + log.Fatal(err) + } +} +``` +Test with curl: +```bash +curl -i --compressed 'http://localhost:8080/ping?size=2047' +curl -i --compressed 'http://localhost:8080/ping?size=2048' +``` + +Notes: +- If a "Content-Length" header is set, that will be used to determine whether to compress based on the given min length. +- If no "Content-Length" header is set, a buffer is used to temporarily store writes until the min length is met or the request completes. + - Setting a high min length will result in more buffering (2048 bytes is a recommended default for most cases) + - The handler performs optimizations to avoid unnecessary operations, such as testing if `len(data)` exceeds min length before writing to the buffer, and reusing buffers between requests. + ### Customized Excluded Extensions ```go diff --git a/gzip.go b/gzip.go index 931945a..4a188ae 100644 --- a/gzip.go +++ b/gzip.go @@ -2,10 +2,12 @@ package gzip import ( "bufio" + "bytes" "compress/gzip" "errors" "net" "net/http" + "strconv" "github.com/gin-gonic/gin" ) @@ -25,15 +27,46 @@ func Gzip(level int, options ...Option) gin.HandlerFunc { type gzipWriter struct { gin.ResponseWriter writer *gzip.Writer + // minLength is the minimum length of the response body (in bytes) to enable compression + minLength int + // shouldCompress indicates whether the minimum length for compression has been met + shouldCompress bool + // buffer to store response data in case compression limit not met + buffer bytes.Buffer } func (g *gzipWriter) WriteString(s string) (int, error) { - g.Header().Del("Content-Length") - return g.writer.Write([]byte(s)) + return g.Write([]byte(s)) } +// Write writes the given data to the appropriate underlying writer. +// Note that this method can be called multiple times within a single request. func (g *gzipWriter) Write(data []byte) (int, error) { - g.Header().Del("Content-Length") + // If a Content-Length header is set, use that to decide whether to compress the response. + if g.Header().Get("Content-Length") != "" { + contentLen, _ := strconv.Atoi(g.Header().Get("Content-Length")) // err intentionally ignored for invalid headers + if contentLen < g.minLength { + return g.ResponseWriter.Write(data) + } + g.shouldCompress = true + g.Header().Del("Content-Length") + } + + // Check if the response body is large enough to be compressed. If so, skip this condition and proceed with the + // normal write process. If not, store the data in the buffer in case more data is written later. + // (At the end, if the response body is still too small, the caller should check wasMinLengthMetForCompression and + // use the data stored in the buffer to write the response instead.) + if !g.shouldCompress && len(data) >= g.minLength { + g.shouldCompress = true + } else if !g.shouldCompress { + lenWritten, err := g.buffer.Write(data) + if err != nil || g.buffer.Len() < g.minLength { + return lenWritten, err + } + g.shouldCompress = true + data = g.buffer.Bytes() + } + return g.writer.Write(data) } @@ -42,12 +75,6 @@ func (g *gzipWriter) Flush() { g.ResponseWriter.Flush() } -// Fix: https://github.com/mholt/caddy/issues/38 -func (g *gzipWriter) WriteHeader(code int) { - g.Header().Del("Content-Length") - g.ResponseWriter.WriteHeader(code) -} - // Ensure gzipWriter implements the http.Hijacker interface. // This will cause a compile-time error if gzipWriter does not implement all methods of the http.Hijacker interface. var _ http.Hijacker = (*gzipWriter)(nil) diff --git a/gzip_test.go b/gzip_test.go index be64040..f3af03c 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -13,6 +13,7 @@ import ( "net/http/httputil" "net/url" "strconv" + "strings" "testing" "github.com/gin-gonic/gin" @@ -136,6 +137,17 @@ func TestGzipPNG(t *testing.T) { assert.Equal(t, w.Body.String(), "this is a PNG!") } +func TestWriteString(t *testing.T) { + testC, _ := gin.CreateTestContext(httptest.NewRecorder()) + gz := gzipWriter{ + ResponseWriter: testC.Writer, + writer: gzip.NewWriter(testC.Writer), + } + n, err := gz.WriteString("test") + assert.NoError(t, err) + assert.Equal(t, 4, n) +} + func TestExcludedPathsAndExtensions(t *testing.T) { tests := []struct { path string @@ -377,6 +389,138 @@ func TestCustomShouldCompressFn(t *testing.T) { assert.Equal(t, testResponse, w.Body.String()) } +func TestMinLengthShortResponse(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + req.Header.Add(headerAcceptEncoding, "gzip") + + router := gin.New() + router.Use(Gzip(DefaultCompression, WithMinLength(2048))) + router.GET("/", func(c *gin.Context) { + c.String(200, testResponse) + }) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "", w.Header().Get(headerContentEncoding)) + assert.Equal(t, "19", w.Header().Get("Content-Length")) + assert.Equal(t, testResponse, w.Body.String()) +} + +func TestMinLengthLongResponse(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + req.Header.Add(headerAcceptEncoding, "gzip") + + router := gin.New() + router.Use(Gzip(DefaultCompression, WithMinLength(2048))) + router.GET("/", func(c *gin.Context) { + c.String(200, strings.Repeat("a", 2048)) + }) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "gzip", w.Header().Get(headerContentEncoding)) + assert.NotEqual(t, "2048", w.Header().Get("Content-Length")) + assert.Less(t, w.Body.Len(), 2048) +} + +func TestMinLengthMultiWriteResponse(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + req.Header.Add(headerAcceptEncoding, "gzip") + + router := gin.New() + router.Use(Gzip(DefaultCompression, WithMinLength(2048))) + router.GET("/", func(c *gin.Context) { + c.String(200, strings.Repeat("a", 1024)) + c.String(200, strings.Repeat("b", 1024)) + }) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "gzip", w.Header().Get(headerContentEncoding)) + assert.NotEqual(t, "2048", w.Header().Get("Content-Length")) + assert.Less(t, w.Body.Len(), 2048) +} + +// Note this test intentionally triggers gzipping even when the actual response doesn't meet min length. This is because +// we use the Content-Length header as the primary determinant of compression to avoid the cost of buffering. +func TestMinLengthUsesContentLengthHeaderInsteadOfBuffering(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + req.Header.Add(headerAcceptEncoding, "gzip") + + router := gin.New() + router.Use(Gzip(DefaultCompression, WithMinLength(2048))) + router.GET("/", func(c *gin.Context) { + c.Header("Content-Length", "2048") + c.String(200, testResponse) + }) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "gzip", w.Header().Get(headerContentEncoding)) + assert.NotEmpty(t, w.Header().Get("Content-Length")) + assert.NotEqual(t, "19", w.Header().Get("Content-Length")) +} + +// Note this test intentionally does not trigger gzipping even when the actual response meets min length. This is +// because we use the Content-Length header as the primary determinant of compression to avoid the cost of buffering. +func TestMinLengthMultiWriteResponseUsesContentLengthHeaderInsteadOfBuffering(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + req.Header.Add(headerAcceptEncoding, "gzip") + + router := gin.New() + router.Use(Gzip(DefaultCompression, WithMinLength(1024))) + router.GET("/", func(c *gin.Context) { + c.Header("Content-Length", "999") + c.String(200, strings.Repeat("a", 1024)) + c.String(200, strings.Repeat("b", 1024)) + }) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.NotEqual(t, "gzip", w.Header().Get(headerContentEncoding)) // no gzip due to Content-Length header + assert.Equal(t, "2048", w.Header().Get("Content-Length")) +} + +func TestMinLengthWithInvalidContentLengthHeader(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil) + req.Header.Add(headerAcceptEncoding, "gzip") + + router := gin.New() + router.Use(Gzip(DefaultCompression, WithMinLength(2048))) + router.GET("/", func(c *gin.Context) { + c.Header("Content-Length", "xyz") + c.String(200, testResponse) + }) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "", w.Header().Get(headerContentEncoding)) + assert.Equal(t, "19", w.Header().Get("Content-Length")) +} + +func TestFlush(t *testing.T) { + testC, _ := gin.CreateTestContext(httptest.NewRecorder()) + gz := gzipWriter{ + ResponseWriter: testC.Writer, + writer: gzip.NewWriter(testC.Writer), + } + _, _ = gz.WriteString("test") + gz.Flush() + assert.True(t, gz.Written()) +} + type hijackableResponse struct { Hijacked bool header http.Header diff --git a/handler.go b/handler.go index 412c838..e926c03 100644 --- a/handler.go +++ b/handler.go @@ -84,13 +84,27 @@ func (g *gzipHandler) Handle(c *gin.Context) { if originalEtag != "" && !strings.HasPrefix(originalEtag, "W/") { c.Header("ETag", "W/"+originalEtag) } - c.Writer = &gzipWriter{c.Writer, gz} + gzWriter := gzipWriter{ + ResponseWriter: c.Writer, + writer: gz, + minLength: g.minLength, + } + c.Writer = &gzWriter defer func() { + // if compression limit not met after all write commands were executed, then the response data is stored in the + // internal buffer which should now be written to the response writer directly + if !gzWriter.shouldCompress { + c.Writer.Header().Del(headerContentEncoding) + c.Writer.Header().Del(headerVary) + _, _ = gzWriter.ResponseWriter.Write(gzWriter.buffer.Bytes()) + gzWriter.writer.Reset(io.Discard) + } + if c.Writer.Size() < 0 { // do not write gzip footer when nothing is written to the response body - gz.Reset(io.Discard) + gzWriter.writer.Reset(io.Discard) } - _ = gz.Close() + _ = gzWriter.writer.Close() if c.Writer.Size() > -1 { c.Header("Content-Length", strconv.Itoa(c.Writer.Size())) } diff --git a/options.go b/options.go index 67607f5..97a9e32 100644 --- a/options.go +++ b/options.go @@ -47,6 +47,7 @@ type config struct { decompressFn func(c *gin.Context) decompressOnly bool customShouldCompressFn func(c *gin.Context) bool + minLength int } // WithExcludedExtensions returns an Option that sets the ExcludedExtensions field of the Options struct. @@ -117,6 +118,29 @@ func WithCustomShouldCompressFn(fn func(c *gin.Context) bool) Option { }) } +// WithMinLength returns an Option that sets the minLength field of the Options struct. +// Parameters: +// - minLength: int - The minimum length of the response body (in bytes) to trigger gzip compression. +// If the response body is smaller than this length, it will not be compressed. +// This option is useful for avoiding the overhead of compression on small responses, especially since gzip +// compression actually increases the size of small responses. 2048 is a recommended value for most cases. +// The minLength value must be non-negative; negative values will cause undefined behavior. +// +// Note that specifying this option does not override other options. If a path has been excluded (eg through +// WithExcludedPaths), it will continue to be excluded. +// +// Returns: +// - Option - An option that sets the MinLength field of the Options struct. +// +// Example: +// +// router.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithMinLength(2048))) +func WithMinLength(minLength int) Option { + return optionFunc(func(o *config) { + o.minLength = minLength + }) +} + // Using map for better lookup performance type ExcludedExtensions map[string]struct{}