diff --git a/pkg/util/resource/tracker.go b/pkg/util/resource/tracker.go new file mode 100644 index 00000000000..805ae137353 --- /dev/null +++ b/pkg/util/resource/tracker.go @@ -0,0 +1,163 @@ +package resource + +import ( + "context" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/cortexproject/cortex/pkg/util/services" +) + +type memoryBuckets struct { + buckets []uint64 + lastUpdate time.Time + currentIdx int +} + +type ResourceTracker struct { + services.Service + + memoryData map[string]*memoryBuckets + windowSize int + maxActiveRequests int + + mu sync.RWMutex +} + +type IResourceTracker interface { + AddBytes(requestID string, bytes uint64) + GetHeaviestQuery() (requestID string, bytes uint64) +} + +func NewResourceTracker(windowSize, maxActiveRequests int, registerer prometheus.Registerer) *ResourceTracker { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: windowSize, + maxActiveRequests: maxActiveRequests, + } + + promauto.With(registerer).NewGaugeFunc(prometheus.GaugeOpts{ + Name: "cortex_resource_tracker_active_requests", + }, rt.activeRequestCount) + + rt.Service = services.NewBasicService(nil, rt.running, nil) + return rt +} + +func (rt *ResourceTracker) AddBytes(requestID string, bytes uint64) { + rt.mu.Lock() + defer rt.mu.Unlock() + + now := time.Now().Truncate(time.Second) + + buckets, exists := rt.memoryData[requestID] + if !exists { + // Check if we're at capacity + if len(rt.memoryData) >= rt.maxActiveRequests { + // Evict oldest request + rt.evictOldest() + } + + buckets = &memoryBuckets{ + buckets: make([]uint64, rt.windowSize), + lastUpdate: now, + currentIdx: 0, + } + rt.memoryData[requestID] = buckets + } + + // Calculate seconds drift and rotate buckets if needed + secondsDrift := int(now.Sub(buckets.lastUpdate).Seconds()) + if secondsDrift > 0 { + // Clear old buckets + for i := 0; i < min(secondsDrift, rt.windowSize); i++ { + nextIdx := (buckets.currentIdx + 1 + i) % rt.windowSize + buckets.buckets[nextIdx] = 0 + } + // Update current index + buckets.currentIdx = (buckets.currentIdx + secondsDrift) % rt.windowSize + buckets.lastUpdate = now + } + + // Add bytes to current bucket + buckets.buckets[buckets.currentIdx] += bytes +} + +func (rt *ResourceTracker) GetHeaviestQuery() (string, uint64) { + rt.mu.RLock() + defer rt.mu.RUnlock() + + var maxID string + var maxBytes uint64 + + for id, buckets := range rt.memoryData { + // Sum all buckets + var totalBytes uint64 + for _, bytes := range buckets.buckets { + totalBytes += bytes + } + if totalBytes > maxBytes { + maxBytes = totalBytes + maxID = id + } + } + + return maxID, maxBytes +} + +func (rt *ResourceTracker) running(ctx context.Context) error { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + rt.cleanup() + } + } +} + +func (rt *ResourceTracker) activeRequestCount() float64 { + rt.mu.RLock() + defer rt.mu.RUnlock() + + return float64(len(rt.memoryData)) +} + +func (rt *ResourceTracker) cleanup() { + rt.mu.Lock() + defer rt.mu.Unlock() + + now := time.Now() + cutoff := now.Add(-time.Duration(rt.windowSize) * time.Second) + + // Remove stale requestIDs + for requestID, buckets := range rt.memoryData { + if buckets.lastUpdate.Before(cutoff) { + delete(rt.memoryData, requestID) + } + } +} + +func (rt *ResourceTracker) evictOldest() { + var oldestID string + var oldestTime time.Time + + // Find oldest request + for requestID, buckets := range rt.memoryData { + if oldestID == "" || buckets.lastUpdate.Before(oldestTime) { + oldestID = requestID + oldestTime = buckets.lastUpdate + } + } + + // Remove oldest request + if oldestID != "" { + delete(rt.memoryData, oldestID) + } +} diff --git a/pkg/util/resource/tracker_test.go b/pkg/util/resource/tracker_test.go new file mode 100644 index 00000000000..6f055211bac --- /dev/null +++ b/pkg/util/resource/tracker_test.go @@ -0,0 +1,224 @@ +package resource + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestResourceTracker_AddBytes(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + rt.AddBytes("req1", 1000) + + assert.Len(t, rt.memoryData, 1) + assert.Contains(t, rt.memoryData, "req1") + assert.Equal(t, uint64(1000), rt.memoryData["req1"].buckets[0]) +} + +func TestResourceTracker_GetHeaviestQuery(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + rt.AddBytes("req1", 1000) + rt.AddBytes("req2", 2000) + rt.AddBytes("req3", 500) + + requestID, bytes := rt.GetHeaviestQuery() + assert.Equal(t, "req2", requestID) + assert.Equal(t, uint64(2000), bytes) +} + +func TestResourceTracker_EmptyTracker(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + requestID, bytes := rt.GetHeaviestQuery() + assert.Equal(t, "", requestID) + assert.Equal(t, uint64(0), bytes) +} + +func TestResourceTracker_SlidingWindow(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + // Add bytes at different times + rt.AddBytes("req1", 1000) + + // Simulate 1 second later + rt.mu.Lock() + rt.memoryData["req1"].lastUpdate = rt.memoryData["req1"].lastUpdate.Add(-1 * time.Second) + rt.mu.Unlock() + + rt.AddBytes("req1", 2000) + + // Should have both values in different buckets + _, bytes := rt.GetHeaviestQuery() + assert.Equal(t, uint64(3000), bytes) // 1000 + 2000 +} + +func TestResourceTracker_BucketRotation(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + rt.AddBytes("req1", 1000) + + // Simulate 4 seconds later (should clear old buckets) + rt.mu.Lock() + rt.memoryData["req1"].lastUpdate = rt.memoryData["req1"].lastUpdate.Add(-4 * time.Second) + rt.mu.Unlock() + + rt.AddBytes("req1", 2000) + + // Should only have the new value (old bucket cleared) + _, bytes := rt.GetHeaviestQuery() + assert.Equal(t, uint64(2000), bytes) +} + +func TestResourceTracker_Cleanup(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + rt.AddBytes("req1", 1000) + rt.AddBytes("req2", 2000) + + // Simulate old lastUpdate time + rt.mu.Lock() + rt.memoryData["req1"].lastUpdate = time.Now().Add(-5 * time.Second) + rt.mu.Unlock() + + rt.cleanup() + + assert.Len(t, rt.memoryData, 1) + assert.Contains(t, rt.memoryData, "req2") + assert.NotContains(t, rt.memoryData, "req1") +} + +func TestResourceTracker_MaxActiveRequests(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 100, + } + + // Manually set to limit for faster test + rt.mu.Lock() + for i := 0; i < rt.maxActiveRequests; i++ { + rt.memoryData[fmt.Sprintf("req%d", i)] = &memoryBuckets{ + buckets: make([]uint64, rt.windowSize), + lastUpdate: time.Now(), + } + } + rt.mu.Unlock() + + // Add one more request (should trigger eviction) + rt.AddBytes("new_req", 9999) + + assert.Len(t, rt.memoryData, rt.maxActiveRequests) + assert.Contains(t, rt.memoryData, "new_req") +} + +func TestResourceTracker_EvictOldest(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + now := time.Now() + + // Add requests with different timestamps + rt.memoryData["req1"] = &memoryBuckets{ + buckets: make([]uint64, rt.windowSize), + lastUpdate: now.Add(-10 * time.Second), // Oldest + } + + rt.memoryData["req2"] = &memoryBuckets{ + buckets: make([]uint64, rt.windowSize), + lastUpdate: now.Add(-5 * time.Second), + } + + rt.memoryData["req3"] = &memoryBuckets{ + buckets: make([]uint64, rt.windowSize), + lastUpdate: now, + } + + rt.evictOldest() + + assert.Len(t, rt.memoryData, 2) + assert.NotContains(t, rt.memoryData, "req1") // Oldest should be evicted + assert.Contains(t, rt.memoryData, "req2") + assert.Contains(t, rt.memoryData, "req3") +} + +func TestResourceTracker_ConcurrentAccess(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + // Test concurrent writes + done := make(chan bool, 20) + for i := 0; i < 10; i++ { + go func(id int) { + for j := 0; j < 10; j++ { + rt.AddBytes(fmt.Sprintf("req%d", id), uint64(j)) + } + done <- true + }(i) + } + + // Test concurrent reads + for i := 0; i < 10; i++ { + go func() { + rt.GetHeaviestQuery() + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 20; i++ { + <-done + } + + // Should have 10 requests + assert.Len(t, rt.memoryData, 10) +} + +func TestResourceTracker_AccumulateBytes(t *testing.T) { + rt := &ResourceTracker{ + memoryData: make(map[string]*memoryBuckets), + windowSize: 3, + maxActiveRequests: 10, + } + + // Add bytes multiple times to same request + rt.AddBytes("req1", 1000) + rt.AddBytes("req1", 2000) + rt.AddBytes("req1", 3000) + + _, bytes := rt.GetHeaviestQuery() + assert.Equal(t, uint64(6000), bytes) // Should accumulate +}