diff --git a/.env.example b/.env.example index bfadfd44..7e7465e3 100644 --- a/.env.example +++ b/.env.example @@ -38,3 +38,13 @@ MCP_REGISTRY_OIDC_EXTRA_CLAIMS=[{"hd":"modelcontextprotocol.io"}] # Grant admin permissions to OIDC-authenticated users MCP_REGISTRY_OIDC_EDIT_PERMISSIONS=* MCP_REGISTRY_OIDC_PUBLISH_PERMISSIONS=* + +# Rate limiting configuration +# Note: Rate limits are enforced per-pod, so in multi-replica deployments the +# effective limits are approximate (e.g., 2 replicas = up to 2x the configured rate). +# Enable or disable rate limiting (default: true) +MCP_REGISTRY_RATE_LIMIT_ENABLED=true +# Maximum requests per minute per IP address (default: 60) +MCP_REGISTRY_RATE_LIMIT_REQUESTS_PER_MINUTE=60 +# Maximum requests per hour per IP address (default: 1000) +MCP_REGISTRY_RATE_LIMIT_REQUESTS_PER_HOUR=1000 diff --git a/docs/reference/api/official-registry-api.md b/docs/reference/api/official-registry-api.md index f9e3c7dd..371ec5b2 100644 --- a/docs/reference/api/official-registry-api.md +++ b/docs/reference/api/official-registry-api.md @@ -14,6 +14,28 @@ This API is based on the [generic registry API](./generic-registry-api.md) with - **[Live API Docs](https://registry.modelcontextprotocol.io/docs)** - Stoplight elements with try-it-now functionality - **[OpenAPI Spec](https://registry.modelcontextprotocol.io/openapi.yaml)** - Complete machine-readable specification +## Rate Limiting + +The official registry enforces rate limits to protect against abuse: + +- **60 requests per minute** per IP address +- **1,000 requests per hour** per IP address + +When rate limited, the API returns HTTP `429 Too Many Requests` with a `Retry-After: 60` header. The response body follows the [RFC 7807](https://tools.ietf.org/html/rfc7807) problem details format: + +```json +{ + "title": "Too Many Requests", + "status": 429, + "detail": "Rate limit exceeded. Please reduce request frequency and retry after some time." +} +``` + +**Notes:** +- Rate limits are approximate due to the multi-replica deployment architecture +- The `/health`, `/ping`, and `/metrics` endpoints are not rate limited +- Clients should implement exponential backoff when receiving 429 responses + ## Extensions The official registry implements the [Generic Registry API](./generic-registry-api.md) with the following specific configurations and extensions: diff --git a/internal/api/ratelimit/ratelimit.go b/internal/api/ratelimit/ratelimit.go new file mode 100644 index 00000000..45bafb89 --- /dev/null +++ b/internal/api/ratelimit/ratelimit.go @@ -0,0 +1,297 @@ +// Package ratelimit provides IP-based rate limiting middleware for HTTP servers. +package ratelimit + +import ( + "encoding/json" + "log" + "net" + "net/http" + "strings" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// OnRateLimitedFunc is a callback invoked when a request is rate limited. +// It receives the client IP that was blocked. +type OnRateLimitedFunc func(ip string) + +// Config holds the rate limiting configuration +type Config struct { + // RequestsPerMinute is the maximum number of requests allowed per minute per IP + RequestsPerMinute int + // RequestsPerHour is the maximum number of requests allowed per hour per IP + RequestsPerHour int + // CleanupInterval is how often to clean up stale entries (default: 10 minutes) + CleanupInterval time.Duration + // SkipPaths are paths that should not be rate limited + SkipPaths []string + // MaxVisitors is the maximum number of visitor entries to track (memory protection). + // When exceeded, oldest entries are evicted. Default: 100000. + MaxVisitors int + // OnRateLimited is an optional callback invoked when a request is rate limited. + // Used for recording metrics. + OnRateLimited OnRateLimitedFunc +} + +// DefaultConfig returns the default rate limiting configuration +func DefaultConfig() Config { + return Config{ + RequestsPerMinute: 60, + RequestsPerHour: 1000, + CleanupInterval: 10 * time.Minute, + SkipPaths: []string{"/health", "/ping", "/metrics"}, + MaxVisitors: 100000, + } +} + +// visitor tracks rate limiting state for a single IP address +type visitor struct { + minuteLimiter *rate.Limiter + hourLimiter *rate.Limiter + lastSeen time.Time +} + +// RateLimiter implements IP-based rate limiting +type RateLimiter struct { + config Config + visitors map[string]*visitor + mu sync.RWMutex + stopCh chan struct{} +} + +// New creates a new RateLimiter with the given configuration +func New(cfg Config) *RateLimiter { + if cfg.MaxVisitors <= 0 { + cfg.MaxVisitors = 100000 + } + + rl := &RateLimiter{ + config: cfg, + visitors: make(map[string]*visitor), + stopCh: make(chan struct{}), + } + + // Start background cleanup goroutine + go rl.cleanupLoop() + + return rl +} + +// Stop stops the background cleanup goroutine +func (rl *RateLimiter) Stop() { + close(rl.stopCh) +} + +// cleanupLoop periodically removes stale visitor entries +func (rl *RateLimiter) cleanupLoop() { + ticker := time.NewTicker(rl.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + rl.cleanup() + case <-rl.stopCh: + return + } + } +} + +// cleanup removes visitors that haven't been seen in the last hour +func (rl *RateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + threshold := time.Now().Add(-time.Hour) + for ip, v := range rl.visitors { + if v.lastSeen.Before(threshold) { + delete(rl.visitors, ip) + } + } +} + +// evictOldestLocked removes the oldest visitor entry. Must be called with lock held. +func (rl *RateLimiter) evictOldestLocked() { + var oldestIP string + var oldestTime time.Time + + for ip, v := range rl.visitors { + if oldestIP == "" || v.lastSeen.Before(oldestTime) { + oldestIP = ip + oldestTime = v.lastSeen + } + } + + if oldestIP != "" { + delete(rl.visitors, oldestIP) + } +} + +// getVisitor returns the visitor for the given IP, creating one if necessary. +// Implements memory protection by evicting oldest entries when MaxVisitors is reached. +func (rl *RateLimiter) getVisitor(ip string) *visitor { + // Try read lock first for existing visitors (common case) + rl.mu.RLock() + v, exists := rl.visitors[ip] + rl.mu.RUnlock() + + if exists { + // Update timestamp - this is a minor race but acceptable for lastSeen + v.lastSeen = time.Now() + return v + } + + // Need to create new visitor - acquire write lock + rl.mu.Lock() + defer rl.mu.Unlock() + + // Double-check after acquiring write lock + v, exists = rl.visitors[ip] + if exists { + v.lastSeen = time.Now() + return v + } + + // Enforce max visitors limit (memory protection) + if len(rl.visitors) >= rl.config.MaxVisitors { + rl.evictOldestLocked() + } + + // Create rate limiters: + // - Minute limiter: allows RequestsPerMinute requests per minute with burst of same + // - Hour limiter: allows RequestsPerHour requests per hour with burst of same + minuteRate := rate.Limit(float64(rl.config.RequestsPerMinute) / 60.0) // requests per second + hourRate := rate.Limit(float64(rl.config.RequestsPerHour) / 3600.0) // requests per second + + v = &visitor{ + minuteLimiter: rate.NewLimiter(minuteRate, rl.config.RequestsPerMinute), + hourLimiter: rate.NewLimiter(hourRate, rl.config.RequestsPerHour), + lastSeen: time.Now(), + } + rl.visitors[ip] = v + + return v +} + +// Allow checks if a request from the given IP should be allowed +func (rl *RateLimiter) Allow(ip string) bool { + v := rl.getVisitor(ip) + + // Both limiters must allow the request + if !v.minuteLimiter.Allow() { + return false + } + if !v.hourLimiter.Allow() { + return false + } + return true +} + +// shouldSkip returns true if the path should not be rate limited +func (rl *RateLimiter) shouldSkip(path string) bool { + for _, skipPath := range rl.config.SkipPaths { + if path == skipPath || strings.HasPrefix(path, skipPath+"/") { + return true + } + } + return false +} + +// getClientIP extracts the client IP from the request. +// It considers X-Forwarded-For and X-Real-IP headers for reverse proxy scenarios, +// as the registry is deployed behind NGINX ingress with use-forwarded-headers enabled. +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header (can contain multiple IPs) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP (original client) + if idx := strings.Index(xff, ","); idx != -1 { + xff = xff[:idx] + } + xff = strings.TrimSpace(xff) + if ip := validateAndNormalizeIP(xff); ip != "" { + return ip + } + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + if ip := validateAndNormalizeIP(strings.TrimSpace(xri)); ip != "" { + return ip + } + } + + // Fall back to RemoteAddr + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // RemoteAddr might not have a port + ip = r.RemoteAddr + } + + // Validate and normalize the IP + if validIP := validateAndNormalizeIP(ip); validIP != "" { + return validIP + } + + // If all else fails, use a fallback that won't cause issues + return "unknown" +} + +// validateAndNormalizeIP validates the IP string and returns a normalized form. +// Returns empty string if the IP is invalid. +func validateAndNormalizeIP(ip string) string { + if ip == "" { + return "" + } + + // Parse the IP to validate it + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return "" + } + + // Return normalized string representation + return parsedIP.String() +} + +// Middleware returns an HTTP middleware that enforces rate limiting +func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip rate limiting for certain paths + if rl.shouldSkip(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + + ip := getClientIP(r) + + if !rl.Allow(ip) { + // Record the rate-limited request if callback is configured + if rl.config.OnRateLimited != nil { + rl.config.OnRateLimited(ip) + } + + w.Header().Set("Content-Type", "application/problem+json") + w.Header().Set("Retry-After", "60") + w.WriteHeader(http.StatusTooManyRequests) + + errorBody := map[string]interface{}{ + "title": "Too Many Requests", + "status": http.StatusTooManyRequests, + "detail": "Rate limit exceeded. Please reduce request frequency and retry after some time.", + } + + jsonData, err := json.Marshal(errorBody) + if err != nil { + log.Printf("Failed to marshal rate limit error response: %v", err) + _, _ = w.Write([]byte(`{"title":"Too Many Requests","status":429}`)) + return + } + _, _ = w.Write(jsonData) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/api/ratelimit/ratelimit_test.go b/internal/api/ratelimit/ratelimit_test.go new file mode 100644 index 00000000..0d6e2a56 --- /dev/null +++ b/internal/api/ratelimit/ratelimit_test.go @@ -0,0 +1,484 @@ +package ratelimit_test + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/modelcontextprotocol/registry/internal/api/ratelimit" +) + +const ( + testLocalAddr = "127.0.0.1:12345" + testLocalAltAddr = "127.0.0.1:12346" +) + +func TestRateLimiter_Allow(t *testing.T) { + cfg := ratelimit.Config{ + RequestsPerMinute: 5, + RequestsPerHour: 10, + CleanupInterval: time.Hour, // Long interval to avoid cleanup during test + SkipPaths: []string{"/health"}, + MaxVisitors: 1000, + } + rl := ratelimit.New(cfg) + defer rl.Stop() + + ip := "192.168.1.1" + + // Should allow the first 5 requests (minute limit) + for i := 0; i < 5; i++ { + if !rl.Allow(ip) { + t.Errorf("Request %d should be allowed", i+1) + } + } + + // 6th request should be blocked (minute limit exceeded) + if rl.Allow(ip) { + t.Error("Request 6 should be blocked due to minute limit") + } + + // Different IP should still be allowed + if !rl.Allow("192.168.1.2") { + t.Error("Request from different IP should be allowed") + } +} + +func TestRateLimiter_HourlyLimit(t *testing.T) { + // Configure with high minute limit but low hour limit + cfg := ratelimit.Config{ + RequestsPerMinute: 100, + RequestsPerHour: 5, + CleanupInterval: time.Hour, + SkipPaths: []string{"/health"}, + MaxVisitors: 1000, + } + rl := ratelimit.New(cfg) + defer rl.Stop() + + ip := "192.168.1.1" + + // Should allow up to hour limit + for i := 0; i < 5; i++ { + if !rl.Allow(ip) { + t.Errorf("Request %d should be allowed", i+1) + } + } + + // Next request should be blocked by hour limit + if rl.Allow(ip) { + t.Error("Request should be blocked due to hour limit") + } +} + +func TestRateLimiter_Middleware(t *testing.T) { + cfg := ratelimit.Config{ + RequestsPerMinute: 2, + RequestsPerHour: 100, + CleanupInterval: time.Hour, + SkipPaths: []string{"/health", "/ping"}, + MaxVisitors: 1000, + } + rl := ratelimit.New(cfg) + defer rl.Stop() + + // Create a simple handler that returns OK + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + middleware := rl.Middleware(handler) + + t.Run("allows requests within limit", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil) + req.RemoteAddr = "10.0.0.1:12345" + w := httptest.NewRecorder() + + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } + }) + + t.Run("blocks requests over limit", func(t *testing.T) { + // Exhaust the limit for this IP + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil) + req.RemoteAddr = "10.0.0.2:12345" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + } + + // This request should be blocked + req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil) + req.RemoteAddr = "10.0.0.2:12345" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected status %d, got %d", http.StatusTooManyRequests, w.Code) + } + + // Check Retry-After header + if w.Header().Get("Retry-After") != "60" { + t.Errorf("expected Retry-After header to be 60, got %s", w.Header().Get("Retry-After")) + } + + // Check Content-Type + if w.Header().Get("Content-Type") != "application/problem+json" { + t.Errorf("expected Content-Type application/problem+json, got %s", w.Header().Get("Content-Type")) + } + }) + + t.Run("skips health endpoint", func(t *testing.T) { + // Use an IP that's already rate limited + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "/health", nil) + req.RemoteAddr = "10.0.0.3:12345" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("health endpoint should always be allowed, got status %d on request %d", w.Code, i+1) + } + } + }) + + t.Run("skips ping endpoint", func(t *testing.T) { + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "/ping", nil) + req.RemoteAddr = "10.0.0.4:12345" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("ping endpoint should always be allowed, got status %d on request %d", w.Code, i+1) + } + } + }) +} + +func TestGetClientIP_WithHeaders(t *testing.T) { + cfg := ratelimit.Config{ + RequestsPerMinute: 1, + RequestsPerHour: 100, + CleanupInterval: time.Hour, + SkipPaths: []string{}, + MaxVisitors: 1000, + } + rl := ratelimit.New(cfg) + defer rl.Stop() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := rl.Middleware(handler) + + t.Run("uses X-Forwarded-For header", func(t *testing.T) { + // First request from this forwarded IP + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = testLocalAddr + req.Header.Set("X-Forwarded-For", "203.0.113.1") + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("first request should be allowed, got status %d", w.Code) + } + + // Second request should be blocked (same forwarded IP) + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = testLocalAltAddr + req2.Header.Set("X-Forwarded-For", "203.0.113.1") + w2 := httptest.NewRecorder() + middleware.ServeHTTP(w2, req2) + + if w2.Code != http.StatusTooManyRequests { + t.Errorf("second request should be blocked, got status %d", w2.Code) + } + }) + + t.Run("uses first IP from X-Forwarded-For with multiple IPs", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = testLocalAddr + req.Header.Set("X-Forwarded-For", "203.0.113.2, 10.0.0.1, 192.168.1.1") + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("first request should be allowed, got status %d", w.Code) + } + + // Second request with same first IP should be blocked + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = testLocalAltAddr + req2.Header.Set("X-Forwarded-For", "203.0.113.2, 10.0.0.2") + w2 := httptest.NewRecorder() + middleware.ServeHTTP(w2, req2) + + if w2.Code != http.StatusTooManyRequests { + t.Errorf("second request should be blocked, got status %d", w2.Code) + } + }) + + t.Run("uses X-Real-IP header", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = testLocalAddr + req.Header.Set("X-Real-IP", "203.0.113.3") + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("first request should be allowed, got status %d", w.Code) + } + + // Second request should be blocked + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = testLocalAltAddr + req2.Header.Set("X-Real-IP", "203.0.113.3") + w2 := httptest.NewRecorder() + middleware.ServeHTTP(w2, req2) + + if w2.Code != http.StatusTooManyRequests { + t.Errorf("second request should be blocked, got status %d", w2.Code) + } + }) + + t.Run("falls back to RemoteAddr when no headers", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "203.0.113.4:12345" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("first request should be allowed, got status %d", w.Code) + } + + // Second request from same IP should be blocked + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "203.0.113.4:12346" + w2 := httptest.NewRecorder() + middleware.ServeHTTP(w2, req2) + + if w2.Code != http.StatusTooManyRequests { + t.Errorf("second request should be blocked, got status %d", w2.Code) + } + }) +} + +func TestGetClientIP_InvalidIPs(t *testing.T) { + cfg := ratelimit.Config{ + RequestsPerMinute: 100, + RequestsPerHour: 1000, + CleanupInterval: time.Hour, + SkipPaths: []string{}, + MaxVisitors: 1000, + } + rl := ratelimit.New(cfg) + defer rl.Stop() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := rl.Middleware(handler) + + t.Run("handles empty X-Forwarded-For", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "10.0.0.100:12345" + req.Header.Set("X-Forwarded-For", "") + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + // Should fall back to RemoteAddr + if w.Code != http.StatusOK { + t.Errorf("request should be allowed, got status %d", w.Code) + } + }) + + t.Run("handles malformed X-Forwarded-For", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "10.0.0.101:12345" + req.Header.Set("X-Forwarded-For", "not-an-ip") + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + // Should fall back to RemoteAddr + if w.Code != http.StatusOK { + t.Errorf("request should be allowed, got status %d", w.Code) + } + }) + + t.Run("handles IPv6 addresses", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "[::1]:12345" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("request should be allowed, got status %d", w.Code) + } + }) +} + +func TestRateLimiter_MaxVisitors(t *testing.T) { + cfg := ratelimit.Config{ + RequestsPerMinute: 100, + RequestsPerHour: 1000, + CleanupInterval: time.Hour, + SkipPaths: []string{}, + MaxVisitors: 3, // Very low limit for testing + } + rl := ratelimit.New(cfg) + defer rl.Stop() + + // Add visitors up to the limit + for i := 0; i < 3; i++ { + ip := "192.168.1." + string(rune('1'+i)) + rl.Allow(ip) + } + + // Adding one more should evict the oldest + rl.Allow("192.168.2.1") + + // The rate limiter should still function (not crash or hang) + if !rl.Allow("192.168.2.2") { + t.Error("rate limiter should still work after eviction") + } +} + +func TestRateLimiter_Concurrency(_ *testing.T) { + cfg := ratelimit.Config{ + RequestsPerMinute: 1000, + RequestsPerHour: 10000, + CleanupInterval: time.Hour, + SkipPaths: []string{}, + MaxVisitors: 100, + } + rl := ratelimit.New(cfg) + defer rl.Stop() + + // Test concurrent access from same IP + var wg sync.WaitGroup + const goroutines = 100 + const requestsPerGoroutine = 10 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < requestsPerGoroutine; j++ { + rl.Allow("192.168.1.1") + } + }() + } + + wg.Wait() + // If we got here without deadlock or panic, concurrency is working +} + +func TestDefaultConfig(t *testing.T) { + cfg := ratelimit.DefaultConfig() + + if cfg.RequestsPerMinute != 60 { + t.Errorf("expected RequestsPerMinute to be 60, got %d", cfg.RequestsPerMinute) + } + + if cfg.RequestsPerHour != 1000 { + t.Errorf("expected RequestsPerHour to be 1000, got %d", cfg.RequestsPerHour) + } + + if cfg.CleanupInterval != 10*time.Minute { + t.Errorf("expected CleanupInterval to be 10 minutes, got %v", cfg.CleanupInterval) + } + + if cfg.MaxVisitors != 100000 { + t.Errorf("expected MaxVisitors to be 100000, got %d", cfg.MaxVisitors) + } + + // Check skip paths + expectedSkipPaths := map[string]bool{ + "/health": true, + "/ping": true, + "/metrics": true, + } + + for _, path := range cfg.SkipPaths { + if !expectedSkipPaths[path] { + t.Errorf("unexpected skip path: %s", path) + } + delete(expectedSkipPaths, path) + } + + if len(expectedSkipPaths) > 0 { + t.Errorf("missing skip paths: %v", expectedSkipPaths) + } +} + +func TestOnRateLimitedCallback(t *testing.T) { + var callbackCount int + var lastIP string + + cfg := ratelimit.Config{ + RequestsPerMinute: 1, + RequestsPerHour: 100, + CleanupInterval: time.Hour, + SkipPaths: []string{}, + MaxVisitors: 1000, + OnRateLimited: func(ip string) { + callbackCount++ + lastIP = ip + }, + } + rl := ratelimit.New(cfg) + defer rl.Stop() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := rl.Middleware(handler) + + // First request should succeed, no callback + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if callbackCount != 0 { + t.Errorf("callback should not be called on allowed request, got %d calls", callbackCount) + } + + // Second request should be blocked, callback should fire + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "192.168.1.100:12346" + w2 := httptest.NewRecorder() + middleware.ServeHTTP(w2, req2) + + if w2.Code != http.StatusTooManyRequests { + t.Errorf("expected status %d, got %d", http.StatusTooManyRequests, w2.Code) + } + + if callbackCount != 1 { + t.Errorf("callback should be called once, got %d calls", callbackCount) + } + + if lastIP != "192.168.1.100" { + t.Errorf("expected IP 192.168.1.100, got %s", lastIP) + } + + // Third request also blocked, callback should fire again + req3 := httptest.NewRequest(http.MethodGet, "/test", nil) + req3.RemoteAddr = "192.168.1.100:12347" + w3 := httptest.NewRecorder() + middleware.ServeHTTP(w3, req3) + + if callbackCount != 2 { + t.Errorf("callback should be called twice, got %d calls", callbackCount) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 496282b6..08ca7a89 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -11,6 +11,7 @@ import ( "github.com/rs/cors" v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" + "github.com/modelcontextprotocol/registry/internal/api/ratelimit" "github.com/modelcontextprotocol/registry/internal/api/router" "github.com/modelcontextprotocol/registry/internal/config" "github.com/modelcontextprotocol/registry/internal/service" @@ -37,10 +38,11 @@ func TrailingSlashMiddleware(next http.Handler) http.Handler { // Server represents the HTTP server type Server struct { - config *config.Config - registry service.RegistryService - humaAPI huma.API - server *http.Server + config *config.Config + registry service.RegistryService + humaAPI huma.API + server *http.Server + rateLimiter *ratelimit.RateLimiter } // NewServer creates a new HTTP server @@ -67,13 +69,39 @@ func NewServer(cfg *config.Config, registryService service.RegistryService, metr }) // Wrap the mux with middleware stack - // Order: TrailingSlash -> CORS -> Mux - handler := TrailingSlashMiddleware(corsHandler.Handler(mux)) + // Order: TrailingSlash -> RateLimit -> CORS -> Mux + handler := corsHandler.Handler(mux) + + // Initialize rate limiter if enabled. + // Note: Rate limits are enforced per-pod, so in multi-replica deployments the + // effective limits are approximate (e.g., 2 replicas = up to 2x the configured rate). + var rateLimiter *ratelimit.RateLimiter + if cfg.RateLimitEnabled { + rateLimitConfig := ratelimit.Config{ + RequestsPerMinute: cfg.RateLimitRequestsPerMinute, + RequestsPerHour: cfg.RateLimitRequestsPerHour, + CleanupInterval: 10 * time.Minute, + SkipPaths: []string{"/health", "/ping", "/metrics"}, + MaxVisitors: 100000, + OnRateLimited: func(_ string) { + if metrics != nil { + metrics.RateLimitedRequests.Add(context.Background(), 1) + } + }, + } + rateLimiter = ratelimit.New(rateLimitConfig) + handler = rateLimiter.Middleware(handler) + log.Printf("Rate limiting enabled: %d req/min, %d req/hour per IP", + cfg.RateLimitRequestsPerMinute, cfg.RateLimitRequestsPerHour) + } + + handler = TrailingSlashMiddleware(handler) server := &Server{ - config: cfg, - registry: registryService, - humaAPI: api, + config: cfg, + registry: registryService, + humaAPI: api, + rateLimiter: rateLimiter, server: &http.Server{ Addr: cfg.ServerAddress, Handler: handler, @@ -92,5 +120,9 @@ func (s *Server) Start() error { // Shutdown gracefully shuts down the server func (s *Server) Shutdown(ctx context.Context) error { + // Stop rate limiter cleanup goroutine + if s.rateLimiter != nil { + s.rateLimiter.Stop() + } return s.server.Shutdown(ctx) } diff --git a/internal/config/config.go b/internal/config/config.go index 60e9a067..0b49ed19 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,6 +24,11 @@ type Config struct { OIDCExtraClaims string `env:"OIDC_EXTRA_CLAIMS" envDefault:""` OIDCEditPerms string `env:"OIDC_EDIT_PERMISSIONS" envDefault:""` OIDCPublishPerms string `env:"OIDC_PUBLISH_PERMISSIONS" envDefault:""` + + // Rate Limiting Configuration + RateLimitEnabled bool `env:"RATE_LIMIT_ENABLED" envDefault:"true"` + RateLimitRequestsPerMinute int `env:"RATE_LIMIT_REQUESTS_PER_MINUTE" envDefault:"60"` + RateLimitRequestsPerHour int `env:"RATE_LIMIT_REQUESTS_PER_HOUR" envDefault:"1000"` } // NewConfig creates a new configuration with default values diff --git a/internal/telemetry/metrics.go b/internal/telemetry/metrics.go index 8b24d653..829488da 100644 --- a/internal/telemetry/metrics.go +++ b/internal/telemetry/metrics.go @@ -32,6 +32,9 @@ type Metrics struct { // Up tracks the health of the service Up metric.Int64Gauge + + // RateLimitedRequests tracks requests blocked by rate limiting + RateLimitedRequests metric.Int64Counter } // ShutdownFunc is a delegate that shuts down the OpenTelemetry components. @@ -73,11 +76,20 @@ func NewMetrics(meter metric.Meter) (*Metrics, error) { return nil, fmt.Errorf("failed to create service up gauge: %w", err) } + rateLimited, err := meter.Int64Counter( + Namespace+".http.rate_limited", + metric.WithDescription("Total number of requests blocked by rate limiting"), + ) + if err != nil { + return nil, fmt.Errorf("failed to create rate limited counter: %w", err) + } + return &Metrics{ - Requests: req, - RequestDuration: reqDuration, - ErrorCount: errCount, - Up: up, + Requests: req, + RequestDuration: reqDuration, + ErrorCount: errCount, + Up: up, + RateLimitedRequests: rateLimited, }, nil }