@@ -3,6 +3,7 @@ package ratelimit
33
44import (
55 "encoding/json"
6+ "log"
67 "net"
78 "net/http"
89 "strings"
@@ -22,6 +23,13 @@ type Config struct {
2223 CleanupInterval time.Duration
2324 // SkipPaths are paths that should not be rate limited
2425 SkipPaths []string
26+ // TrustProxy determines if X-Forwarded-For and X-Real-IP headers should be trusted.
27+ // Set to true only when running behind a trusted reverse proxy (e.g., nginx, cloud load balancer).
28+ // When false, only the direct connection IP (RemoteAddr) is used, preventing IP spoofing.
29+ TrustProxy bool
30+ // MaxVisitors is the maximum number of visitor entries to track (memory protection).
31+ // When exceeded, oldest entries are evicted. Default: 100000.
32+ MaxVisitors int
2533}
2634
2735// DefaultConfig returns the default rate limiting configuration
@@ -31,6 +39,8 @@ func DefaultConfig() Config {
3139 RequestsPerHour : 1000 ,
3240 CleanupInterval : 10 * time .Minute ,
3341 SkipPaths : []string {"/health" , "/ping" , "/metrics" },
42+ TrustProxy : false , // Secure default: don't trust proxy headers
43+ MaxVisitors : 100000 ,
3444 }
3545}
3646
@@ -51,6 +61,10 @@ type RateLimiter struct {
5161
5262// New creates a new RateLimiter with the given configuration
5363func New (cfg Config ) * RateLimiter {
64+ if cfg .MaxVisitors <= 0 {
65+ cfg .MaxVisitors = 100000
66+ }
67+
5468 rl := & RateLimiter {
5569 config : cfg ,
5670 visitors : make (map [string ]* visitor ),
@@ -96,29 +110,66 @@ func (rl *RateLimiter) cleanup() {
96110 }
97111}
98112
99- // getVisitor returns the visitor for the given IP, creating one if necessary
113+ // evictOldestLocked removes the oldest visitor entry. Must be called with lock held.
114+ func (rl * RateLimiter ) evictOldestLocked () {
115+ var oldestIP string
116+ var oldestTime time.Time
117+
118+ for ip , v := range rl .visitors {
119+ if oldestIP == "" || v .lastSeen .Before (oldestTime ) {
120+ oldestIP = ip
121+ oldestTime = v .lastSeen
122+ }
123+ }
124+
125+ if oldestIP != "" {
126+ delete (rl .visitors , oldestIP )
127+ }
128+ }
129+
130+ // getVisitor returns the visitor for the given IP, creating one if necessary.
131+ // Implements memory protection by evicting oldest entries when MaxVisitors is reached.
100132func (rl * RateLimiter ) getVisitor (ip string ) * visitor {
133+ // Try read lock first for existing visitors (common case)
134+ rl .mu .RLock ()
135+ v , exists := rl .visitors [ip ]
136+ rl .mu .RUnlock ()
137+
138+ if exists {
139+ // Update timestamp - this is a minor race but acceptable for lastSeen
140+ v .lastSeen = time .Now ()
141+ return v
142+ }
143+
144+ // Need to create new visitor - acquire write lock
101145 rl .mu .Lock ()
102146 defer rl .mu .Unlock ()
103147
104- v , exists := rl .visitors [ip ]
105- if ! exists {
106- // Create rate limiters:
107- // - Minute limiter: allows RequestsPerMinute requests per minute with burst of same
108- // - Hour limiter: allows RequestsPerHour requests per hour with burst of same
109- minuteRate := rate .Limit (float64 (rl .config .RequestsPerMinute ) / 60.0 ) // requests per second
110- hourRate := rate .Limit (float64 (rl .config .RequestsPerHour ) / 3600.0 ) // requests per second
111-
112- v = & visitor {
113- minuteLimiter : rate .NewLimiter (minuteRate , rl .config .RequestsPerMinute ),
114- hourLimiter : rate .NewLimiter (hourRate , rl .config .RequestsPerHour ),
115- lastSeen : time .Now (),
116- }
117- rl .visitors [ip ] = v
118- } else {
148+ // Double-check after acquiring write lock
149+ v , exists = rl .visitors [ip ]
150+ if exists {
119151 v .lastSeen = time .Now ()
152+ return v
153+ }
154+
155+ // Enforce max visitors limit (memory protection)
156+ if len (rl .visitors ) >= rl .config .MaxVisitors {
157+ rl .evictOldestLocked ()
120158 }
121159
160+ // Create rate limiters:
161+ // - Minute limiter: allows RequestsPerMinute requests per minute with burst of same
162+ // - Hour limiter: allows RequestsPerHour requests per hour with burst of same
163+ minuteRate := rate .Limit (float64 (rl .config .RequestsPerMinute ) / 60.0 ) // requests per second
164+ hourRate := rate .Limit (float64 (rl .config .RequestsPerHour ) / 3600.0 ) // requests per second
165+
166+ v = & visitor {
167+ minuteLimiter : rate .NewLimiter (minuteRate , rl .config .RequestsPerMinute ),
168+ hourLimiter : rate .NewLimiter (hourRate , rl .config .RequestsPerHour ),
169+ lastSeen : time .Now (),
170+ }
171+ rl .visitors [ip ] = v
172+
122173 return v
123174}
124175
@@ -146,33 +197,63 @@ func (rl *RateLimiter) shouldSkip(path string) bool {
146197 return false
147198}
148199
149- // getClientIP extracts the client IP from the request
150- // It considers X-Forwarded-For and X-Real-IP headers for reverse proxy scenarios
151- func getClientIP (r * http.Request ) string {
152- // Check X-Forwarded-For header (can contain multiple IPs)
153- if xff := r .Header .Get ("X-Forwarded-For" ); xff != "" {
154- // Take the first IP (original client)
155- if idx := strings .Index (xff , "," ); idx != - 1 {
156- xff = xff [:idx ]
157- }
158- xff = strings .TrimSpace (xff )
159- if xff != "" {
160- return xff
200+ // getClientIP extracts the client IP from the request.
201+ // When TrustProxy is true, it considers X-Forwarded-For and X-Real-IP headers.
202+ // When TrustProxy is false, only RemoteAddr is used to prevent IP spoofing.
203+ func (rl * RateLimiter ) getClientIP (r * http.Request ) string {
204+ // Only trust proxy headers if explicitly configured
205+ if rl .config .TrustProxy {
206+ // Check X-Forwarded-For header (can contain multiple IPs)
207+ if xff := r .Header .Get ("X-Forwarded-For" ); xff != "" {
208+ // Take the first IP (original client)
209+ if idx := strings .Index (xff , "," ); idx != - 1 {
210+ xff = xff [:idx ]
211+ }
212+ xff = strings .TrimSpace (xff )
213+ if ip := validateAndNormalizeIP (xff ); ip != "" {
214+ return ip
215+ }
161216 }
162- }
163217
164- // Check X-Real-IP header
165- if xri := r .Header .Get ("X-Real-IP" ); xri != "" {
166- return strings .TrimSpace (xri )
218+ // Check X-Real-IP header
219+ if xri := r .Header .Get ("X-Real-IP" ); xri != "" {
220+ if ip := validateAndNormalizeIP (strings .TrimSpace (xri )); ip != "" {
221+ return ip
222+ }
223+ }
167224 }
168225
169- // Fall back to RemoteAddr
226+ // Fall back to RemoteAddr (always used when TrustProxy is false)
170227 ip , _ , err := net .SplitHostPort (r .RemoteAddr )
171228 if err != nil {
172229 // RemoteAddr might not have a port
173- return r .RemoteAddr
230+ ip = r .RemoteAddr
174231 }
175- return ip
232+
233+ // Validate and normalize the IP
234+ if validIP := validateAndNormalizeIP (ip ); validIP != "" {
235+ return validIP
236+ }
237+
238+ // If all else fails, use a fallback that won't cause issues
239+ return "unknown"
240+ }
241+
242+ // validateAndNormalizeIP validates the IP string and returns a normalized form.
243+ // Returns empty string if the IP is invalid.
244+ func validateAndNormalizeIP (ip string ) string {
245+ if ip == "" {
246+ return ""
247+ }
248+
249+ // Parse the IP to validate it
250+ parsedIP := net .ParseIP (ip )
251+ if parsedIP == nil {
252+ return ""
253+ }
254+
255+ // Return normalized string representation
256+ return parsedIP .String ()
176257}
177258
178259// Middleware returns an HTTP middleware that enforces rate limiting
@@ -184,7 +265,7 @@ func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
184265 return
185266 }
186267
187- ip := getClientIP (r )
268+ ip := rl . getClientIP (r )
188269
189270 if ! rl .Allow (ip ) {
190271 w .Header ().Set ("Content-Type" , "application/problem+json" )
@@ -199,7 +280,8 @@ func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
199280
200281 jsonData , err := json .Marshal (errorBody )
201282 if err != nil {
202- http .Error (w , "Internal Server Error" , http .StatusInternalServerError )
283+ log .Printf ("Failed to marshal rate limit error response: %v" , err )
284+ _ , _ = w .Write ([]byte (`{"title":"Too Many Requests","status":429}` ))
203285 return
204286 }
205287 _ , _ = w .Write (jsonData )
0 commit comments