Skip to content

Commit 4914d98

Browse files
tadasantclaude
andcommitted
Address security review feedback
- Add TrustProxy config option (defaults to false for security) - When false, only RemoteAddr is used (prevents IP spoofing) - When true, X-Forwarded-For and X-Real-IP headers are trusted - Add IP validation using net.ParseIP() to handle invalid/malformed IPs - Add MaxVisitors config to prevent memory exhaustion attacks - Implement LRU-style eviction when MaxVisitors limit is reached - Optimize lock granularity with read locks for existing visitors - Add comprehensive tests for: - TrustProxy enabled/disabled behavior - IP spoofing prevention - Invalid IP handling - IPv6 support - Memory limit enforcement - Concurrent access 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 91e9d56 commit 4914d98

File tree

5 files changed

+328
-41
lines changed

5 files changed

+328
-41
lines changed

.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,7 @@ MCP_REGISTRY_RATE_LIMIT_ENABLED=true
4646
MCP_REGISTRY_RATE_LIMIT_REQUESTS_PER_MINUTE=60
4747
# Maximum requests per hour per IP address (default: 1000)
4848
MCP_REGISTRY_RATE_LIMIT_REQUESTS_PER_HOUR=1000
49+
# Trust X-Forwarded-For and X-Real-IP headers from reverse proxy (default: false)
50+
# Only enable this when running behind a trusted reverse proxy (nginx, cloud load balancer)
51+
# WARNING: Setting this to true without a trusted proxy allows IP spoofing
52+
MCP_REGISTRY_RATE_LIMIT_TRUST_PROXY=false

internal/api/ratelimit/ratelimit.go

Lines changed: 119 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ratelimit
33

44
import (
55
"encoding/json"
6+
"log"
67
"net"
78
"net/http"
89
"strings"
@@ -22,6 +23,13 @@ type Config struct {
2223
CleanupInterval time.Duration
2324
// SkipPaths are paths that should not be rate limited
2425
SkipPaths []string
26+
// TrustProxy determines if X-Forwarded-For and X-Real-IP headers should be trusted.
27+
// Set to true only when running behind a trusted reverse proxy (e.g., nginx, cloud load balancer).
28+
// When false, only the direct connection IP (RemoteAddr) is used, preventing IP spoofing.
29+
TrustProxy bool
30+
// MaxVisitors is the maximum number of visitor entries to track (memory protection).
31+
// When exceeded, oldest entries are evicted. Default: 100000.
32+
MaxVisitors int
2533
}
2634

2735
// DefaultConfig returns the default rate limiting configuration
@@ -31,6 +39,8 @@ func DefaultConfig() Config {
3139
RequestsPerHour: 1000,
3240
CleanupInterval: 10 * time.Minute,
3341
SkipPaths: []string{"/health", "/ping", "/metrics"},
42+
TrustProxy: false, // Secure default: don't trust proxy headers
43+
MaxVisitors: 100000,
3444
}
3545
}
3646

@@ -51,6 +61,10 @@ type RateLimiter struct {
5161

5262
// New creates a new RateLimiter with the given configuration
5363
func New(cfg Config) *RateLimiter {
64+
if cfg.MaxVisitors <= 0 {
65+
cfg.MaxVisitors = 100000
66+
}
67+
5468
rl := &RateLimiter{
5569
config: cfg,
5670
visitors: make(map[string]*visitor),
@@ -96,29 +110,66 @@ func (rl *RateLimiter) cleanup() {
96110
}
97111
}
98112

99-
// getVisitor returns the visitor for the given IP, creating one if necessary
113+
// evictOldestLocked removes the oldest visitor entry. Must be called with lock held.
114+
func (rl *RateLimiter) evictOldestLocked() {
115+
var oldestIP string
116+
var oldestTime time.Time
117+
118+
for ip, v := range rl.visitors {
119+
if oldestIP == "" || v.lastSeen.Before(oldestTime) {
120+
oldestIP = ip
121+
oldestTime = v.lastSeen
122+
}
123+
}
124+
125+
if oldestIP != "" {
126+
delete(rl.visitors, oldestIP)
127+
}
128+
}
129+
130+
// getVisitor returns the visitor for the given IP, creating one if necessary.
131+
// Implements memory protection by evicting oldest entries when MaxVisitors is reached.
100132
func (rl *RateLimiter) getVisitor(ip string) *visitor {
133+
// Try read lock first for existing visitors (common case)
134+
rl.mu.RLock()
135+
v, exists := rl.visitors[ip]
136+
rl.mu.RUnlock()
137+
138+
if exists {
139+
// Update timestamp - this is a minor race but acceptable for lastSeen
140+
v.lastSeen = time.Now()
141+
return v
142+
}
143+
144+
// Need to create new visitor - acquire write lock
101145
rl.mu.Lock()
102146
defer rl.mu.Unlock()
103147

104-
v, exists := rl.visitors[ip]
105-
if !exists {
106-
// Create rate limiters:
107-
// - Minute limiter: allows RequestsPerMinute requests per minute with burst of same
108-
// - Hour limiter: allows RequestsPerHour requests per hour with burst of same
109-
minuteRate := rate.Limit(float64(rl.config.RequestsPerMinute) / 60.0) // requests per second
110-
hourRate := rate.Limit(float64(rl.config.RequestsPerHour) / 3600.0) // requests per second
111-
112-
v = &visitor{
113-
minuteLimiter: rate.NewLimiter(minuteRate, rl.config.RequestsPerMinute),
114-
hourLimiter: rate.NewLimiter(hourRate, rl.config.RequestsPerHour),
115-
lastSeen: time.Now(),
116-
}
117-
rl.visitors[ip] = v
118-
} else {
148+
// Double-check after acquiring write lock
149+
v, exists = rl.visitors[ip]
150+
if exists {
119151
v.lastSeen = time.Now()
152+
return v
153+
}
154+
155+
// Enforce max visitors limit (memory protection)
156+
if len(rl.visitors) >= rl.config.MaxVisitors {
157+
rl.evictOldestLocked()
120158
}
121159

160+
// Create rate limiters:
161+
// - Minute limiter: allows RequestsPerMinute requests per minute with burst of same
162+
// - Hour limiter: allows RequestsPerHour requests per hour with burst of same
163+
minuteRate := rate.Limit(float64(rl.config.RequestsPerMinute) / 60.0) // requests per second
164+
hourRate := rate.Limit(float64(rl.config.RequestsPerHour) / 3600.0) // requests per second
165+
166+
v = &visitor{
167+
minuteLimiter: rate.NewLimiter(minuteRate, rl.config.RequestsPerMinute),
168+
hourLimiter: rate.NewLimiter(hourRate, rl.config.RequestsPerHour),
169+
lastSeen: time.Now(),
170+
}
171+
rl.visitors[ip] = v
172+
122173
return v
123174
}
124175

@@ -146,33 +197,63 @@ func (rl *RateLimiter) shouldSkip(path string) bool {
146197
return false
147198
}
148199

149-
// getClientIP extracts the client IP from the request
150-
// It considers X-Forwarded-For and X-Real-IP headers for reverse proxy scenarios
151-
func getClientIP(r *http.Request) string {
152-
// Check X-Forwarded-For header (can contain multiple IPs)
153-
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
154-
// Take the first IP (original client)
155-
if idx := strings.Index(xff, ","); idx != -1 {
156-
xff = xff[:idx]
157-
}
158-
xff = strings.TrimSpace(xff)
159-
if xff != "" {
160-
return xff
200+
// getClientIP extracts the client IP from the request.
201+
// When TrustProxy is true, it considers X-Forwarded-For and X-Real-IP headers.
202+
// When TrustProxy is false, only RemoteAddr is used to prevent IP spoofing.
203+
func (rl *RateLimiter) getClientIP(r *http.Request) string {
204+
// Only trust proxy headers if explicitly configured
205+
if rl.config.TrustProxy {
206+
// Check X-Forwarded-For header (can contain multiple IPs)
207+
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
208+
// Take the first IP (original client)
209+
if idx := strings.Index(xff, ","); idx != -1 {
210+
xff = xff[:idx]
211+
}
212+
xff = strings.TrimSpace(xff)
213+
if ip := validateAndNormalizeIP(xff); ip != "" {
214+
return ip
215+
}
161216
}
162-
}
163217

164-
// Check X-Real-IP header
165-
if xri := r.Header.Get("X-Real-IP"); xri != "" {
166-
return strings.TrimSpace(xri)
218+
// Check X-Real-IP header
219+
if xri := r.Header.Get("X-Real-IP"); xri != "" {
220+
if ip := validateAndNormalizeIP(strings.TrimSpace(xri)); ip != "" {
221+
return ip
222+
}
223+
}
167224
}
168225

169-
// Fall back to RemoteAddr
226+
// Fall back to RemoteAddr (always used when TrustProxy is false)
170227
ip, _, err := net.SplitHostPort(r.RemoteAddr)
171228
if err != nil {
172229
// RemoteAddr might not have a port
173-
return r.RemoteAddr
230+
ip = r.RemoteAddr
174231
}
175-
return ip
232+
233+
// Validate and normalize the IP
234+
if validIP := validateAndNormalizeIP(ip); validIP != "" {
235+
return validIP
236+
}
237+
238+
// If all else fails, use a fallback that won't cause issues
239+
return "unknown"
240+
}
241+
242+
// validateAndNormalizeIP validates the IP string and returns a normalized form.
243+
// Returns empty string if the IP is invalid.
244+
func validateAndNormalizeIP(ip string) string {
245+
if ip == "" {
246+
return ""
247+
}
248+
249+
// Parse the IP to validate it
250+
parsedIP := net.ParseIP(ip)
251+
if parsedIP == nil {
252+
return ""
253+
}
254+
255+
// Return normalized string representation
256+
return parsedIP.String()
176257
}
177258

178259
// Middleware returns an HTTP middleware that enforces rate limiting
@@ -184,7 +265,7 @@ func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
184265
return
185266
}
186267

187-
ip := getClientIP(r)
268+
ip := rl.getClientIP(r)
188269

189270
if !rl.Allow(ip) {
190271
w.Header().Set("Content-Type", "application/problem+json")
@@ -199,7 +280,8 @@ func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
199280

200281
jsonData, err := json.Marshal(errorBody)
201282
if err != nil {
202-
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
283+
log.Printf("Failed to marshal rate limit error response: %v", err)
284+
_, _ = w.Write([]byte(`{"title":"Too Many Requests","status":429}`))
203285
return
204286
}
205287
_, _ = w.Write(jsonData)

0 commit comments

Comments
 (0)