diff --git a/http/respChain.go b/http/respChain.go index 6a4e009..69d9258 100644 --- a/http/respChain.go +++ b/http/respChain.go @@ -3,10 +3,14 @@ package httputil import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "fmt" "net/http" "sync" + "sync/atomic" + mapsutil "github.com/projectdiscovery/utils/maps" "github.com/projectdiscovery/utils/sync/sizedpool" ) @@ -27,6 +31,91 @@ func GetPoolSize() int64 { // and reuse it for each request var bufPool *sizedpool.SizedPool[*bytes.Buffer] +// CachedResponse stores cached response data with reference counting +type CachedResponse struct { + Body []byte // Copy of body bytes + FullResponse []byte // Copy of full response bytes + RefCount int32 // Atomic reference counter +} + +// ResponseCache provides thread-safe response caching with reference counting +type ResponseCache struct { + cache *mapsutil.SyncLockMap[string, *CachedResponse] +} + +var globalResponseCache = &ResponseCache{ + cache: mapsutil.NewSyncLockMap[string, *CachedResponse](), +} + +// hashBytes computes SHA256 hash of data +func hashBytes(data []byte) string { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]) +} + +// GetOrStore retrieves cached response or stores a new one +// Returns: (cached response, wasCached) +func (rc *ResponseCache) GetOrStore(bodyHash, fullHash string, bodyData, fullData []byte) (*CachedResponse, bool) { + // Check if full response is cached (most common case) + if cached, ok := rc.cache.Get(fullHash); ok { + atomic.AddInt32(&cached.RefCount, 1) + return cached, true + } + + // Store new response + // Make copies to avoid retaining reference to original buffer + bodyCopy := make([]byte, len(bodyData)) + fullCopy := make([]byte, len(fullData)) + copy(bodyCopy, bodyData) + copy(fullCopy, fullData) + + cr := &CachedResponse{ + Body: bodyCopy, + FullResponse: fullCopy, + RefCount: 1, + } + + // Store by full hash (primary) and body hash (secondary) for lookup + // Both point to the same CachedResponse object + _ = rc.cache.Set(fullHash, cr) + if bodyHash != fullHash { + _ = rc.cache.Set(bodyHash, cr) + } + return cr, false +} + +// Release releases a reference to a cached response +// Note: We track bodyHash for cleanup, but primarily use fullHash for lookups +func (rc *ResponseCache) Release(fullHash, bodyHash string) { + if cached, ok := rc.cache.Get(fullHash); ok { + if atomic.AddInt32(&cached.RefCount, -1) <= 0 { + // Clean up both hash entries to avoid memory leaks + rc.cache.Delete(fullHash) + if bodyHash != fullHash { + rc.cache.Delete(bodyHash) + } + } + } +} + +// GetCachedBody retrieves cached body by hash (without incrementing ref count) +// Used for runtime resolution of hashes in DSL evaluation +func GetCachedBody(hash string) ([]byte, bool) { + if cached, ok := globalResponseCache.cache.Get(hash); ok { + return cached.Body, true + } + return nil, false +} + +// GetCachedFullResponse retrieves cached full response by hash (without incrementing ref count) +// Used for runtime resolution of hashes in DSL evaluation +func GetCachedFullResponse(hash string) ([]byte, bool) { + if cached, ok := globalResponseCache.cache.Get(hash); ok { + return cached.FullResponse, true + } + return nil, false +} + func init() { var p = &sync.Pool{ New: func() any { @@ -75,6 +164,9 @@ type ResponseChain struct { headers *bytes.Buffer body *bytes.Buffer fullResponse *bytes.Buffer + cachedResp *CachedResponse // nil if not using cache + bodyHash string // SHA256 hash of body + fullHash string // SHA256 hash of full response resp *http.Response reloaded bool // if response was reloaded to its previous redirect } @@ -95,19 +187,51 @@ func NewResponseChain(resp *http.Response, maxBody int64) *ResponseChain { // Response returns the current response in the chain func (r *ResponseChain) Headers() *bytes.Buffer { + // Headers are part of fullResponse, but if using cache we need to extract them + // For now, if cached, return nil since headers are typically not accessed separately + // and are included in FullResponse(). This maintains backward compatibility. + if r.cachedResp != nil { + // Headers are in fullResponse, but we don't have easy way to extract just headers + // In practice, code should use FullResponse() instead of Headers() when cache is used + // Return empty buffer to avoid nil pointer issues + return bytes.NewBuffer(nil) + } return r.headers } // Body returns the current response body in the chain func (r *ResponseChain) Body() *bytes.Buffer { + if r.cachedResp != nil { + // Return a buffer wrapper that reads from cache + return bytes.NewBuffer(r.cachedResp.Body) + } return r.body } // FullResponse returns the current response in the chain func (r *ResponseChain) FullResponse() *bytes.Buffer { + if r.cachedResp != nil { + // Return a buffer wrapper that reads from cache + return bytes.NewBuffer(r.cachedResp.FullResponse) + } return r.fullResponse } +// BodyHash returns the hash of the response body +func (r *ResponseChain) BodyHash() string { + return r.bodyHash +} + +// FullHash returns the hash of the full response +func (r *ResponseChain) FullHash() string { + return r.fullHash +} + +// IsCached returns true if this response is using cached data +func (r *ResponseChain) IsCached() bool { + return r.cachedResp != nil +} + // previous updates response pointer to previous response // if it was redirected and returns true else false func (r *ResponseChain) Previous() bool { @@ -156,17 +280,57 @@ func (r *ResponseChain) Fill() error { // join headers and body r.fullResponse.Write(r.headers.Bytes()) r.fullResponse.Write(r.body.Bytes()) + + // Compute hashes after normalization and full response construction + bodyBytes := r.body.Bytes() + fullBytes := r.fullResponse.Bytes() + r.bodyHash = hashBytes(bodyBytes) + r.fullHash = hashBytes(fullBytes) + + // Check cache and use cached version if available + cached, wasCached := globalResponseCache.GetOrStore(r.bodyHash, r.fullHash, bodyBytes, fullBytes) + if wasCached { + // Release buffers immediately since we're using cached version + putBuffer(r.headers) + putBuffer(r.body) + putBuffer(r.fullResponse) + r.headers = nil + r.body = nil + r.fullResponse = nil + r.cachedResp = cached + } else { + // Keep buffers, store reference for future cleanup + r.cachedResp = cached + } + return nil } // Close the response chain and releases the buffers. func (r *ResponseChain) Close() { - putBuffer(r.headers) - putBuffer(r.body) - putBuffer(r.fullResponse) - r.headers = nil - r.body = nil - r.fullResponse = nil + // Release cache reference if using cached response + if r.cachedResp != nil && r.fullHash != "" { + globalResponseCache.Release(r.fullHash, r.bodyHash) + r.cachedResp = nil + } + + // Release buffers (will be no-op if already released due to cache hit) + if r.headers != nil { + putBuffer(r.headers) + r.headers = nil + } + if r.body != nil { + putBuffer(r.body) + r.body = nil + } + if r.fullResponse != nil { + putBuffer(r.fullResponse) + r.fullResponse = nil + } + + // Clear hash references + r.bodyHash = "" + r.fullHash = "" } // Has returns true if the response chain has a response @@ -190,7 +354,22 @@ func (r *ResponseChain) Response() *http.Response { // reset without releasing the buffers // useful for redirect chain func (r *ResponseChain) reset() { - r.headers.Reset() - r.body.Reset() - r.fullResponse.Reset() + // Clear cached response reference if set (for redirect chains) + if r.cachedResp != nil && r.fullHash != "" { + globalResponseCache.Release(r.fullHash, r.bodyHash) + r.cachedResp = nil + r.bodyHash = "" + r.fullHash = "" + } + + // Reset buffers only if they exist (not released due to cache hit) + if r.headers != nil { + r.headers.Reset() + } + if r.body != nil { + r.body.Reset() + } + if r.fullResponse != nil { + r.fullResponse.Reset() + } }