diff --git a/examples/using-cache/README.md b/examples/using-cache/README.md new file mode 100644 index 000000000..5fe04c102 --- /dev/null +++ b/examples/using-cache/README.md @@ -0,0 +1,301 @@ +# Gofr Cache: In-Memory & Redis + +## Overview + +Gofr provides a unified cache interface with two implementations: **in-memory** and **Redis-backed**. Both implementations support the same `cache.Cache` interface, making them interchangeable in your application. + +## Cache Interface + +All cache implementations implement the `cache.Cache` interface: + +```go +type Cache interface { + Get(ctx context.Context, key string) (any, bool, error) + Set(ctx context.Context, key string, value any) error + Delete(ctx context.Context, key string) error + Exists(ctx context.Context, key string) (bool, error) + Clear(ctx context.Context) error + Close(ctx context.Context) error + UseTracer(tracer trace.Tracer) +} +``` + +## Usage + +### Method 1: Using App Convenience Methods (Recommended) + +```go +import "gofr.dev/pkg/gofr" + +app := gofr.New() + +// Add in-memory cache +app.AddInMemoryCache(ctx, "my-cache", 5*time.Minute, 1000) + +// Add Redis cache +app.AddRedisCache(ctx, "my-redis-cache", 10*time.Minute, "localhost:6379") + +// Get cache instance +cache := app.GetCache("my-cache") +``` + +### Method 2: Direct Instantiation + +#### In-Memory Cache + +```go +import "gofr.dev/pkg/cache/inmemory" + +cache, err := inmemory.NewInMemoryCache(ctx, + inmemory.WithName("my-cache"), + inmemory.WithTTL(5*time.Minute), + inmemory.WithMaxItems(1000), +) +``` + +**Configuration Options:** + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `WithTTL(duration)` | `time.Duration` | `1 minute` | Default TTL for cache entries. Zero disables expiration | +| `WithMaxItems(int)` | `int` | `0` (no limit) | Maximum items before LRU eviction | +| `WithName(string)` | `string` | `"default"` | Cache name for logging/metrics | +| `WithLogger(logger)` | `observability.Logger` | `NewStdLogger()` | Custom logger implementation | +| `WithMetrics(metrics)` | `*observability.Metrics` | `NewMetrics("gofr", "inmemory_cache")` | Prometheus metrics collector | + +#### Redis Cache + +```go +import "gofr.dev/pkg/cache/redis" + +cache, err := redis.NewRedisCache(ctx, + redis.WithName("my-redis-cache"), + redis.WithTTL(10*time.Minute), + redis.WithAddr("localhost:6379"), + redis.WithPassword("password"), + redis.WithDB(0), +) +``` + +**Configuration Options:** + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `WithTTL(duration)` | `time.Duration` | `1 minute` | Default TTL for cache entries | +| `WithAddr(string)` | `string` | `"localhost:6379"` | Redis server address | +| `WithPassword(string)` | `string` | `""` | Redis authentication password | +| `WithDB(int)` | `int` | `0` | Redis database number (0-255) | +| `WithName(string)` | `string` | `"default-redis"` | Cache name for logging/metrics | +| `WithLogger(logger)` | `observability.Logger` | `NewStdLogger()` | Custom logger implementation | +| `WithMetrics(metrics)` | `*observability.Metrics` | `NewMetrics("gofr", "redis_cache")` | Prometheus metrics collector | + +### Method 3: Factory Pattern + +```go +import "gofr.dev/pkg/cache/factory" + +// In-memory cache +cache, err := factory.NewInMemoryCache(ctx, "my-cache", + factory.WithTTL(5*time.Minute), + factory.WithMaxItems(1000), +) + +// Redis cache +cache, err := factory.NewRedisCache(ctx, "my-redis-cache", + factory.WithTTL(10*time.Minute), + factory.WithRedisAddr("localhost:6379"), +) +``` + +## Features + +### In-Memory Cache +- **LRU eviction**: Automatically removes least recently used items when capacity is reached +- **TTL support**: Automatic expiration of cache entries +- **Thread-safe**: Concurrent access supported with RWMutex +- **Background cleanup**: Periodic cleanup of expired items +- **O(1) operations**: Get, Set, Delete operations are constant time +- **Memory efficient**: Uses doubly-linked list for LRU implementation + +### Redis Cache +- **Persistence**: Data survives application restarts +- **TTL support**: Automatic expiration handled by Redis +- **Connection pooling**: Managed by go-redis client +- **Serialization**: Automatic JSON serialization for complex types +- **Network resilience**: Built-in retry and connection management +- **Cluster support**: Can connect to Redis Cluster or Sentinel + +### Common Features +- **Observability**: Built-in logging and Prometheus metrics +- **Tracing**: OpenTelemetry integration with span attributes +- **Error handling**: Comprehensive error types and validation +- **Context support**: All operations accept context for cancellation/timeout +- **Type safety**: Strong typing with proper error handling + +## Monitoring & Observability + +### Metrics + +Both cache implementations expose comprehensive Prometheus metrics: + +#### Common Metrics +- `gofr_{backend}_hits_total`: Total cache hits +- `gofr_{backend}_misses_total`: Total cache misses +- `gofr_{backend}_sets_total`: Total set operations +- `gofr_{backend}_deletes_total`: Total delete operations +- `gofr_{backend}_items_current`: Current number of items +- `gofr_{backend}_operation_latency_seconds`: Operation latency histogram + +#### In-Memory Only Metrics +- `gofr_inmemory_cache_evictions_total`: Items evicted due to capacity limits + +Replace `{backend}` with `inmemory_cache` or `redis_cache`. + +### Logging + +Cache operations are logged with structured information: +- Operation type (GET, SET, DELETE, etc.) +- Cache name +- Key being operated on +- Duration +- Success/failure status + +### Tracing + +OpenTelemetry spans are created for each cache operation with attributes: +- `cache.name`: Cache instance name +- `cache.key`: Key being operated on +- `cache.operation`: Operation type + +## Docker & Monitoring Stack + +### Quick Start with Monitoring + +The example includes a pre-configured monitoring stack with Prometheus and Grafana: + +```bash +# Start the application +go run main.go + +# In another terminal, start the monitoring stack +./monitoring.sh +``` + +### Monitoring Stack Components + +#### Prometheus Configuration +```yaml +# pkg/cache/monitoring/prometheus.yml +global: + scrape_interval: 5s +scrape_configs: + - job_name: 'gofr-cache' + static_configs: + - targets: ['host.docker.internal:8080'] +``` + +#### Grafana Setup +- **URL**: http://localhost:3000 +- **Username**: `admin` +- **Password**: `admin` +- **Pre-configured**: Prometheus data source and cache metrics dashboard + +#### Docker Compose Services +```yaml +services: + prometheus: + image: prom/prometheus:latest + ports: ["9090:9090"] + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml + + grafana: + image: grafana/grafana:latest + ports: ["3000:3000"] + environment: + - GF_SECURITY_ADMIN_PASSWORD=admin + volumes: + - ./provisioning:/etc/grafana/provisioning +``` + +### Manual Monitoring Setup + +If you prefer to set up monitoring manually: + +1. **Start Redis** (if using Redis cache): +```bash +docker run -d --name redis -p 6379:6379 redis:alpine +``` + +2. **Start Prometheus**: +```bash +docker run -d --name prometheus -p 9090:9090 \ + -v $(pwd)/prometheus.yml:/etc/prometheus/prometheus.yml \ + prom/prometheus:latest +``` + +3. **Start Grafana**: +```bash +docker run -d --name grafana -p 3000:3000 \ + -e GF_SECURITY_ADMIN_PASSWORD=admin \ + grafana/grafana:latest +``` + +## Example Application + +The `main.go` file demonstrates: + +1. **Cache initialization** using app convenience methods +2. **Metrics exposure** on HTTP endpoint (port 2121) +3. **Continuous operations** to generate observable metrics +4. **Graceful shutdown** handling with signal management + +### Running the Example + +```bash +# Start the application +go run main.go + +# View Grafana dashboard (after starting monitoring stack) +http://localhost:3000 +``` + +### Application Configuration + +The example application uses these default settings: +- **Metrics port**: 2121 (configurable via `METRICS_PORT` env var) +- **Cache TTL**: 5 minutes +- **Max items**: 1000 (in-memory only) +- **Redis address**: localhost:6379 + +## Error Handling + +### Common Errors + +| Error | Description | Resolution | +|-------|-------------|------------| +| `ErrCacheClosed` | Operation attempted on closed cache | Ensure cache is not closed | +| `ErrEmptyKey` | Empty key provided | Provide non-empty key | +| `ErrNilValue` | Nil value provided to Set | Provide non-nil value | +| `ErrInvalidMaxItems` | Negative maxItems value | Use non-negative value | +| `ErrAddressEmpty` | Empty Redis address | Provide valid Redis address | +| `ErrInvalidDatabaseNumber` | Invalid Redis DB number | Use 0-255 range | +| `ErrNegativeTTL` | Negative TTL value | Use non-negative duration | + +### Error Handling Best Practices + +```go +// Always check for errors +value, found, err := cache.Get(ctx, "key") +if err != nil { + log.Printf("Cache get error: %v", err) + // Handle error appropriately + return err +} + +// Check if key exists +if !found { + // Key doesn't exist, handle accordingly + return nil +} +``` diff --git a/examples/using-cache/configs/.env b/examples/using-cache/configs/.env new file mode 100644 index 000000000..2c5e1c93e --- /dev/null +++ b/examples/using-cache/configs/.env @@ -0,0 +1 @@ +METRICS_PORT=8080 \ No newline at end of file diff --git a/examples/using-cache/main.go b/examples/using-cache/main.go new file mode 100644 index 000000000..dc3da2a66 --- /dev/null +++ b/examples/using-cache/main.go @@ -0,0 +1,95 @@ +// DISCLAIMER: +// This is a simple simulation of using a cache with some seed metrics +// to demonstrate how the cache factory, metrics, and observability hooks work. +// It continuously sets, gets, and deletes cache keys in a loop to generate +// measurable metrics for Prometheus. +// NOTE: +// - This is not intended for production use as-is. +// - Actual implementations may differ significantly depending on requirements, +// storage backends, error handling, concurrency, and performance optimizations. + +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "gofr.dev/pkg/gofr" +) + +func main() { + // Cancellable context that ends on SIGINT/SIGTERM. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + <-sigCh + fmt.Println("\nReceived shutdown signal, exiting...") + cancel() + }() + + // Initialize the GoFr app. + app := gofr.New() + + // Method 1: Using the app's convenience methods (recommended for most cases) + // Tracing is automatically handled by the factory - no manual setup required! + app.AddInMemoryCache(ctx, "default", 5*time.Minute, 1000) + // app.AddRedisCache(ctx, "default", 5*time.Minute, "localhost:6379") + + // Method 2: Using the factory directly (for more control) + // c, err := factory.NewInMemoryCache( + // ctx, + // "default", + // factory.WithLogger(app.Logger()), + // factory.WithTTL(5*time.Minute), + // factory.WithMaxItems(1000), + // ) + // if err != nil { + // panic(fmt.Sprintf("failed to create cache: %v", err)) + // } + // app.container.AddCache("default", c) + + // Get the cache instance into a variable 'c' in the main scope. + c := app.GetCache("default") + if c == nil { + panic("failed to get cache from app container") + } + + // Goroutine to run the app's metrics server. + go func() { + port := app.Config.Get("METRICS_PORT") + if port == "" { + port = "2121" + } + fmt.Printf("Metrics available at http://localhost:%s/metrics\n", port) + app.Run() + cancel() + }() + + // Goroutine to simulate cache usage. + go func() { + for { + select { + case <-ctx.Done(): + return + default: + c.Set(ctx, "alpha", 42) // triggers sets_total + c.Get(ctx, "alpha") // triggers hits_total + c.Get(ctx, "nonexistent") // triggers misses_total + c.Delete(ctx, "alpha") // triggers deletes_total + c.Set(ctx, "alpha", 100) + time.Sleep(2 * time.Second) + } + } + }() + + // Wait until the context is canceled. + <-ctx.Done() + fmt.Println("Shutdown complete.") +} diff --git a/examples/using-cache/monitoring.sh b/examples/using-cache/monitoring.sh new file mode 100755 index 000000000..1ffe0a7c2 --- /dev/null +++ b/examples/using-cache/monitoring.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# This script starts the monitoring stack (Prometheus and Grafana) +# for the cache example. + +# Get the directory of the script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" + +# Navigate to the monitoring directory and start docker-compose +cd "${SCRIPT_DIR}/../../pkg/cache/monitoring" && docker-compose up --build \ No newline at end of file diff --git a/go.mod b/go.mod index 93f24d403..049ab743c 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( go.opentelemetry.io/otel/sdk/metric v1.37.0 go.opentelemetry.io/otel/trace v1.37.0 go.uber.org/mock v0.6.0 + gofr.dev/pkg/cache v0.0.0-00010101000000-000000000000 golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.34.0 @@ -46,6 +47,8 @@ require ( modernc.org/sqlite v1.38.2 ) +replace gofr.dev/pkg/cache => ./pkg/cache + require ( cloud.google.com/go v0.120.0 // indirect cloud.google.com/go/auth v0.16.3 // indirect diff --git a/go.work b/go.work index 4772f9d2f..01738e965 100644 --- a/go.work +++ b/go.work @@ -3,6 +3,7 @@ go 1.24.5 use ( . ./examples/using-add-filestore + ./pkg/cache ./pkg/gofr/datasource/arangodb ./pkg/gofr/datasource/cassandra ./pkg/gofr/datasource/clickhouse diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 000000000..a7feec992 --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,37 @@ +package cache + +import ( + "context" + + "go.opentelemetry.io/otel/trace" +) + +type Cache interface { + // Get retrieves the value associated with the given key. + // It returns the value, a boolean indicating if the key was found, and an error if any occurred. + // If the key is not found, it returns nil, false, nil. + Get(ctx context.Context, key string) (any, bool, error) + + // Set stores a key-value pair in the cache. + // If the key already exists, its value is overwritten. + // It may also set a time-to-live (TTL) for the entry, depending on the implementation. + Set(ctx context.Context, key string, value any) error + + // Delete removes the key-value pair associated with the given key from the cache. + // If the key does not exist, it does nothing and returns nil. + Delete(ctx context.Context, key string) error + + // Exists checks if a key exists in the cache and has not expired. + // It returns true if the key exists, and false otherwise. + Exists(ctx context.Context, key string) (bool, error) + + // Clear removes all key-value pairs from the cache. + // This operation is destructive and should be used with caution. + Clear(ctx context.Context) error + + // Close releases any resources used by the cache, such as background goroutines or network connections. + // After Close is called, the cache may no longer be usable. + Close(ctx context.Context) error + + UseTracer(tracer trace.Tracer) +} diff --git a/pkg/cache/factory/factory.go b/pkg/cache/factory/factory.go new file mode 100644 index 000000000..f8dfb9b5a --- /dev/null +++ b/pkg/cache/factory/factory.go @@ -0,0 +1,248 @@ +package factory + +import ( + "context" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + + "gofr.dev/pkg/cache" + "gofr.dev/pkg/cache/inmemory" + "gofr.dev/pkg/cache/observability" + "gofr.dev/pkg/cache/redis" + "gofr.dev/pkg/gofr/logging" +) + +type config struct { + inMemoryOptions []inmemory.Option + redisOptions []redis.Option + logger logging.Logger + metrics *observability.Metrics +} + +type Option func(*config) + +// WithLogger sets a custom logging.Logger for the cache. +func WithLogger(logger logging.Logger) Option { + return func(c *config) { + c.logger = logger + } +} + +// WithObservabilityLogger sets a custom observability.Logger for both in-memory and Redis caches. +func WithObservabilityLogger(logger observability.Logger) Option { + return func(c *config) { + c.inMemoryOptions = append(c.inMemoryOptions, inmemory.WithLogger(logger)) + c.redisOptions = append(c.redisOptions, redis.WithLogger(logger)) + } +} + +// WithMetrics sets a metrics collector for both in-memory and Redis caches. +func WithMetrics(metrics *observability.Metrics) Option { + return func(c *config) { + c.metrics = metrics + c.inMemoryOptions = append(c.inMemoryOptions, inmemory.WithMetrics(metrics)) + c.redisOptions = append(c.redisOptions, redis.WithMetrics(metrics)) + } +} + +// WithRedisAddr sets the Redis connection address. +func WithRedisAddr(addr string) Option { + return func(c *config) { + c.redisOptions = append(c.redisOptions, redis.WithAddr(addr)) + } +} + +// WithRedisPassword sets the password for Redis authentication. +func WithRedisPassword(password string) Option { + return func(c *config) { + c.redisOptions = append(c.redisOptions, redis.WithPassword(password)) + } +} + +// WithRedisDB sets the Redis database number. +func WithRedisDB(db int) Option { + return func(c *config) { + c.redisOptions = append(c.redisOptions, redis.WithDB(db)) + } +} + +// WithTTL sets the time-to-live for cache entries (applies to both in-memory and Redis caches). +func WithTTL(ttl time.Duration) Option { + return func(c *config) { + c.inMemoryOptions = append(c.inMemoryOptions, inmemory.WithTTL(ttl)) + c.redisOptions = append(c.redisOptions, redis.WithTTL(ttl)) + } +} + +// WithMaxItems sets the maximum number of items for in-memory cache. +func WithMaxItems(maxItems int) Option { + return func(c *config) { + c.inMemoryOptions = append(c.inMemoryOptions, inmemory.WithMaxItems(maxItems)) + } +} + +type contextAwareLogger struct { + logging.Logger +} + +func extractTraceID(ctx context.Context) map[string]any { + if ctx == nil { + return nil + } + + sc := trace.SpanFromContext(ctx).SpanContext() + if sc.IsValid() { + return map[string]any{"__trace_id__": sc.TraceID().String()} + } + + return nil +} + +func (l *contextAwareLogger) Errorf(ctx context.Context, format string, args ...any) { + if traceMap := extractTraceID(ctx); traceMap != nil { + args = append(args, traceMap) + } + + l.Logger.Errorf(format, args...) +} + +func (l *contextAwareLogger) Warnf(ctx context.Context, format string, args ...any) { + if traceMap := extractTraceID(ctx); traceMap != nil { + args = append(args, traceMap) + } + + l.Logger.Warnf(format, args...) +} + +func (l *contextAwareLogger) Infof(ctx context.Context, format string, args ...any) { + if traceMap := extractTraceID(ctx); traceMap != nil { + args = append(args, traceMap) + } + + l.Logger.Infof(format, args...) +} + +func (l *contextAwareLogger) Debugf(ctx context.Context, format string, args ...any) { + if traceMap := extractTraceID(ctx); traceMap != nil { + args = append(args, traceMap) + } + + l.Logger.Debugf(format, args...) +} + +func (l *contextAwareLogger) Hitf(ctx context.Context, message string, duration time.Duration, operation string) { + args := []any{message, operation, duration} + if traceMap := extractTraceID(ctx); traceMap != nil { + args = append(args, traceMap) + } + + l.Logger.Infof("%s: %s, duration: %s", args...) +} + +func (l *contextAwareLogger) Missf(ctx context.Context, message string, duration time.Duration, operation string) { + args := []any{message, operation, duration} + if traceMap := extractTraceID(ctx); traceMap != nil { + args = append(args, traceMap) + } + + l.Logger.Infof("%s: %s, duration: %s", args...) +} + +func (l *contextAwareLogger) LogRequest(ctx context.Context, level, message string, tag any, duration time.Duration, operation string) { + logMessage := "message: %s, tag: %v, duration: %v, operation: %s" + args := []any{message, tag, duration, operation} + + if traceMap := extractTraceID(ctx); traceMap != nil { + args = append(args, traceMap) + } + + switch level { + case "INFO": + l.Logger.Infof(logMessage, args...) + case "DEBUG": + l.Logger.Debugf(logMessage, args...) + case "WARN": + l.Logger.Warnf(logMessage, args...) + case "ERROR": + l.Logger.Errorf(logMessage, args...) + default: + args = append(args, "unsupported log level", level) + l.Logger.Logf(logMessage, args...) + } +} + +func getTracer(name string) trace.Tracer { + return otel.GetTracerProvider().Tracer(name) +} + +type cacheBuilder interface { + build(ctx context.Context, cfg *config) (cache.Cache, error) +} + +type inMemoryBuilder struct { + name string +} + +func (b *inMemoryBuilder) build(ctx context.Context, cfg *config) (cache.Cache, error) { + cfg.inMemoryOptions = append(cfg.inMemoryOptions, inmemory.WithName(b.name)) + + c, err := inmemory.NewInMemoryCache(ctx, cfg.inMemoryOptions...) + if err != nil { + return nil, err + } + + c.UseTracer(getTracer("gofr-inmemory-cache")) + + return c, nil +} + +type redisBuilder struct { + name string +} + +func (b *redisBuilder) build(ctx context.Context, cfg *config) (cache.Cache, error) { + cfg.redisOptions = append(cfg.redisOptions, redis.WithName(b.name)) + + c, err := redis.NewRedisCache(ctx, cfg.redisOptions...) + if err != nil { + return nil, err + } + + c.UseTracer(getTracer("gofr-redis-cache")) + + return c, nil +} + +func newCacheWithBuilder(ctx context.Context, builder cacheBuilder, opts ...Option) (cache.Cache, error) { + cfg := &config{} + for _, opt := range opts { + opt(cfg) + } + + if cfg.logger != nil { + adaptedLogger := &contextAwareLogger{Logger: cfg.logger} + cfg.inMemoryOptions = append(cfg.inMemoryOptions, inmemory.WithLogger(adaptedLogger)) + cfg.redisOptions = append(cfg.redisOptions, redis.WithLogger(adaptedLogger)) + } + + return builder.build(ctx, cfg) +} + +func NewInMemoryCache(ctx context.Context, name string, opts ...Option) (cache.Cache, error) { + return newCacheWithBuilder(ctx, &inMemoryBuilder{name: name}, opts...) +} + +func NewRedisCache(ctx context.Context, name string, opts ...Option) (cache.Cache, error) { + return newCacheWithBuilder(ctx, &redisBuilder{name: name}, opts...) +} + +func NewCache(ctx context.Context, cacheType, name string, opts ...Option) (cache.Cache, error) { + switch cacheType { + case "redis": + return NewRedisCache(ctx, name, opts...) + default: + return NewInMemoryCache(ctx, name, opts...) + } +} diff --git a/pkg/cache/factory/factory_test.go b/pkg/cache/factory/factory_test.go new file mode 100644 index 000000000..e3647f01b --- /dev/null +++ b/pkg/cache/factory/factory_test.go @@ -0,0 +1,168 @@ +package factory + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gofr.dev/pkg/cache/observability" +) + +func TestNewInMemoryCache(t *testing.T) { + tests := []struct { + name string + cacheName string + opts []Option + expErr bool + }{ + { + name: "Successful creation with no options", + cacheName: "test-inmemory", + opts: nil, + expErr: false, + }, + { + name: "Successful creation with TTL and max items", + cacheName: "test-inmemory-with-config", + opts: []Option{WithTTL(5 * time.Minute), WithMaxItems(100)}, + expErr: false, + }, + { + name: "Successful creation with logger option", + cacheName: "test-inmemory-with-logger", + opts: []Option{WithObservabilityLogger(observability.NewStdLogger()), WithTTL(10 * time.Minute), WithMaxItems(50)}, + expErr: false, + }, + { + name: "Successful creation with metrics option", + cacheName: "test-inmemory-with-metrics", + opts: []Option{WithMetrics(observability.NewMetrics("test", "inmemory")), WithTTL(10 * time.Minute), WithMaxItems(50)}, + expErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + c, err := NewInMemoryCache(t.Context(), tt.cacheName, tt.opts...) + + if tt.expErr { + assert.Error(t, err, "Expected an error for %v", tt.name) + } else { + require.NoError(t, err, "Did not expect an error for %v", tt.name) + assert.NotNil(t, c, "Expected a cache instance for %v", tt.name) + } + }) + } +} + +func TestNewRedisCache(t *testing.T) { + tests := []struct { + name string + cacheName string + opts []Option + expErr bool + }{ + { + name: "Initialization without options", + cacheName: "test-redis", + opts: nil, + expErr: false, + }, + { + name: "Initialization with address and TTL", + cacheName: "test-redis-with-addr", + opts: []Option{WithRedisAddr("localhost:6379"), WithTTL(10 * time.Minute)}, + expErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + c, err := NewRedisCache(t.Context(), tt.cacheName, tt.opts...) + + if tt.expErr { + assert.Error(t, err, "Expected an error for %v", tt.name) + } else { + require.NoError(t, err, "Did not expect an error for %v", tt.name) + assert.NotNil(t, c, "Expected a cache instance for %v", tt.name) + } + }) + } +} + +func TestNewCache(t *testing.T) { + tests := []struct { + name string + cacheType string + opts []Option + expErr bool + }{ + { + name: "Create inmemory cache", + cacheType: "inmemory", + opts: []Option{WithTTL(5 * time.Minute), WithMaxItems(100)}, + expErr: false, + }, + { + name: "Create redis cache", + cacheType: "redis", + opts: []Option{WithTTL(5 * time.Minute)}, + expErr: false, + }, + { + name: "Create default cache (inmemory)", + cacheType: "unknown", + opts: []Option{WithTTL(5 * time.Minute), WithMaxItems(100)}, + expErr: false, + }, + { + name: "Create with empty type (default to inmemory)", + cacheType: "", + opts: []Option{WithTTL(5 * time.Minute), WithMaxItems(100)}, + expErr: false, + }, + { + name: "Create redis cache with options", + cacheType: "redis", + opts: []Option{WithRedisAddr("localhost:6379"), WithTTL(5 * time.Minute)}, + expErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + c, err := NewCache(t.Context(), tt.cacheType, "test-cache", tt.opts...) + + if tt.expErr { + assert.Error(t, err, "Expected an error for %v", tt.name) + } else { + require.NoError(t, err, "Did not expect an error for %v", tt.name) + assert.NotNil(t, c, "Expected a cache instance for %v", tt.name) + } + }) + } +} diff --git a/pkg/cache/go.mod b/pkg/cache/go.mod new file mode 100644 index 000000000..ef61e97a6 --- /dev/null +++ b/pkg/cache/go.mod @@ -0,0 +1,34 @@ +module gofr.dev/pkg/cache + +go 1.24.5 + +require ( + github.com/prometheus/client_golang v1.23.0 + github.com/redis/go-redis/v9 v9.12.1 + github.com/stretchr/testify v1.10.0 + go.opentelemetry.io/otel v1.37.0 + go.opentelemetry.io/otel/trace v1.37.0 + gofr.dev v1.43.0 +) + +replace gofr.dev => ../.. + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.65.0 // indirect + github.com/prometheus/procfs v0.17.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/term v0.34.0 // indirect + google.golang.org/protobuf v1.36.7 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/pkg/cache/go.sum b/pkg/cache/go.sum new file mode 100644 index 000000000..e8c20da1d --- /dev/null +++ b/pkg/cache/go.sum @@ -0,0 +1,62 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= +github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= +github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= +github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= +github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= +github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/cache/inmemory/inmemory.go b/pkg/cache/inmemory/inmemory.go new file mode 100644 index 000000000..fb747bb42 --- /dev/null +++ b/pkg/cache/inmemory/inmemory.go @@ -0,0 +1,539 @@ +package inmemory + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "gofr.dev/pkg/cache" + "gofr.dev/pkg/cache/observability" +) + +// Common errors. +var ( + // ErrCacheClosed is returned when an operation is attempted on a closed cache. + ErrCacheClosed = errors.New("cache is closed") + // ErrEmptyKey is returned when an operation is attempted with an empty key. + ErrEmptyKey = errors.New("key cannot be empty") + // ErrNilValue is returned when a nil value is provided to Set. + ErrNilValue = errors.New("value cannot be nil") + // ErrInvalidMaxItems is returned when a negative value is provided for maxItems. + ErrInvalidMaxItems = errors.New("maxItems must be non-negative") +) + +// node represents an element in the LRU doubly linked list. +// Used for O(1) insert, remove, and move-to-front operations. +type node struct { + key string + prev, next *node +} + +type entry struct { + value any + expiresAt time.Time + node *node +} + +type inMemoryCache struct { + mu sync.RWMutex + items map[string]entry + ttl time.Duration + maxItems int + quit chan struct{} + closed bool + + // LRU list head and tail + head, tail *node + + name string + logger observability.Logger + metrics *observability.Metrics + tracer *trace.Tracer +} + +type Option func(*inMemoryCache) error + +// WithTTL sets the default time-to-live (TTL) for all entries in the cache. +// Items will be automatically removed after this duration has passed since they were last set. +// A TTL of zero or less disables automatic expiration. +func WithTTL(ttl time.Duration) Option { + return func(c *inMemoryCache) error { + c.ttl = ttl + return nil + } +} + +// WithMaxItems sets the maximum number of items the cache can hold. +// When this limit is reached, the least recently used (LRU) item is evicted +// to make space for a new one. A value of 0 means no limit. +func WithMaxItems(maxItems int) Option { + return func(c *inMemoryCache) error { + if maxItems < 0 { + return ErrInvalidMaxItems + } + + c.maxItems = maxItems + + return nil + } +} + +// WithName sets a descriptive name for the cache instance. +// This name is used in logs and metrics to identify the cache. +func WithName(name string) Option { + return func(c *inMemoryCache) error { + if name != "" { + c.name = name + } + + return nil + } +} + +// WithLogger provides a custom logger for the cache. +// If not provided, a default standard library logger is used. +func WithLogger(logger observability.Logger) Option { + return func(c *inMemoryCache) error { + if logger != nil { + c.logger = logger + } + + return nil + } +} + +// WithMetrics provides a metrics collector for the cache. +// If provided, the cache will record metrics for operations like hits, misses, and sets. +func WithMetrics(m *observability.Metrics) Option { + return func(c *inMemoryCache) error { + if m != nil { + c.metrics = m + } + + return nil + } +} + +// validateKey ensures key is non-empty. +func validateKey(key string) error { + if key == "" { + return ErrEmptyKey + } + + return nil +} + +// withSpan creates a span for cache operations and ensures proper context propagation. +func (c *inMemoryCache) withSpan(ctx context.Context, operation, key string) (context.Context, trace.Span) { + if c.tracer == nil { + return ctx, trace.SpanFromContext(ctx) + } + + tracer := *c.tracer + spanCtx, span := tracer.Start(ctx, fmt.Sprintf("cache.%s", operation), + trace.WithAttributes( + attribute.String("cache.name", c.name), + attribute.String("cache.key", key), + attribute.String("cache.operation", operation), + )) + + return spanCtx, span +} + +// NewInMemoryCache creates and returns a new in-memory cache instance. +// It takes zero or more Option functions to customize its configuration. +// By default, it creates a cache with a 1-minute TTL and no item limit. +// It also starts a background goroutine for periodic cleanup of expired items. +func NewInMemoryCache(ctx context.Context, opts ...Option) (cache.Cache, error) { + c := &inMemoryCache{ + items: make(map[string]entry), + ttl: time.Minute, + maxItems: 0, + quit: make(chan struct{}), + logger: observability.NewStdLogger(), + metrics: observability.NewMetrics("gofr", "inmemory_cache"), + } + + for _, opt := range opts { + if err := opt(c); err != nil { + return nil, err + } + } + + c.logger.Infof(ctx, "Cache '%s' initialized with TTL=%s, MaxItems=%d", c.name, c.ttl, c.maxItems) + + // Start cleanup goroutine + go c.startCleanup(ctx) + + return c, nil +} + +func (c *inMemoryCache) UseTracer(tracer trace.Tracer) { + c.tracer = &tracer +} + +// Set adds or updates a key-value pair in the cache. +// If the key already exists, its value is updated, and it's marked as the most recently used item. +// If the cache is at capacity, the least recently used item is evicted. +// This operation is thread-safe. +func (c *inMemoryCache) Set(ctx context.Context, key string, value any) error { + spanCtx, span := c.withSpan(ctx, "set", key) + defer span.End() + + if err := validateKey(key); err != nil { + c.logger.Errorf(spanCtx, "Set failed: %v", err) + return err + } + + if value == nil { + c.logger.Errorf(spanCtx, "Set failed: %v", ErrNilValue) + return ErrNilValue + } + + now := time.Now() + + c.mu.Lock() + defer c.mu.Unlock() + + // Clean up expired entries (O(n) TTL scan) + removed := c.cleanupExpired(now) + if removed > 0 { + c.logger.Debugf(spanCtx, "Cleaned %d expired items during Set", removed) + } + + // If present, update value and move node to front (most recently used) + if ent, ok := c.items[key]; ok { + ent.value = value + ent.expiresAt = c.computeExpiry(now) + c.items[key] = ent + // O(1) move-to-front + c.moveToFront(ent.node) + + duration := time.Since(now) + c.logger.LogRequest(spanCtx, "INFO", "SET", "UPDATE", duration, key) + c.metrics.Sets().WithLabelValues(c.name).Inc() + c.metrics.Items().WithLabelValues(c.name).Set(float64(len(c.items))) + c.metrics.Latency().WithLabelValues(c.name, "set").Observe(duration.Seconds()) + + return nil + } + + // Evict if at capacity (O(1) eviction) + if c.maxItems > 0 && len(c.items) >= c.maxItems { + c.evictTail(spanCtx) + } + + // Insert new node at head (most recently used) + node := &node{key: key} + // O(1) insert at front + c.insertAtFront(node) + c.items[key] = entry{value: value, expiresAt: c.computeExpiry(now), node: node} + + duration := time.Since(now) + c.logger.LogRequest(spanCtx, "INFO", "SET", "CREATE", duration, key) + c.metrics.Sets().WithLabelValues(c.name).Inc() + c.metrics.Items().WithLabelValues(c.name).Set(float64(len(c.items))) + c.metrics.Latency().WithLabelValues(c.name, "set").Observe(duration.Seconds()) + + return nil +} + +// Get retrieves the value for a given key. +// If the key is found and not expired, it returns the value and true. +// It also marks the accessed item as the most recently used. +// If the key is not found or has expired, it returns nil and false. +// This operation is thread-safe. +func (c *inMemoryCache) Get(ctx context.Context, key string) (value any, found bool, err error) { + spanCtx, span := c.withSpan(ctx, "get", key) + defer span.End() + + if err := validateKey(key); err != nil { + c.logger.Errorf(spanCtx, "Get failed: %v", err) + return nil, false, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + ent, ok := c.items[key] + if !ok || time.Now().After(ent.expiresAt) { + if ok { + // Remove expired node (O(1)) + c.removeNode(ent.node) + delete(c.items, key) + } + + duration := time.Since(time.Now()) + c.logger.Missf(spanCtx, "GET", duration, key) + c.metrics.Misses().WithLabelValues(c.name).Inc() + c.metrics.Latency().WithLabelValues(c.name, "get").Observe(duration.Seconds()) + + return nil, false, nil + } + + // Hit: move node to front to mark as most recently used (O(1)) + start := time.Now() + + c.moveToFront(ent.node) + + duration := time.Since(start) // ✅ This now measures actual processing time + + c.logger.Hitf(spanCtx, "GET", duration, key) + c.metrics.Hits().WithLabelValues(c.name).Inc() + c.metrics.Latency().WithLabelValues(c.name, "get").Observe(duration.Seconds()) + + return ent.value, true, nil +} + +// Delete removes a key from the cache. +// If the key does not exist, the operation is a no-op. +// This operation is thread-safe. +func (c *inMemoryCache) Delete(ctx context.Context, key string) error { + spanCtx, span := c.withSpan(ctx, "delete", key) + defer span.End() + + if err := validateKey(key); err != nil { + c.logger.Errorf(spanCtx, "Delete failed: %v", err) + return err + } + + c.mu.Lock() + defer c.mu.Unlock() + + if ent, existed := c.items[key]; existed { + // O(1) removal from LRU list + c.removeNode(ent.node) + delete(c.items, key) + c.logger.Debugf(spanCtx, "Deleted key '%s'", key) + c.metrics.Deletes().WithLabelValues(c.name).Inc() + c.metrics.Items().WithLabelValues(c.name).Set(float64(len(c.items))) + } + + c.metrics.Latency().WithLabelValues(c.name, "delete").Observe(time.Since(time.Now()).Seconds()) + + return nil +} + +// Exists checks if a key exists in the cache and has not expired. +// It returns true if the key is present and valid, false otherwise. +// This operation does not update the item's recency. +// This operation is thread-safe. +func (c *inMemoryCache) Exists(ctx context.Context, key string) (bool, error) { + spanCtx, span := c.withSpan(ctx, "exists", key) + defer span.End() + + if err := validateKey(key); err != nil { + c.logger.Errorf(spanCtx, "Exists failed: %v", err) + return false, err + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if ent, ok := c.items[key]; ok && time.Now().Before(ent.expiresAt) { + return true, nil + } + + return false, nil +} + +// Clear removes all items from the cache. +// This operation is thread-safe. +func (c *inMemoryCache) Clear(ctx context.Context) error { + spanCtx, span := c.withSpan(ctx, "clear", "") + defer span.End() + + c.mu.Lock() + defer c.mu.Unlock() + + count := len(c.items) + c.items = make(map[string]entry) + c.head, c.tail = nil, nil + c.logger.Infof(spanCtx, "Cleared cache '%s', removed %d items", c.name, count) + c.metrics.Items().WithLabelValues(c.name).Set(0) + + return nil +} + +// Close stops the background cleanup goroutine and marks the cache as closed. +// Subsequent operations on the cache may fail. Calling Close on an already closed +// cache returns ErrCacheClosed. +// This operation is thread-safe. +func (c *inMemoryCache) Close(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + c.logger.Warnf(ctx, "Close called on already closed cache '%s'", c.name) + return ErrCacheClosed + } + + close(c.quit) + c.closed = true + c.logger.Infof(ctx, "Cache '%s' closed", c.name) + + return nil +} + +// returns expiration for now. +func (c *inMemoryCache) computeExpiry(now time.Time) time.Time { + if c.ttl <= 0 { + return now + } + + return now.Add(c.ttl) +} + +// unlinks expired entries. +func (c *inMemoryCache) cleanupExpired(now time.Time) int { + var removed int + + for k, ent := range c.items { + if now.After(ent.expiresAt) { + // O(1) remove from LRU list + c.removeNode(ent.node) + delete(c.items, k) + + removed++ + } + } + + return removed +} + +// removes the least-recently used item. +func (c *inMemoryCache) evictTail(ctx context.Context) { + if c.tail == nil { + return + } + + key := c.tail.key + c.removeNode(c.tail) + delete(c.items, key) + c.logger.Debugf(ctx, "Evicted key '%s'", key) + c.metrics.Evicts().WithLabelValues(c.name).Inc() + c.metrics.Items().WithLabelValues(c.name).Set(float64(len(c.items))) +} + +// places n at the head. +func (c *inMemoryCache) insertAtFront(n *node) { + n.prev = nil + n.next = c.head + + if c.head != nil { + c.head.prev = n + } + + c.head = n + + if c.tail == nil { + c.tail = n + } +} + +// unlinks then inserts n at head. +func (c *inMemoryCache) moveToFront(n *node) { + if c.head == n { + return + } + + c.removeNode(n) + + c.insertAtFront(n) +} + +// unlinks n from the list. +func (c *inMemoryCache) removeNode(n *node) { + if n.prev != nil { + n.prev.next = n.next + } else { + c.head = n.next + } + + if n.next != nil { + n.next.prev = n.prev + } else { + c.tail = n.prev + } + + n.prev, n.next = nil, nil +} + +// runs periodic TTL cleanup. +func (c *inMemoryCache) startCleanup(ctx context.Context) { + interval := c.ttl / 4 + if interval < 10*time.Second { + interval = 10 * time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + c.logger.Infof(ctx, "Started cleanup every %v for cache '%s'", interval, c.name) + + for { + select { + case <-ticker.C: + c.mu.Lock() + if !c.closed { + removed := c.cleanupExpired(time.Now()) + if removed > 0 { + c.logger.Debugf(ctx, "Cleanup removed %d items, remaining %d", removed, len(c.items)) + } + } + c.mu.Unlock() + + case <-ctx.Done(): + c.logger.Infof(ctx, "Context canceled: stopping cleanup for cache '%s'", c.name) + return + + case <-c.quit: + c.logger.Infof(ctx, "Quit channel closed: stopping cleanup for cache '%s'", c.name) + return + } + } +} + +const ( + DefaultTTL = 5 * time.Minute + DefaultMaxItems = 1000 + DebugMaxItems = 100 +) + +// NewDefaultCache creates a cache with sensible default settings for general use. +// It is configured with a 5-minute TTL and a 1000-item limit. +func NewDefaultCache(ctx context.Context, name string) (cache.Cache, error) { + return NewInMemoryCache( + ctx, + WithName(name), + WithTTL(DefaultTTL), + WithMaxItems(DefaultMaxItems), + ) +} + +// NewDebugCache creates a cache with settings suitable for debugging. +// It has a short TTL (1 minute) and a small capacity (100 items). +func NewDebugCache(ctx context.Context, name string) (cache.Cache, error) { + return NewInMemoryCache( + ctx, + WithName(name), + WithTTL(1*time.Minute), + WithMaxItems(DebugMaxItems), + ) +} + +// NewProductionCache creates a cache with settings suitable for production environments. +// It requires explicit configuration for TTL and maximum item count. +func NewProductionCache(ctx context.Context, name string, ttl time.Duration, maxItems int) (cache.Cache, error) { + return NewInMemoryCache( + ctx, + WithName(name), + WithTTL(ttl), + WithMaxItems(maxItems), + ) +} diff --git a/pkg/cache/inmemory/inmemory_test.go b/pkg/cache/inmemory/inmemory_test.go new file mode 100644 index 000000000..437ef741d --- /dev/null +++ b/pkg/cache/inmemory/inmemory_test.go @@ -0,0 +1,758 @@ +package inmemory + +import ( + "context" + "runtime" + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// makeCache initializes the cache and fails the test on error. +func makeCache(ctx context.Context, t *testing.T, opts ...Option) *inMemoryCache { + t.Helper() + + ci, err := NewInMemoryCache(ctx, opts...) + require.NoError(t, err, "failed to initialize cache") + + return ci.(*inMemoryCache) +} + +// Test basic Set/Get/Delete/Exists operations. +func TestOperations(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithName("name"), WithTTL(5*time.Second), WithMaxItems(10)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "key1", 10)) + + v, found, err := c.Get(ctx, "key1") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, 10, v) + + exists, err := c.Exists(ctx, "key1") + require.NoError(t, err) + assert.True(t, exists) + + require.NoError(t, c.Delete(ctx, "key1")) + + exists, err = c.Exists(ctx, "key1") + require.NoError(t, err) + assert.False(t, exists) +} + +// Test Clear method. +func TestClear(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(10)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "x", 1)) + require.NoError(t, c.Set(ctx, "y", 2)) + + require.NoError(t, c.Clear(ctx)) + + for _, k := range []string{"x", "y"} { + exist, err := c.Exists(ctx, k) + require.NoError(t, err) + assert.False(t, exist) + } +} + +// Test TTL expiration. +func TestTTLExpiry(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(50*time.Millisecond), WithMaxItems(10)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "foo", "bar")) + time.Sleep(60 * time.Millisecond) + + _, found, err := c.Get(ctx, "foo") + require.NoError(t, err) + assert.False(t, found) +} + +// Test eviction due to capacity. +func TestCapacityEviction(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(2)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "k1", 1)) + time.Sleep(time.Millisecond) + require.NoError(t, c.Set(ctx, "k2", 2)) + _, _, err := c.Get(ctx, "k1") // Access to keep recent + require.NoError(t, err) + require.NoError(t, c.Set(ctx, "k3", 3)) + + exists, err := c.Exists(ctx, "k2") + require.NoError(t, err) + assert.False(t, exists) +} + +// Test overwriting existing key. +func TestOverwrite(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(5*time.Second), WithMaxItems(10)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "dupKey", "first")) + require.NoError(t, c.Set(ctx, "dupKey", "second")) + + v, found, err := c.Get(ctx, "dupKey") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "second", v) +} + +// Test deleting non-existent key. +func TestDeleteNonExistent(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(5*time.Second), WithMaxItems(10)) + defer c.Close(ctx) + + err := c.Delete(ctx, "ghost") + require.NoError(t, err) +} + +// Test clearing empty cache. +func TestClearEmpty(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(5*time.Second), WithMaxItems(10)) + defer c.Close(ctx) + + err := c.Clear(ctx) + require.NoError(t, err) +} + +// Test concurrent Set/Get/Exists. +func TestConcurrentAccess(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(5*time.Second), WithMaxItems(10)) + defer c.Close(ctx) + + var wg sync.WaitGroup + + f := func() { + defer wg.Done() + require.NoError(t, c.Set(ctx, "concurrent", "safe")) + _, _, err := c.Get(ctx, "concurrent") + require.NoError(t, err) + _, err = c.Exists(ctx, "concurrent") + require.NoError(t, err) + } + + for i := 0; i < 100; i++ { + wg.Add(1) + + go f() + } + + wg.Wait() +} + +// Test cleanup removes expired before eviction. +func TestEvictionEdgeCase(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(100*time.Millisecond), WithMaxItems(2)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "a", 1)) + time.Sleep(110 * time.Millisecond) + require.NoError(t, c.Set(ctx, "b", 2)) + require.NoError(t, c.Set(ctx, "c", 3)) + + existsB, err := c.Exists(ctx, "b") + require.NoError(t, err) + assert.True(t, existsB) + + existsC, err := c.Exists(ctx, "c") + require.NoError(t, err) + assert.True(t, existsC) +} + +// Test default configuration values. +func TestDefaultConfiguration(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + + ci, err := NewInMemoryCache(ctx) + require.NoError(t, err) + + c := ci.(*inMemoryCache) + defer c.Close(ctx) + + assert.Equal(t, time.Minute, c.ttl) + assert.Equal(t, int(0), c.maxItems) + + require.NoError(t, c.Set(ctx, "test", "value")) + v, found, err := c.Get(ctx, "test") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value", v) +} + +// Last option wins. +func TestMultipleOptions(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, + WithTTL(30*time.Second), + WithMaxItems(5), + WithTTL(60*time.Second), + ) + defer c.Close(ctx) + + assert.Equal(t, 60*time.Second, c.ttl) + assert.Equal(t, 5, c.maxItems) +} + +// TTL=0 should expire immediately. +func TestZeroTTL(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(0)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "immediate", "expire")) + _, found, _ := c.Get(ctx, "immediate") + assert.False(t, found) +} + +// TTL<0 should expire immediately. +func TestNegativeTTL(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(-time.Second)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "neg", "ttl")) + _, found, _ := c.Get(ctx, "neg") + assert.False(t, found) +} + +// maxItems=0 means unlimited. +func TestUnlimitedCapacity(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(0)) + defer c.Close(ctx) + + for i := 0; i < 500; i++ { + require.NoError(t, c.Set(ctx, string(rune(i)), i)) + } + + count := 0 + + for i := 0; i < 500; i++ { + exist, _ := c.Exists(ctx, string(rune(i))) + if exist { + count++ + } + } + + assert.Equal(t, 500, count) +} + +// maxItems=1 should only allow one item. +func TestSingleItemCapacity(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(1)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "first", 1)) + require.NoError(t, c.Set(ctx, "second", 2)) + + exist1, _ := c.Exists(ctx, "first") + exist2, _ := c.Exists(ctx, "second") + + assert.NotEqual(t, exist1, exist2) +} + +// Test LRU eviction order. +func TestLRUEvictionOrder(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(3)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "a", 1)) + require.NoError(t, c.Set(ctx, "b", 2)) + require.NoError(t, c.Set(ctx, "c", 3)) + _, _, err := c.Get(ctx, "a") + require.NoError(t, err) + require.NoError(t, c.Set(ctx, "d", 4)) + + existB, _ := c.Exists(ctx, "b") + assert.False(t, existB) +} + +// Updating key should refresh its usage. +func TestUpdateExistingKeyTiming(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(2)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "old", 1)) + require.NoError(t, c.Set(ctx, "new", 2)) + require.NoError(t, c.Set(ctx, "old", 10)) + require.NoError(t, c.Set(ctx, "third", 3)) + + existNew, _ := c.Exists(ctx, "new") + assert.False(t, existNew) +} + +// Support for multiple Go types. +func TestDifferentValueTypes(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute)) + defer c.Close(ctx) + + values := map[string]any{ + "str": "hello", + "int": 42, + "flt": 3.14, + "bool": true, + "slice": []int{1, 2, 3}, + "map": map[string]int{"k": 123}, + } + + err := c.Set(ctx, "nilval", nil) + require.Error(t, err, "Expected an error when setting a nil value") + + for k, val := range values { + require.NoError(t, c.Set(ctx, k, val)) + } + + for k, expected := range values { + v, found, _ := c.Get(ctx, k) + assert.True(t, found) + assert.Equal(t, expected, v) + } + + _, found, _ := c.Get(ctx, "nilval") + assert.False(t, found) +} + +// Using empty key should error. +func TestEmptyStringKey(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute)) + defer c.Close(ctx) + + err := c.Set(ctx, "", "v") + require.ErrorIs(t, err, ErrEmptyKey) +} + +// Long keys are supported. +func TestLongKey(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute)) + defer c.Close(ctx) + + longKey := string(make([]byte, 10000)) + err := c.Set(ctx, longKey, "v") + require.NoError(t, err) + + v, found, _ := c.Get(ctx, longKey) + assert.True(t, found) + assert.Equal(t, "v", v) +} + +// Concurrent Set on same key. +func TestConcurrentEviction(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(10)) + defer c.Close(ctx) + + for i := 0; i < 10; i++ { + require.NoError(t, c.Set(ctx, string(rune(i)), i)) + } + + var wg sync.WaitGroup + + f := func(val int) { + defer wg.Done() + require.NoError(t, c.Set(ctx, string(rune(val+100)), val)) + _, _, err := c.Get(ctx, string(rune(val%10))) + require.NoError(t, err) + _, err = c.Exists(ctx, string(rune(val%10))) + require.NoError(t, err) + } + + for i := 0; i < 50; i++ { + wg.Add(1) + + go f(i) + } + + wg.Wait() + + count := 0 + + for i := 0; i < 200; i++ { + if exist, _ := c.Exists(ctx, string(rune(i))); exist { + count++ + } + } + + assert.LessOrEqual(t, count, 10) +} + +// Cleanup goroutine should stop after Close. +func TestCleanupGoroutineStops(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + before := runtime.NumGoroutine() + c := makeCache(ctx, t, WithTTL(time.Millisecond)) + time.Sleep(10 * time.Millisecond) + c.Close(ctx) + time.Sleep(50 * time.Millisecond) + + after := runtime.NumGoroutine() + assert.LessOrEqual(t, after, before) +} + +// Calling Close multiple times is safe. +func TestMultipleClose(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute)) + c.Close(ctx) + c.Close(ctx) + c.Close(ctx) +} + +// Set/Get still work after Close. +func TestOperationsAfterClose(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute)) + require.NoError(t, c.Set(ctx, "pre", "close")) + c.Close(ctx) + require.NoError(t, c.Set(ctx, "post", "close")) + + v1, found1, _ := c.Get(ctx, "pre") + assert.True(t, found1) + assert.Equal(t, "close", v1) + + v2, found2, _ := c.Get(ctx, "post") + assert.True(t, found2) + assert.Equal(t, "close", v2) +} + +// Expired items should be cleaned in background. +func TestCleanupFrequency(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(100*time.Millisecond)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "expire_me", "v")) + time.Sleep(250 * time.Millisecond) + + exists, err := c.Exists(ctx, "expire_me") + require.NoError(t, err) + assert.False(t, exists) +} + +// Exists should clean expired keys. +func TestExistsWithExpiredItems(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(50*time.Millisecond)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "short", "v")) + + exists, err := c.Exists(ctx, "short") + require.NoError(t, err) + assert.True(t, exists) + + time.Sleep(60 * time.Millisecond) + + exists, err = c.Exists(ctx, "short") + require.NoError(t, err) + assert.False(t, exists) +} + +// Cleanup should free space for new items. +func TestPartialEvictionWithExpiredItems(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(100*time.Millisecond), WithMaxItems(3)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "a", 1)) + require.NoError(t, c.Set(ctx, "b", 2)) + time.Sleep(110 * time.Millisecond) + require.NoError(t, c.Set(ctx, "c", 3)) + require.NoError(t, c.Set(ctx, "d", 4)) + require.NoError(t, c.Set(ctx, "e", 5)) + + for _, k := range []string{"c", "d", "e"} { + exists, err := c.Exists(ctx, k) + require.NoError(t, err) + assert.True(t, exists) + } +} + +// Get should update lastUsed for LRU. +func TestGetUpdatesLastUsed(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(2)) + defer c.Close(ctx) + + require.NoError(t, c.Set(ctx, "x", 1)) + require.NoError(t, c.Set(ctx, "y", 2)) + _, _, err := c.Get(ctx, "x") + require.NoError(t, err) + require.NoError(t, c.Set(ctx, "z", 3)) + + exists, err := c.Exists(ctx, "y") + require.NoError(t, err) + assert.False(t, exists) +} + +// Stress test with mixed operations. +func TestHighVolumeOperations(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeCache(ctx, t, WithTTL(time.Minute), WithMaxItems(1000)) + defer c.Close(ctx) + + var wg sync.WaitGroup + + ops := 1000 + + f := func(id int) { + defer wg.Done() + + key := string(rune(id % 100)) + val := id + + require.NoError(t, c.Set(ctx, key, val)) + _, _, err := c.Get(ctx, key) + require.NoError(t, err) + _, err = c.Exists(ctx, key) + require.NoError(t, err) + + if id%10 == 0 { + require.NoError(t, c.Delete(ctx, key)) + } + } + + for i := 0; i < ops; i++ { + wg.Add(1) + + go f(i) + } + + wg.Wait() + + cnt := 0 + + for i := 0; i < 100; i++ { + if exist, _ := c.Exists(ctx, string(rune(i))); exist { + cnt++ + } + } + + assert.LessOrEqual(t, cnt, 100) +} diff --git a/pkg/cache/monitoring/README.md b/pkg/cache/monitoring/README.md new file mode 100644 index 000000000..dfb36a038 --- /dev/null +++ b/pkg/cache/monitoring/README.md @@ -0,0 +1,30 @@ +# Pre-configured Monitoring for Gofr Cache + +This directory contains a pre-configured Docker Compose setup to run Prometheus and Grafana for monitoring your application's cache metrics out-of-the-box. + +## Features + +- **Prometheus**: Pre-configured to scrape metrics from your Gofr application. +- **Grafana**: Pre-provisioned with a Prometheus data source and a dashboard for cache metrics. + +## Quick Start + +### 1. Prerequisites +- Docker and Docker Compose are installed. +- Your Gofr application is running and exposes Prometheus metrics on `/metrics` (this is default for Gofr apps). + +### 2. Run the Monitoring Stack + +Navigate to this directory and run: +```sh +docker-compose up --build +``` +- **Grafana**: Will be available at [http://localhost:3000](http://localhost:3000) (user: `admin`, pass: `admin`) +- **Prometheus**: Will be available at [http://localhost:9090](http://localhost:9090) +- **App Metrics**: Your app should expose metrics at a port, e.g., [http://localhost:8080/metrics](http://localhost:8080/metrics) + +### How it Works + +The `prometheus.yml` is configured to scrape metrics from `host.docker.internal:8080`. `host.docker.internal` is a special DNS name that resolves to the host machine's IP address from within a Docker container. If your application runs on a different port, you can modify `prometheus.yml`. + +The Grafana service is provisioned with the Prometheus data source and a dashboard defined in `provisioning/`. \ No newline at end of file diff --git a/pkg/cache/monitoring/docker-compose.yml b/pkg/cache/monitoring/docker-compose.yml new file mode 100644 index 000000000..760bbd47a --- /dev/null +++ b/pkg/cache/monitoring/docker-compose.yml @@ -0,0 +1,23 @@ +services: + prometheus: + image: prom/prometheus:latest + container_name: prometheus + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml + ports: + - "9090:9090" + restart: unless-stopped + + grafana: + image: grafana/grafana:latest + container_name: grafana + ports: + - "3000:3000" + volumes: + - ./provisioning/datasources:/etc/grafana/provisioning/datasources + - ./provisioning/dashboards:/etc/grafana/provisioning/dashboards + environment: + - GF_SECURITY_ADMIN_PASSWORD=admin + depends_on: + - prometheus + restart: unless-stopped \ No newline at end of file diff --git a/pkg/cache/monitoring/prometheus.yml b/pkg/cache/monitoring/prometheus.yml new file mode 100644 index 000000000..4884d3917 --- /dev/null +++ b/pkg/cache/monitoring/prometheus.yml @@ -0,0 +1,6 @@ +global: + scrape_interval: 5s +scrape_configs: + - job_name: 'gofr-cache' + static_configs: + - targets: ['host.docker.internal:8080'] \ No newline at end of file diff --git a/pkg/cache/monitoring/provisioning/dashboards/cache-metrics.json b/pkg/cache/monitoring/provisioning/dashboards/cache-metrics.json new file mode 100644 index 000000000..8814b97d5 --- /dev/null +++ b/pkg/cache/monitoring/provisioning/dashboards/cache-metrics.json @@ -0,0 +1,60 @@ +{ + "dashboard": { + "id": null, + "title": "Gofr InMemory Cache Metrics", + "panels": [ + { + "type": "timeseries", + "title": "Cache Hits", + "targets": [ + {"expr": "gofr_inmemory_cache_hits_total", "legendFormat": "{{cache_name}}"} + ], + "gridPos": {"h": 8, "w": 12, "x": 0, "y": 0} + }, + { + "type": "timeseries", + "title": "Cache Misses", + "targets": [ + {"expr": "gofr_inmemory_cache_misses_total", "legendFormat": "{{cache_name}}"} + ], + "gridPos": {"h": 8, "w": 12, "x": 12, "y": 0} + }, + { + "type": "timeseries", + "title": "Cache Sets", + "targets": [ + {"expr": "gofr_inmemory_cache_sets_total", "legendFormat": "{{cache_name}}"} + ], + "gridPos": {"h": 8, "w": 12, "x": 0, "y": 8} + }, + { + "type": "timeseries", + "title": "Cache Deletes", + "targets": [ + {"expr": "gofr_inmemory_cache_deletes_total", "legendFormat": "{{cache_name}}"} + ], + "gridPos": {"h": 8, "w": 12, "x": 12, "y": 8} + }, + { + "type": "timeseries", + "title": "Cache Evictions", + "targets": [ + {"expr": "gofr_inmemory_cache_evictions_total", "legendFormat": "{{cache_name}}"} + ], + "gridPos": {"h": 8, "w": 12, "x": 0, "y": 16} + }, + { + "type": "timeseries", + "title": "Current Item Count", + "targets": [ + {"expr": "gofr_inmemory_cache_items_current", "legendFormat": "{{cache_name}}"} + ], + "gridPos": {"h": 8, "w": 12, "x": 12, "y": 16} + } + ], + "schemaVersion": 36, + "version": 1, + "refresh": "5s" + }, + "overwrite": true + } \ No newline at end of file diff --git a/pkg/cache/monitoring/provisioning/dashboards/dashboard.yaml b/pkg/cache/monitoring/provisioning/dashboards/dashboard.yaml new file mode 100644 index 000000000..34315b245 --- /dev/null +++ b/pkg/cache/monitoring/provisioning/dashboards/dashboard.yaml @@ -0,0 +1,13 @@ +apiVersion: 1 +providers: + - name: 'default' + orgId: 1 + folder: '' + type: file + disableDeletion: false + editable: true + options: + path: /etc/grafana/provisioning/dashboards + foldersFromFilesStructure: false + dashboards: + - file: /etc/grafana/provisioning/dashboards/cache-metrics.json \ No newline at end of file diff --git a/pkg/cache/monitoring/provisioning/datasources/prometheus.yaml b/pkg/cache/monitoring/provisioning/datasources/prometheus.yaml new file mode 100644 index 000000000..8049912b1 --- /dev/null +++ b/pkg/cache/monitoring/provisioning/datasources/prometheus.yaml @@ -0,0 +1,8 @@ +apiVersion: 1 + +datasources: + - name: Prometheus + type: prometheus + access: proxy + url: http://prometheus:9090 + isDefault: true \ No newline at end of file diff --git a/pkg/cache/observability/logger.go b/pkg/cache/observability/logger.go new file mode 100644 index 000000000..b7a4e23c2 --- /dev/null +++ b/pkg/cache/observability/logger.go @@ -0,0 +1,236 @@ +package observability + +import ( + "context" + "fmt" + "os" + "regexp" + "strings" + "time" + + "go.opentelemetry.io/otel/trace" +) + +// ANSI Color Codes. +const ( + reset = "\033[0m" + red = "\033[31m" + green = "\033[32m" + yellow = "\033[33m" + blue = "\033[34m" + cyan = "\033[36m" + gray = "\033[90m" +) + +const ( + infoLevel = "INFO" + warnLevel = "WARN" + errorLevel = "ERROR" + debugLevel = "DEBUG" +) + +var ansiRegex = regexp.MustCompile("[\u001B\u009B][[]()#;?]*.{0,2}(?:(?:;\\d{1,3})*.[a-zA-Z\\d]|(?:\\d{1,4}/?)*[a-zA-Z])") + +// Logger defines a standard interface for logging with built-in context awareness. +type Logger interface { + // Standard logging methods with context support + Errorf(ctx context.Context, format string, args ...any) + Warnf(ctx context.Context, format string, args ...any) + Infof(ctx context.Context, format string, args ...any) + Debugf(ctx context.Context, format string, args ...any) + + // Cache-specific logging methods + Hitf(ctx context.Context, message string, duration time.Duration, operation string) + Missf(ctx context.Context, message string, duration time.Duration, operation string) + + // Generic structured logging method + LogRequest(ctx context.Context, level, message string, tag any, duration time.Duration, operation string) +} + +type nopLogger struct{} + +func NewNopLogger() Logger { return &nopLogger{} } + +func (*nopLogger) Errorf(_ context.Context, _ string, _ ...any) {} +func (*nopLogger) Warnf(_ context.Context, _ string, _ ...any) {} +func (*nopLogger) Infof(_ context.Context, _ string, _ ...any) {} +func (*nopLogger) Debugf(_ context.Context, _ string, _ ...any) {} +func (*nopLogger) Hitf(_ context.Context, _ string, _ time.Duration, _ string) {} +func (*nopLogger) Missf(_ context.Context, _ string, _ time.Duration, _ string) {} +func (*nopLogger) LogRequest(_ context.Context, _, _ string, _ any, _ time.Duration, _ string) {} + +type styledLogger struct { + useColors bool +} + +func NewStdLogger() Logger { + return &styledLogger{ + useColors: isTerminal(), + } +} + +func (l *styledLogger) getTraceString(ctx context.Context) string { + if ctx == nil { + return "" + } + + sc := trace.SpanFromContext(ctx).SpanContext() + if sc.IsValid() { + return " " + l.applyColor(gray, sc.TraceID().String()) + } + + return "" +} + +func (l *styledLogger) Errorf(ctx context.Context, format string, args ...any) { + l.logSimple(ctx, errorLevel, red, format, args...) +} + +func (l *styledLogger) Warnf(ctx context.Context, format string, args ...any) { + l.logSimple(ctx, warnLevel, yellow, format, args...) +} + +func (l *styledLogger) Infof(ctx context.Context, format string, args ...any) { + l.logSimple(ctx, infoLevel, green, format, args...) +} + +func (l *styledLogger) Debugf(ctx context.Context, format string, args ...any) { + l.logSimple(ctx, debugLevel, gray, format, args...) +} + +func (l *styledLogger) Hitf(ctx context.Context, _ string, duration time.Duration, operation string) { + l.LogRequest(ctx, infoLevel, "Cache hit", "HIT", duration, operation) +} + +func (l *styledLogger) Missf(ctx context.Context, _ string, duration time.Duration, operation string) { + // A miss isn't an error, but we'll color its tag yellow for attention. + l.LogRequest(ctx, infoLevel, "Cache miss", "MISS", duration, operation) +} + +func (l *styledLogger) LogRequest(ctx context.Context, level, message string, tag any, duration time.Duration, operation string) { + const tagColumnStart = 45 + + const durationColumnStart = 60 + + levelStr, levelColor := getLevelStyle(level) + ts := l.applyColor(gray, "["+time.Now().Format(time.TimeOnly)+"]") + traceStr := l.getTraceString(ctx) + initialPart := fmt.Sprintf("%s %s%s %s", l.applyColor(levelColor, levelStr), ts, traceStr, message) + + tagStr := l.formatTag(tag) + durationStr := l.applyColor(gray, fmt.Sprintf("%dµs", duration.Microseconds())) + + padding1 := getPadding(tagColumnStart, len(stripAnsi(initialPart))) + padding2 := getPadding(durationColumnStart, tagColumnStart+len(stripAnsi(tagStr))) + + fmt.Printf("%s%s%s%s%s %s\n", + initialPart, + padding1, + tagStr, + padding2, + durationStr, + operation, + ) +} + +func (l *styledLogger) logSimple(ctx context.Context, level, color, format string, args ...any) { + levelStr := l.applyColor(color, level) + ts := l.applyColor(gray, "["+time.Now().Format(time.TimeOnly)+"]") + traceStr := l.getTraceString(ctx) + msg := fmt.Sprintf(format, args...) + fmt.Printf("%s %s%s %s\n", levelStr, ts, traceStr, msg) +} + +func getLevelStyle(level string) (levelStr, color string) { + switch level { + case errorLevel: + return "ERROR", red + case warnLevel: + return "WARN", yellow + case infoLevel: + return "INFO", green + case debugLevel: + return "DEBUG", gray + default: + return level, reset + } +} + +func (l *styledLogger) formatTag(tag any) string { + switch t := tag.(type) { + case int: + return l.formatIntTag(t) + case string: + return l.formatStringTag(t) + default: + return fmt.Sprint(tag) + } +} + +const ( + statusOKRangeStart = 200 + statusOKRangeEnd = 300 + statusClientErrorRangeStart = 400 + statusClientErrorRangeEnd = 500 + statusServerErrorRangeStart = 500 +) + +func (l *styledLogger) formatIntTag(t int) string { + var color string + if t >= statusOKRangeStart && t < statusOKRangeEnd { + color = green + } else if t >= statusClientErrorRangeStart && t < statusClientErrorRangeEnd { + color = yellow + } else if t >= statusServerErrorRangeStart { + color = red + } else { + color = gray + } + + return l.applyColor(color, fmt.Sprintf("%d", t)) +} + +func (l *styledLogger) formatStringTag(t string) string { + var color string + + switch t { + case "HIT": + color = green + case "MISS": + color = yellow + case "REDIS": + color = blue + case "SQL": + color = cyan + default: + color = gray + } + + return l.applyColor(color, t) +} + +func getPadding(columnTarget, currentLength int) string { + if padLen := columnTarget - currentLength; padLen > 0 { + return strings.Repeat(" ", padLen) + } + + return " " +} + +func (l *styledLogger) applyColor(color, text string) string { + if !l.useColors { + return text + } + + return color + text + reset +} + +func isTerminal() bool { + fi, _ := os.Stdout.Stat() + return (fi.Mode() & os.ModeCharDevice) != 0 +} + +// stripAnsi removes ANSI escape codes from a string. +func stripAnsi(str string) string { + return ansiRegex.ReplaceAllString(str, "") +} diff --git a/pkg/cache/observability/metrics.go b/pkg/cache/observability/metrics.go new file mode 100644 index 000000000..61091a7d2 --- /dev/null +++ b/pkg/cache/observability/metrics.go @@ -0,0 +1,131 @@ +package observability + +import ( + "sync" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Metrics encapsulates a set of Prometheus metrics for monitoring a cache. +// It includes counters for hits, misses, sets, deletes, and evictions, +// a gauge for the current number of items, and a histogram for operation latency. +// All metrics are labeled with 'cache_name'. +type Metrics struct { + hits *prometheus.CounterVec + misses *prometheus.CounterVec + sets *prometheus.CounterVec + deletes *prometheus.CounterVec + evicts *prometheus.CounterVec + items *prometheus.GaugeVec + latency *prometheus.HistogramVec +} + +type metricsRegistry struct { + mtx sync.Mutex + singletons map[string]*Metrics +} + +// getRegistry returns the singleton metricsRegistry instance. +// It uses a function-scoped sync.Once and closure to avoid global variables. +func getRegistry() *metricsRegistry { + var ( + once sync.Once + reg *metricsRegistry + ) + + get := func() *metricsRegistry { + once.Do(func() { + reg = &metricsRegistry{ + singletons: make(map[string]*Metrics), + } + }) + + return reg + } + + return get() +} + +// NewMetrics creates or retrieves a singleton Metrics instance for a given namespace and subsystem. +// This ensures that metrics are registered with Prometheus only once per application lifecycle. +func NewMetrics(namespace, subsystem string) *Metrics { + key := namespace + "/" + subsystem + reg := getRegistry() + + reg.mtx.Lock() + defer reg.mtx.Unlock() + + if m, ok := reg.singletons[key]; ok { + return m + } + + // Create and register exactly once. + factory := promauto.With(prometheus.DefaultRegisterer) + m := &Metrics{ + hits: factory.NewCounterVec( + prometheus.CounterOpts{Namespace: namespace, Subsystem: subsystem, Name: "hits_total", Help: "Total number of cache hits."}, + []string{"cache_name"}, + ), + misses: factory.NewCounterVec( + prometheus.CounterOpts{Namespace: namespace, Subsystem: subsystem, Name: "misses_total", Help: "Total number of cache misses."}, + []string{"cache_name"}, + ), + sets: factory.NewCounterVec( + prometheus.CounterOpts{Namespace: namespace, Subsystem: subsystem, Name: "sets_total", Help: "Total number of set operations."}, + []string{"cache_name"}, + ), + deletes: factory.NewCounterVec( + prometheus.CounterOpts{Namespace: namespace, Subsystem: subsystem, Name: "deletes_total", Help: "Total number of delete operations."}, + []string{"cache_name"}, + ), + evicts: factory.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "evictions_total", + Help: "Total number of items evicted from the cache.", + }, + []string{"cache_name"}, + ), + items: factory.NewGaugeVec( + prometheus.GaugeOpts{Namespace: namespace, Subsystem: subsystem, Name: "items_current", Help: "Current number of items in the cache."}, + []string{"cache_name"}, + ), + latency: factory.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "operation_latency_seconds", + Help: "Latency of cache operations in seconds.", + Buckets: prometheus.DefBuckets, + }, + []string{"cache_name", "operation"}, + ), + } + + reg.singletons[key] = m + + return m +} + +// Hits returns the counter for cache hits. +func (m *Metrics) Hits() *prometheus.CounterVec { return m.hits } + +// Misses returns the counter for cache misses. +func (m *Metrics) Misses() *prometheus.CounterVec { return m.misses } + +// Sets returns the counter for set operations. +func (m *Metrics) Sets() *prometheus.CounterVec { return m.sets } + +// Deletes returns the counter for delete operations. +func (m *Metrics) Deletes() *prometheus.CounterVec { return m.deletes } + +// Evicts returns the counter for cache evictions. +func (m *Metrics) Evicts() *prometheus.CounterVec { return m.evicts } + +// Items returns the gauge for the current number of items in the cache. +func (m *Metrics) Items() *prometheus.GaugeVec { return m.items } + +// Latency returns the histogram for cache operation latencies. +func (m *Metrics) Latency() *prometheus.HistogramVec { return m.latency } diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go new file mode 100644 index 000000000..df45fc7ea --- /dev/null +++ b/pkg/cache/redis/redis.go @@ -0,0 +1,394 @@ +package redis + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" + "go.opentelemetry.io/otel/trace" + + "gofr.dev/pkg/cache" + "gofr.dev/pkg/cache/observability" +) + +// Common errors. +var ( + // ErrEmptyKey is returned when an operation is attempted with an empty key. + ErrEmptyKey = errors.New("key cannot be empty") + // ErrNilValue is returned when a nil value is provided to Set. + ErrNilValue = errors.New("value cannot be nil") + // ErrNilClient is returned when the Redis client is not initialized. + ErrNilClient = errors.New("redis client is nil") + // ErrAddressEmpty is returned when an empty address is provided. + ErrAddressEmpty = errors.New("address cannot be empty") + // ErrInvalidDatabaseNumber is returned when a database number outside the valid range is provided. + ErrInvalidDatabaseNumber = errors.New("database number must be between 0 and 255") + // ErrNegativeTTL is returned when a negative TTL is provided. + ErrNegativeTTL = errors.New("TTL cannot be negative") +) + +type redisCache struct { + client *redis.Client + ttl time.Duration + name string + logger observability.Logger + metrics *observability.Metrics + tracer *trace.Tracer +} + +type Option func(*redisCache) error + +// WithAddr sets the network address of the Redis server (e.g., "localhost:6379"). +func WithAddr(addr string) Option { + return func(c *redisCache) error { + if addr == "" { + return ErrAddressEmpty + } + + opts := c.client.Options() + opts.Addr = addr + c.client = redis.NewClient(opts) + + return nil + } +} + +// WithPassword sets the password for authenticating with the Redis server. +func WithPassword(password string) Option { + return func(c *redisCache) error { + opts := c.client.Options() + opts.Password = password + c.client = redis.NewClient(opts) + + return nil + } +} + +// WithDB sets the Redis database number to use. +// The database number must be between 0 and 255. +func WithDB(db int) Option { + return func(c *redisCache) error { + if db < 0 || db > 255 { + return ErrInvalidDatabaseNumber + } + + opts := c.client.Options() + opts.DB = db + c.client = redis.NewClient(opts) + + return nil + } +} + +// WithTTL sets the default time-to-live (TTL) for all entries in the cache. +// Redis will automatically remove items after this duration. +// A TTL of zero means items will not expire. +func WithTTL(ttl time.Duration) Option { + return func(c *redisCache) error { + if ttl < 0 { + return ErrNegativeTTL + } + + c.ttl = ttl + + return nil + } +} + +// WithName sets a descriptive name for the cache instance. +// This name is used in logs and metrics to identify the cache. +func WithName(name string) Option { + return func(c *redisCache) error { + if name != "" { + c.name = name + } + + return nil + } +} + +// WithLogger provides a custom logger for the cache. +// If not provided, a default standard library logger is used. +func WithLogger(logger observability.Logger) Option { + return func(c *redisCache) error { + if logger != nil { + c.logger = logger + } + + return nil + } +} + +// WithMetrics provides a metrics collector for the cache. +// If provided, the cache will record metrics for its operations. +func WithMetrics(m *observability.Metrics) Option { + return func(c *redisCache) error { + if m != nil { + c.metrics = m + } + + return nil + } +} + +// NewRedisCache creates and returns a new Redis-backed cache instance. +// It establishes a connection to the Redis server and pings it to ensure connectivity. +// It takes zero or more Option functions to customize its configuration. +// By default, it connects to "localhost:6379" with a 1-minute TTL. +func NewRedisCache(ctx context.Context, opts ...Option) (cache.Cache, error) { + // Default client connects to localhost:6379 + defaultClient := redis.NewClient(&redis.Options{}) + + c := &redisCache{ + client: defaultClient, + ttl: time.Minute, + name: "default-redis", + logger: observability.NewStdLogger(), + metrics: observability.NewMetrics("gofr", "redis_cache"), + } + + for _, opt := range opts { + if err := opt(c); err != nil { + return nil, fmt.Errorf("failed to configure redis cache: %w", err) + } + } + + // Verify the connection is alive + if err := c.client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("failed to connect to redis: %w", err) + } + + c.logger.Infof(ctx, "Redis cache '%s' initialized on %s, DB %d, TTL=%v", + c.name, c.client.Options().Addr, c.client.Options().DB, c.ttl) + + return c, nil +} + +func (c *redisCache) UseTracer(tracer trace.Tracer) { + c.tracer = &tracer +} + +// validateKey ensures key is non-empty. +func validateKey(key string) error { + if key == "" { + return ErrEmptyKey + } + + return nil +} + +// serializeValue converts a value to JSON for storage. +func (*redisCache) serializeValue(value any) (string, error) { + // Handle simple types directly to maintain readability in Redis + switch v := value.(type) { + case string: + return v, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%d", v), nil + case float32, float64: + return fmt.Sprintf("%g", v), nil + case bool: + if v { + return "true", nil + } + + return "false", nil + default: + // For complex types, use JSON + bytes, err := json.Marshal(value) + if err != nil { + return "", fmt.Errorf("failed to serialize value: %w", err) + } + + return string(bytes), nil + } +} + +// Set adds or updates a key-value pair in the Redis cache with the default TTL. +// The value is serialized before being stored. Simple types are stored as strings, +// while complex types are JSON-marshaled. +// This operation is thread-safe. +func (c *redisCache) Set(ctx context.Context, key string, value any) error { + start := time.Now() + + if err := validateKey(key); err != nil { + c.logger.Errorf(ctx, "Set failed: %v", err) + return err + } + + if value == nil { + c.logger.Errorf(ctx, "Set failed: %v", ErrNilValue) + return ErrNilValue + } + + serializedValue, err := c.serializeValue(value) + if err != nil { + c.logger.Errorf(ctx, "Set failed to serialize value for key '%s': %v", key, err) + return err + } + + if err := c.client.Set(ctx, key, serializedValue, c.ttl).Err(); err != nil { + c.logger.Errorf(ctx, "Redis Set failed for key '%s': %v", key, err) + return err + } + + duration := time.Since(start) + c.logger.LogRequest(ctx, "DEBUG", "Set new cache key", "SUCCESS", duration, key) + + if c.metrics != nil { + c.metrics.Sets().WithLabelValues(c.name).Inc() + c.metrics.Items().WithLabelValues(c.name).Set(float64(c.countKeys(ctx))) + c.metrics.Latency().WithLabelValues(c.name, "set").Observe(duration.Seconds()) + } + + return nil +} + +// Get retrieves an item from the Redis cache. +// If the key is found, it returns the stored value and true. +// The caller is responsible for deserializing it if necessary. +// If the key is not found, it returns nil and false. +// This operation is thread-safe. +func (c *redisCache) Get(ctx context.Context, key string) (value any, found bool, err error) { + start := time.Now() + + if keyerr := validateKey(key); keyerr != nil { + c.logger.Errorf(ctx, "Get failed: %v", keyerr) + return nil, false, keyerr + } + + val, err := c.client.Get(ctx, key).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + duration := time.Since(start) + c.logger.Missf(ctx, "GET", duration, key) + + if c.metrics != nil { + c.metrics.Misses().WithLabelValues(c.name).Inc() + c.metrics.Latency().WithLabelValues(c.name, "get").Observe(duration.Seconds()) + } + + return nil, false, nil // Key does not exist + } + + c.logger.Errorf(ctx, "Redis Get failed for key '%s': %v", key, err) + + return nil, false, err + } + + duration := time.Since(start) + c.logger.Hitf(ctx, "GET", duration, key) + + if c.metrics != nil { + c.metrics.Hits().WithLabelValues(c.name).Inc() + c.metrics.Latency().WithLabelValues(c.name, "get").Observe(duration.Seconds()) + } + + return val, true, nil +} + +// Delete removes a key from the Redis cache. +// If the key does not exist, the operation is a no-op. +// This operation is thread-safe. +func (c *redisCache) Delete(ctx context.Context, key string) error { + start := time.Now() + + if err := validateKey(key); err != nil { + c.logger.Errorf(ctx, "Delete failed: %v", err) + return err + } + + duration := time.Since(start) + + if err := c.client.Del(ctx, key).Err(); err != nil { + c.logger.Errorf(ctx, "Redis Del failed for key '%s': %v", key, err) + return err + } + + c.logger.LogRequest(ctx, "DEBUG", "Deleted cache key", "SUCCESS", duration, key) + + if c.metrics != nil { + c.metrics.Deletes().WithLabelValues(c.name).Inc() + c.metrics.Items().WithLabelValues(c.name).Set(float64(c.countKeys(ctx))) + c.metrics.Latency().WithLabelValues(c.name, "delete").Observe(duration.Seconds()) + } + + return nil +} + +// Exists checks if a key exists in the Redis cache. +// It returns true if the key is present, false otherwise. +// This operation is thread-safe. +func (c *redisCache) Exists(ctx context.Context, key string) (bool, error) { + start := time.Now() + + if err := validateKey(key); err != nil { + c.logger.Errorf(ctx, "Exists failed: %v", err) + return false, err + } + + res, err := c.client.Exists(ctx, key).Result() + if err != nil { + c.logger.Errorf(ctx, "Redis Exists failed for key '%s': %v", key, err) + return false, err + } + + if c.metrics != nil { + c.metrics.Latency().WithLabelValues(c.name, "exists").Observe(time.Since(start).Seconds()) + } + + return res > 0, nil +} + +// Clear removes all keys from the current Redis database (using FLUSHDB). +// This is a destructive operation and should be used with caution. +// This operation is thread-safe. +func (c *redisCache) Clear(ctx context.Context) error { + start := time.Now() + + if err := c.client.FlushDB(ctx).Err(); err != nil { + c.logger.Errorf(ctx, "Redis FlushDB failed: %v", err) + return err + } + + duration := time.Since(start) + c.logger.LogRequest(ctx, "WARN", "Cleared all keys", "SUCCESS", duration, c.name) + + if c.metrics != nil { + c.metrics.Items().WithLabelValues(c.name).Set(0) + c.metrics.Latency().WithLabelValues(c.name, "clear").Observe(duration.Seconds()) + } + + return nil +} + +// Close closes the connection to the Redis server. +// It's important to call Close to release network resources. +func (c *redisCache) Close(ctx context.Context) error { + if c.client == nil { + return ErrNilClient + } + + if err := c.client.Close(); err != nil { + c.logger.Errorf(ctx, "Failed to close redis client: %v", err) + return err + } + + c.logger.Infof(ctx, "Redis cache '%s' closed", c.name) + + return nil +} + +// countKeys returns the number of keys in the current Redis DB. +func (c *redisCache) countKeys(ctx context.Context) int64 { + res, err := c.client.DBSize(ctx).Result() + if err != nil { + c.logger.Errorf(ctx, "DBSize failed: %v", err) + return 0 + } + + return res +} diff --git a/pkg/cache/redis/redis_test.go b/pkg/cache/redis/redis_test.go new file mode 100644 index 000000000..4083b74fb --- /dev/null +++ b/pkg/cache/redis/redis_test.go @@ -0,0 +1,405 @@ +package redis + +import ( + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gofr.dev/pkg/cache" +) + +func makeRedisCache(t *testing.T, opts ...Option) cache.Cache { + t.Helper() + + allOpts := append([]Option{WithDB(15)}, opts...) + + c, err := NewRedisCache(t.Context(), allOpts...) + if err != nil { + t.Skipf("skipping redis tests: could not connect to redis. Error: %v", err) + } + + t.Cleanup(func() { + _ = c.Clear(t.Context()) + _ = c.Close(t.Context()) + }) + + return c +} + +func TestOperations(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + require.NoError(t, c.Set(ctx, "key1", "value10")) + + v, found, err := c.Get(ctx, "key1") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value10", v) + + exists, err := c.Exists(ctx, "key1") + require.NoError(t, err) + assert.True(t, exists) + + require.NoError(t, c.Delete(ctx, "key1")) + + exists, err = c.Exists(ctx, "key1") + require.NoError(t, err) + assert.False(t, exists) +} + +func TestClear(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + require.NoError(t, c.Set(ctx, "x", 1)) + require.NoError(t, c.Set(ctx, "y", 2)) + + require.NoError(t, c.Clear(ctx)) + + for _, k := range []string{"x", "y"} { + exist, err := c.Exists(ctx, k) + require.NoError(t, err) + assert.False(t, exist) + } +} + +func TestTTLExpiry(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t, WithTTL(50*time.Millisecond)) + + require.NoError(t, c.Set(ctx, "foo", "bar")) + + time.Sleep(60 * time.Millisecond) + + _, found, err := c.Get(ctx, "foo") + require.NoError(t, err) + assert.False(t, found, "key should have expired and not be found") +} + +func TestOverwrite(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + require.NoError(t, c.Set(ctx, "dupKey", "first")) + require.NoError(t, c.Set(ctx, "dupKey", "second")) + + v, found, err := c.Get(ctx, "dupKey") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "second", v) +} + +func TestDeleteNonExistent(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + err := c.Delete(ctx, "ghost") + require.NoError(t, err) +} + +func TestClearEmpty(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + err := c.Clear(ctx) + require.NoError(t, err) +} + +func TestConcurrentAccess(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + assert.NoError(t, c.Set(ctx, "concurrent", "safe")) + _, _, err := c.Get(ctx, "concurrent") + assert.NoError(t, err) + _, err = c.Exists(ctx, "concurrent") + assert.NoError(t, err) + }() + } + + wg.Wait() +} + +func TestMultipleOptions(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + rc, err := NewRedisCache(ctx, + WithTTL(30*time.Second), + WithDB(14), + WithTTL(60*time.Second), + ) + require.NoError(t, err) + + c := rc.(*redisCache) + defer c.Close(ctx) + + assert.Equal(t, 60*time.Second, c.ttl) + assert.Equal(t, 14, c.client.Options().DB) +} + +func TestDifferentValueTypes(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + testCases := []struct { + key string + value any + expected string + }{ + {"str", "hello", "hello"}, + {"int", 42, "42"}, + {"flt", 3.14, "3.14"}, + {"bool_true", true, "true"}, + {"bool_false", false, "false"}, + {"int64", int64(123), "123"}, + {"float32", float32(2.5), "2.5"}, + } + + for _, tc := range testCases { + require.NoError(t, c.Set(ctx, tc.key, tc.value), "Failed to set key %s", tc.key) + } + + for _, tc := range testCases { + v, found, err := c.Get(ctx, tc.key) + require.NoError(t, err, "Failed to get key %s", tc.key) + assert.True(t, found, "Key %s not found", tc.key) + assert.Equal(t, tc.expected, v, "Value mismatch for key %s", tc.key) + } +} + +func TestEmptyStringKey(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + err := c.Set(ctx, "", "v") + require.ErrorIs(t, err, ErrEmptyKey) + + _, _, err = c.Get(ctx, "") + require.ErrorIs(t, err, ErrEmptyKey) + + err = c.Delete(ctx, "") + require.ErrorIs(t, err, ErrEmptyKey) + + _, err = c.Exists(ctx, "") + require.ErrorIs(t, err, ErrEmptyKey) +} + +func TestNilValue(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + err := c.Set(ctx, "key", nil) + require.ErrorIs(t, err, ErrNilValue) +} + +func TestConcurrentSetSameKey(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + + go func(val int) { + defer wg.Done() + assert.NoError(t, c.Set(ctx, "race", val)) + }(i) + } + + wg.Wait() + + _, found, _ := c.Get(ctx, "race") + assert.True(t, found) +} + +func TestOptionValidation(t *testing.T) { + ctx := t.Context() + + t.Run("Invalid DB number (-1)", func(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + _, err := NewRedisCache(ctx, WithDB(-1)) + require.Error(t, err) + }) + + t.Run("Invalid DB number (16)", func(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + _, err := NewRedisCache(ctx, WithDB(16)) + require.Error(t, err) + }) + + t.Run("Negative TTL", func(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + _, err := NewRedisCache(ctx, WithTTL(-1*time.Second)) + require.Error(t, err) + }) + + t.Run("Empty Address", func(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + _, err := NewRedisCache(ctx, WithAddr("")) + require.Error(t, err) + }) +} + +func TestCacheHitMiss(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t) + + // Test cache miss + _, found, err := c.Get(ctx, "nonexistent") + require.NoError(t, err) + assert.False(t, found) + + // Test cache hit + require.NoError(t, c.Set(ctx, "existing", "value")) + val, found, err := c.Get(ctx, "existing") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value", val) +} + +func TestZeroTTL(t *testing.T) { + originalRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + t.Cleanup(func() { + prometheus.DefaultRegisterer = originalRegistry + }) + + ctx := t.Context() + c := makeRedisCache(t, WithTTL(0)) + + require.NoError(t, c.Set(ctx, "permanent", "value")) + + // Should still be there after a short wait + time.Sleep(10 * time.Millisecond) + + val, found, err := c.Get(ctx, "permanent") + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value", val) +} diff --git a/pkg/gofr/container/container.go b/pkg/gofr/container/container.go index c3ee964b1..4de987720 100644 --- a/pkg/gofr/container/container.go +++ b/pkg/gofr/container/container.go @@ -36,6 +36,7 @@ import ( "gofr.dev/pkg/gofr/service" "gofr.dev/pkg/gofr/version" "gofr.dev/pkg/gofr/websocket" + "gofr.dev/pkg/cache" ) // Container is a collection of all common application level concerns. Things like Logger, Connection Pool for Redis @@ -71,6 +72,8 @@ type Container struct { KVStore KVStore File file.FileSystem + + Cache map[string]cache.Cache } func NewContainer(conf config.Config) *Container { @@ -192,6 +195,13 @@ func (c *Container) Close() error { c.WSManager.CloseConnection(conn) } + // Close all cache + for _, cc := range c.Cache { + if cc != nil { + _ = cc.Close(context.Background()) + } + } + return err } @@ -288,6 +298,21 @@ func (c *Container) GetAppVersion() string { return c.appVersion } +func (c *Container) AddCache(name string, v cache.Cache) { + if c.Cache == nil { + c.Cache = make(map[string]cache.Cache) + } + c.Cache[name] = v +} + +func (c *Container) GetCache(name string) cache.Cache { + if c.Cache == nil { + return nil + } + return c.Cache[name] +} + + func (c *Container) GetPublisher() pubsub.Publisher { return c.PubSub } diff --git a/pkg/gofr/external_cache.go b/pkg/gofr/external_cache.go new file mode 100644 index 000000000..e8ef30511 --- /dev/null +++ b/pkg/gofr/external_cache.go @@ -0,0 +1,67 @@ +package gofr + +import ( + "context" + "time" + + "gofr.dev/pkg/cache" + "gofr.dev/pkg/cache/factory" +) + +// AddInMemoryCache adds an in-memory cache to the app's container. +func (a *App) AddInMemoryCache(ctx context.Context, name string, ttl time.Duration, maxItems int) { + c, err := factory.NewInMemoryCache( + ctx, + name, + factory.WithLogger(a.Logger()), + factory.WithTTL(ttl), + factory.WithMaxItems(maxItems), + ) + if err != nil { + a.Logger().Errorf("inmemory cache init failed: %v", err) + return + } + + a.container.AddCache(name, c) +} + +// AddRedisCache adds a Redis cache to the app's container. +func (a *App) AddRedisCache(ctx context.Context, name string, ttl time.Duration, addr string) { + c, err := factory.NewRedisCache( + ctx, + name, + factory.WithLogger(a.Logger()), + factory.WithTTL(ttl), + factory.WithRedisAddr(addr), + ) + if err != nil { + a.Logger().Errorf("redis cache init failed: %v", err) + return + } + + a.container.AddCache(name, c) +} + +// AddRedisCacheDirect adds a Redis cache with full configuration. +func (a *App) AddRedisCacheDirect(ctx context.Context, name, addr, password string, db int, ttl time.Duration) { + c, err := factory.NewRedisCache( + ctx, + name, + factory.WithLogger(a.Logger()), + factory.WithTTL(ttl), + factory.WithRedisAddr(addr), + factory.WithRedisPassword(password), + factory.WithRedisDB(db), + ) + if err != nil { + a.Logger().Errorf("redis cache init failed: %v", err) + return + } + + a.container.AddCache(name, c) +} + +// GetCache retrieves a cache instance from the app's container by name. +func (a *App) GetCache(name string) cache.Cache { + return a.container.GetCache(name) +}