From b715fdbb9b6509f7d06426d7f3a897ea65b59f00 Mon Sep 17 00:00:00 2001 From: Bobby Donchev Date: Sun, 25 May 2025 22:11:26 +0400 Subject: [PATCH 1/2] feat: add headers config --- config/config.go | 7 +- config/config_test.go | 125 +++++++++++++++++++++++++++++ config/helpers.go | 13 +-- pkg/http/cache.go | 129 +++++++++++++++++++----------- pkg/http/cache_test.go | 175 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 393 insertions(+), 56 deletions(-) diff --git a/config/config.go b/config/config.go index a58aa55..d465eea 100644 --- a/config/config.go +++ b/config/config.go @@ -37,7 +37,9 @@ type CacheConfig struct { IgnorePaths []string ShouldHashQuery bool HashQueryIgnore map[string]bool + HashHeaders []string } + type Config struct { ServerConfig ServerConfig CacheConfig CacheConfig @@ -48,7 +50,7 @@ func New() Config { serverConfig := ServerConfig{ Port: getEnv("SERVER_PORT", "8000"), GracePeriod: getEnvAsInt("SHUTDOWN_GRACE_PERIOD", "30"), - LogLevel: getEnvAsLogLevel("SERVER_LOG_LEVEL"), + LogLevel: getEnvAsLogLevel("LOG_LEVEL"), Storage: getEnv("SERVER_STORAGE", ""), } @@ -57,9 +59,10 @@ func New() Config { DownstreamHost: getEnvAsURL("DOWNSTREAM_HOST", ""), Size: getEnvAsFloat("CACHE_SIZE_MB", "10"), IgnorePaths: getEnvAsSlice("CACHE_IGNORE_ENDPOINTS"), - StaleInSeconds: getEnvAsInt("CACHE_STALE_WHILE_REVALIDATE_SEC", "5"), + StaleInSeconds: getEnvAsInt("CACHE_STALE_WHILE_REVALIDATE_SEC", "60"), ShouldHashQuery: getEnvAsBool("CACHE_SHOULD_HASH_QUERY", "true"), HashQueryIgnore: hashQueryIgnoreMap(getEnvAsSlice("CACHE_HASH_QUERY_IGNORE")), + HashHeaders: getEnvAsSlice("CACHE_HASH_HEADERS"), } if strings.ToLower(serverConfig.Storage) == "memory" { diff --git a/config/config_test.go b/config/config_test.go index c38c9aa..794bb18 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -9,6 +9,8 @@ import ( func TestConfig(t *testing.T) { t.Run("inMemory config", func(t *testing.T) { defer setupEnv(t, "SERVER_STORAGE", "memory")() + defer setupEnv(t, "CACHE_STRATEGY", "memory")() + defer setupEnv(t, "DOWNSTREAM_HOST", "http://localhost:8080")() config := New() assert.NotNil(t, config) @@ -17,9 +19,132 @@ func TestConfig(t *testing.T) { t.Run("redis config", func(t *testing.T) { defer setupEnv(t, "SERVER_STORAGE", "redis")() defer setupEnv(t, "REDIS_CONNECTION_STRING", "redis")() + defer setupEnv(t, "CACHE_STRATEGY", "redis")() + defer setupEnv(t, "DOWNSTREAM_HOST", "http://localhost:8080")() config := New() assert.NotNil(t, config) assert.Equal(t, "redis", config.RedisConfig.ConnectionString) }) + + t.Run("header configuration", func(t *testing.T) { + tests := []struct { + name string + hashHeaders string + hashHeadersIgnore string + wantHeaders []string + }{ + { + name: "empty headers", + hashHeaders: "", + hashHeadersIgnore: "", + wantHeaders: []string{}, + }, + { + name: "single header", + hashHeaders: "Authorization", + hashHeadersIgnore: "", + wantHeaders: []string{"Authorization"}, + }, + { + name: "multiple headers", + hashHeaders: "Authorization,X-User-ID,Accept", + hashHeadersIgnore: "", + wantHeaders: []string{"Authorization", "X-User-ID", "Accept"}, + }, + { + name: "headers with ignore", + hashHeaders: "Authorization,X-User-ID,Accept", + hashHeadersIgnore: "X-User-ID", + wantHeaders: []string{"Authorization", "X-User-ID", "Accept"}, + }, + { + name: "case insensitive headers", + hashHeaders: "AUTHORIZATION,x-user-id,accept", + hashHeadersIgnore: "X-USER-ID", + wantHeaders: []string{"AUTHORIZATION", "x-user-id", "accept"}, + }, + { + name: "whitespace in headers", + hashHeaders: " Authorization , X-User-ID , Accept ", + hashHeadersIgnore: " X-User-ID ", + wantHeaders: []string{"Authorization", "X-User-ID", "Accept"}, + }, + { + name: "duplicate headers", + hashHeaders: "Authorization,Authorization,X-User-ID", + hashHeadersIgnore: "", + wantHeaders: []string{"Authorization", "Authorization", "X-User-ID"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer setupEnv(t, "CACHE_HASH_HEADERS", tt.hashHeaders)() + defer setupEnv(t, "CACHE_HASH_HEADERS_IGNORE", tt.hashHeadersIgnore)() + + config := New() + + assert.Equal(t, tt.wantHeaders, config.CacheConfig.HashHeaders) + }) + } + }) + + t.Run("query configuration", func(t *testing.T) { + tests := []struct { + name string + shouldHashQuery string + queryIgnore string + wantHashQuery bool + wantIgnore map[string]bool + }{ + { + name: "default query hashing", + shouldHashQuery: "", + queryIgnore: "", + wantHashQuery: true, + wantIgnore: map[string]bool{}, + }, + { + name: "disabled query hashing", + shouldHashQuery: "false", + queryIgnore: "", + wantHashQuery: false, + wantIgnore: map[string]bool{}, + }, + { + name: "query ignore parameters", + shouldHashQuery: "true", + queryIgnore: "timestamp,request_id", + wantHashQuery: true, + wantIgnore: map[string]bool{"timestamp": true, "request_id": true}, + }, + { + name: "whitespace in query ignore", + shouldHashQuery: "true", + queryIgnore: " timestamp , request_id ", + wantHashQuery: true, + wantIgnore: map[string]bool{"timestamp": true, "request_id": true}, + }, + { + name: "case sensitive query ignore", + shouldHashQuery: "true", + queryIgnore: "Timestamp,Request_ID", + wantHashQuery: true, + wantIgnore: map[string]bool{"timestamp": true, "request_id": true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer setupEnv(t, "CACHE_SHOULD_HASH_QUERY", tt.shouldHashQuery)() + defer setupEnv(t, "CACHE_HASH_QUERY_IGNORE", tt.queryIgnore)() + + config := New() + + assert.Equal(t, tt.wantHashQuery, config.CacheConfig.ShouldHashQuery) + assert.Equal(t, tt.wantIgnore, config.CacheConfig.HashQueryIgnore) + }) + } + }) } diff --git a/config/helpers.go b/config/helpers.go index b3be4ce..ee104fe 100644 --- a/config/helpers.go +++ b/config/helpers.go @@ -12,11 +12,9 @@ import ( func hashQueryIgnoreMap(queryIgnore []string) map[string]bool { hashQueryIgnoreMap := make(map[string]bool) - - for i := 0; i < len(queryIgnore); i++ { - hashQueryIgnoreMap[queryIgnore[i]] = true + for _, q := range queryIgnore { + hashQueryIgnoreMap[strings.ToLower(strings.TrimSpace(q))] = true } - return hashQueryIgnoreMap } @@ -55,8 +53,11 @@ func getEnvAsSlice(key string) []string { if strSlice == "" { return []string{} } - - return strings.Split(strSlice, ",") + parts := strings.Split(strSlice, ",") + for i, p := range parts { + parts[i] = strings.TrimSpace(p) + } + return parts } func getEnvAsInt(key, defaultVal string) int { diff --git a/pkg/http/cache.go b/pkg/http/cache.go index 570c887..d94b9dc 100644 --- a/pkg/http/cache.go +++ b/pkg/http/cache.go @@ -6,13 +6,15 @@ import ( "crypto/sha256" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/httputil" + "sort" "strings" "time" + "maps" + "github.com/rs/zerolog/log" "github.com/neurocode-io/cache-offloader/config" @@ -39,33 +41,68 @@ type ( worker Worker metricsCollector MetricsCollector cfg config.CacheConfig + httpClient *http.Client } ) func (h handler) getCacheKey(req *http.Request) string { cacheKey := sha256.New() + + // Include HTTP method in the hash + cacheKey.Write([]byte(req.Method)) + cacheKey.Write([]byte(":")) + + // Add the path cacheKey.Write([]byte(req.URL.Path)) - if !h.cfg.ShouldHashQuery { - return fmt.Sprintf("% x", cacheKey.Sum(nil)) + // Add query parameters if enabled + if h.cfg.ShouldHashQuery { + // Sort query parameters to ensure consistent ordering + query := req.URL.Query() + keys := make([]string, 0, len(query)) + for k := range query { + if _, ok := h.cfg.HashQueryIgnore[k]; !ok { + keys = append(keys, k) + } + } + sort.Strings(keys) + + for _, key := range keys { + values := query[key] + sort.Strings(values) + for _, value := range values { + cacheKey.Write([]byte("&")) + cacheKey.Write([]byte(key)) + cacheKey.Write([]byte("=")) + cacheKey.Write([]byte(value)) + } + } } - for key, values := range req.URL.Query() { - if _, ok := h.cfg.HashQueryIgnore[key]; ok { - continue - } - for _, value := range values { - cacheKey.Write([]byte(fmt.Sprintf("&%s=%s", key, value))) + // Add headers if configured + if len(h.cfg.HashHeaders) > 0 { + // Sort headers to ensure consistent ordering + sort.Strings(h.cfg.HashHeaders) + + for _, headerName := range h.cfg.HashHeaders { + values := req.Header.Values(headerName) + if len(values) > 0 { + sort.Strings(values) + for _, value := range values { + cacheKey.Write([]byte("|")) + cacheKey.Write([]byte(headerName)) + cacheKey.Write([]byte("=")) + cacheKey.Write([]byte(value)) + } + } } } - return fmt.Sprintf("% x", cacheKey.Sum(nil)) + return fmt.Sprintf("%x", cacheKey.Sum(nil)) } -func serveResponseFromMemory(res http.ResponseWriter, result *model.Response) { - for key, values := range result.Header { - res.Header()[key] = values - } +func serveResponseFromMemory(res http.ResponseWriter, result model.Response) { + maps.Copy(res.Header(), result.Header) res.WriteHeader(result.Status) _, err := res.Write(result.Body) @@ -81,11 +118,27 @@ func errHandler(res http.ResponseWriter, req *http.Request, err error) { } func newCacheHandler(c Cacher, m MetricsCollector, w Worker, cfg config.CacheConfig) handler { + netTransport := &http.Transport{ + MaxIdleConnsPerHost: 1000, + DisableKeepAlives: false, + IdleConnTimeout: time.Hour * 1, + Dial: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + } + return handler{ cacher: c, worker: w, metricsCollector: m, cfg: cfg, + httpClient: &http.Client{ + Timeout: time.Second * 10, + Transport: netTransport, + }, } } @@ -94,34 +147,16 @@ func (h handler) asyncCacheRevalidate(hashKey string, req *http.Request) func() ctx := context.Background() newReq := req.WithContext(ctx) - netTransport := &http.Transport{ - MaxIdleConnsPerHost: 1000, - DisableKeepAlives: false, - IdleConnTimeout: time.Hour * 1, - Dial: (&net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - } - client := &http.Client{ - Timeout: time.Second * 10, - Transport: netTransport, - } - newReq.URL.Host = h.cfg.DownstreamHost.Host newReq.URL.Scheme = h.cfg.DownstreamHost.Scheme newReq.RequestURI = "" - resp, err := client.Do(newReq) - if resp != nil { - defer resp.Body.Close() - } + + resp, err := h.httpClient.Do(newReq) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("Errored when sending request to the server") - return } + defer resp.Body.Close() if err := h.cacheResponse(ctx, hashKey)(resp); err != nil { log.Ctx(ctx).Error().Err(err).Msg("Errored when caching response") @@ -171,43 +206,41 @@ func (h handler) ServeHTTP(res http.ResponseWriter, req *http.Request) { if result.IsStale() { go h.worker.Start(hashKey, h.asyncCacheRevalidate(hashKey, req)) } - serveResponseFromMemory(res, result) + serveResponseFromMemory(res, *result) } func (h handler) cacheResponse(ctx context.Context, hashKey string) func(*http.Response) error { return func(response *http.Response) error { - // if this function returns an error, the proxy will return a 502 Bad Gateway error to the client - // please see the proxy.ModifyResponse documentation for more information logger := log.Ctx(ctx) - logger.Debug().Msg("got response from downstream service") h.metricsCollector.CacheMiss(response.Request.Method, response.StatusCode) if response.StatusCode >= http.StatusInternalServerError { logger.Warn().Msg("won't cache 5XX downstream responses") - return nil } body, readErr := io.ReadAll(response.Body) - if readErr != nil { logger.Error().Err(readErr).Msg("error occurred reading response body") - return nil } - header := response.Header - statusCode := response.StatusCode - newBody := ioutil.NopCloser(bytes.NewReader(body)) + // Create a new reader for the response body + response.Body = io.NopCloser(bytes.NewReader(body)) - response.Body = newBody + // Create a copy of the header to prevent modification of the original + headerCopy := make(http.Header) + maps.Copy(headerCopy, response.Header) - entry := model.Response{Body: body, Header: header, Status: statusCode} + entry := model.Response{ + Body: body, + Header: headerCopy, + Status: response.StatusCode, + } if err := h.cacher.Store(ctx, hashKey, &entry); err != nil { logger.Error().Err(err).Msg("error occurred storing response in memory") - return nil } diff --git a/pkg/http/cache_test.go b/pkg/http/cache_test.go index aa342ad..ef7f5dc 100644 --- a/pkg/http/cache_test.go +++ b/pkg/http/cache_test.go @@ -1,6 +1,7 @@ package http import ( + "crypto/sha256" "errors" "fmt" "net/http" @@ -323,3 +324,177 @@ func TestCacheHandler(t *testing.T) { }) } } + +func TestGetCacheKey(t *testing.T) { + tests := []struct { + name string + path string + method string + query string + headers map[string]string + shouldHashQuery bool + ignoreParams []string + hashHeaders []string + ignoreHeaders []string + want string + }{ + { + name: "simple path without query", + path: "/api/users", + method: "GET", + query: "", + shouldHashQuery: true, + want: "GET:/api/users", + }, + { + name: "path with query parameters", + path: "/api/users", + method: "GET", + query: "name=john&age=30", + shouldHashQuery: true, + want: "GET:/api/users&age=30&name=john", + }, + { + name: "query parameters in different order", + path: "/api/users", + method: "GET", + query: "age=30&name=john", + shouldHashQuery: true, + want: "GET:/api/users&age=30&name=john", + }, + { + name: "multiple values for same parameter", + path: "/api/users", + method: "GET", + query: "role=admin&role=user", + shouldHashQuery: true, + want: "GET:/api/users&role=admin&role=user", + }, + { + name: "different HTTP method", + path: "/api/users", + method: "POST", + query: "name=john", + shouldHashQuery: true, + want: "POST:/api/users&name=john", + }, + { + name: "query parameters disabled", + path: "/api/users", + method: "GET", + query: "name=john&age=30", + shouldHashQuery: false, + want: "GET:/api/users", + }, + { + name: "ignored query parameters", + path: "/api/users", + method: "GET", + query: "name=john&age=30×tamp=123", + shouldHashQuery: true, + ignoreParams: []string{"timestamp"}, + want: "GET:/api/users&age=30&name=john", + }, + { + name: "special characters in path and query", + path: "/api/users/123/profile", + method: "GET", + query: "filter=active&sort=name", + shouldHashQuery: true, + want: "GET:/api/users/123/profile&filter=active&sort=name", + }, + { + name: "with authorization header", + path: "/api/users", + method: "GET", + query: "name=john", + headers: map[string]string{ + "Authorization": "Bearer token123", + }, + shouldHashQuery: true, + hashHeaders: []string{"Authorization"}, + want: "GET:/api/users&name=john|Authorization=Bearer token123", + }, + { + name: "multiple headers", + path: "/api/users", + method: "GET", + headers: map[string]string{ + "Authorization": "Bearer token123", + "X-User-ID": "user456", + "Accept": "application/json", + }, + shouldHashQuery: true, + hashHeaders: []string{"Authorization", "X-User-ID", "Accept"}, + want: "GET:/api/users|Accept=application/json|Authorization=Bearer token123|X-User-ID=user456", + }, + { + name: "ignored headers", + path: "/api/users", + method: "GET", + headers: map[string]string{ + "Authorization": "Bearer token123", + "X-User-ID": "user456", + }, + shouldHashQuery: true, + hashHeaders: []string{"Authorization", "X-User-ID"}, + ignoreHeaders: []string{"X-User-ID"}, + want: "GET:/api/users|Authorization=Bearer token123", + }, + { + name: "multiple values for same header", + path: "/api/users", + method: "GET", + headers: map[string]string{ + "Accept": "application/json,text/plain", + }, + shouldHashQuery: true, + hashHeaders: []string{"Accept"}, + want: "GET:/api/users|Accept=application/json,text/plain", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create request + req, err := http.NewRequest(tt.method, tt.path, nil) + if err != nil { + t.Fatal(err) + } + + // Add query parameters if any + if tt.query != "" { + req.URL.RawQuery = tt.query + } + + // Add headers + for key, value := range tt.headers { + req.Header.Set(key, value) + } + + // Create handler with config + h := handler{ + cfg: config.CacheConfig{ + ShouldHashQuery: tt.shouldHashQuery, + HashQueryIgnore: make(map[string]bool), + HashHeaders: tt.hashHeaders, + }, + } + + // Add ignored parameters + for _, param := range tt.ignoreParams { + h.cfg.HashQueryIgnore[param] = true + } + + // Get cache key + got := h.getCacheKey(req) + + // Calculate expected hash + expectedHash := sha256.New() + expectedHash.Write([]byte(tt.want)) + expected := fmt.Sprintf("%x", expectedHash.Sum(nil)) + + assert.Equal(t, expected, got, "cache key mismatch") + }) + } +} From fab7f0d12774e4e25aee145622901c4d4ad000d7 Mon Sep 17 00:00:00 2001 From: Bobby Donchev Date: Sun, 25 May 2025 22:18:17 +0400 Subject: [PATCH 2/2] fix: cache test --- pkg/http/cache_test.go | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/pkg/http/cache_test.go b/pkg/http/cache_test.go index ef7f5dc..f2de478 100644 --- a/pkg/http/cache_test.go +++ b/pkg/http/cache_test.go @@ -335,7 +335,6 @@ func TestGetCacheKey(t *testing.T) { shouldHashQuery bool ignoreParams []string hashHeaders []string - ignoreHeaders []string want string }{ { @@ -428,19 +427,6 @@ func TestGetCacheKey(t *testing.T) { hashHeaders: []string{"Authorization", "X-User-ID", "Accept"}, want: "GET:/api/users|Accept=application/json|Authorization=Bearer token123|X-User-ID=user456", }, - { - name: "ignored headers", - path: "/api/users", - method: "GET", - headers: map[string]string{ - "Authorization": "Bearer token123", - "X-User-ID": "user456", - }, - shouldHashQuery: true, - hashHeaders: []string{"Authorization", "X-User-ID"}, - ignoreHeaders: []string{"X-User-ID"}, - want: "GET:/api/users|Authorization=Bearer token123", - }, { name: "multiple values for same header", path: "/api/users",