Skip to content

Commit 44d7bf7

Browse files
committed
fix: improve robustness and error handling across codebase
Security and stability improvements: - Fix potential panic in isIPBlacklisted: use netip.ParseAddr instead of MustParseAddr - Fix type assertion panic in processRuleMatch: use getLogID helper function - Add HTTP client timeout (30s) in tor.go to prevent hanging requests - Improve extractIP to handle empty/malformed input gracefully - Improve getClientIP to validate X-Forwarded-For IPs before use Input validation improvements: - Add comprehensive validation in Validate() method for: - Negative anomaly_threshold - Invalid rate limit configuration - Negative max_request_body_size - Negative log_buffer - Add validation in NewRateLimiter for requests, window, cleanup_interval All tests pass with race detector enabled.
1 parent 05d7f8a commit 44d7bf7

File tree

6 files changed

+70
-6
lines changed

6 files changed

+70
-6
lines changed

blacklist.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,17 @@ func (m *Middleware) isIPBlacklisted(addr string) bool {
6868
ip := extractIP(addr)
6969

7070
if m.ipBlacklist == nil {
71-
m.logger.Error("blacklist", zap.String("IP blacklist", "is nil"))
71+
m.logger.Debug("IP blacklist not initialized, skipping check")
72+
return false
73+
}
74+
75+
parsedIP, err := netip.ParseAddr(ip)
76+
if err != nil {
77+
m.logger.Debug("Failed to parse IP address for blacklist check", zap.String("ip", ip), zap.Error(err))
78+
return false
7279
}
7380

74-
if m.ipBlacklist.Contains(netip.MustParseAddr(ip)) {
81+
if m.ipBlacklist.Contains(parsedIP) {
7582
m.muIPBlacklistMetrics.Lock() // Acquire lock before accessing shared counter
7683
m.IPBlacklistBlockCount++ // Increment the counter
7784
m.muIPBlacklistMetrics.Unlock() // Release lock after accessing counter

caddywaf.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,5 +608,31 @@ func (m *Middleware) Validate() error {
608608
if m.logLevel == 0 {
609609
m.logLevel = zapcore.InfoLevel // Default log level
610610
}
611+
612+
// Validate anomaly threshold
613+
if m.AnomalyThreshold < 0 {
614+
return fmt.Errorf("anomaly_threshold cannot be negative: %d", m.AnomalyThreshold)
615+
}
616+
617+
// Validate rate limit configuration if enabled
618+
if m.RateLimit.Requests > 0 {
619+
if m.RateLimit.Window <= 0 {
620+
return fmt.Errorf("rate_limit window must be positive when rate limiting is enabled")
621+
}
622+
if m.RateLimit.CleanupInterval <= 0 {
623+
return fmt.Errorf("rate_limit cleanup_interval must be positive when rate limiting is enabled")
624+
}
625+
}
626+
627+
// Validate max request body size
628+
if m.MaxRequestBodySize < 0 {
629+
return fmt.Errorf("max_request_body_size cannot be negative: %d", m.MaxRequestBodySize)
630+
}
631+
632+
// Validate log buffer
633+
if m.LogBuffer < 0 {
634+
return fmt.Errorf("log_buffer cannot be negative: %d", m.LogBuffer)
635+
}
636+
611637
return nil
612638
}

helpers.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,35 @@ func appendCIDR(ip string) string {
3939
}
4040

4141
// extractIP extracts the IP address from a remote address string.
42+
// Returns the original input if parsing fails, which allows upstream
43+
// code to handle invalid IPs gracefully.
4244
func extractIP(remoteAddr string) string {
45+
if remoteAddr == "" {
46+
return ""
47+
}
4348
host, _, err := net.SplitHostPort(remoteAddr)
4449
if err != nil {
45-
return remoteAddr // Assume the input is already an IP address
50+
// Could be an IP without port, validate it
51+
if ip := net.ParseIP(remoteAddr); ip != nil {
52+
return remoteAddr
53+
}
54+
// Return as-is for upstream handling
55+
return remoteAddr
4656
}
4757
return host
4858
}
4959

5060
// getClientIP returns the real client IP, checking X-Forwarded-For header first.
61+
// Falls back to RemoteAddr if X-Forwarded-For is empty or contains invalid data.
5162
func getClientIP(r *http.Request) string {
5263
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
5364
ips := strings.Split(xff, ",")
5465
if len(ips) > 0 {
55-
return strings.TrimSpace(ips[0])
66+
clientIP := strings.TrimSpace(ips[0])
67+
// Validate the IP before returning
68+
if clientIP != "" && net.ParseIP(extractIP(clientIP)) != nil {
69+
return clientIP
70+
}
5671
}
5772
}
5873
return r.RemoteAddr

ratelimiter.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ type RateLimiter struct {
3737

3838
// NewRateLimiter creates a new RateLimiter instance.
3939
func NewRateLimiter(config RateLimit) (*RateLimiter, error) {
40+
// Validate configuration
41+
if config.Requests <= 0 {
42+
return nil, fmt.Errorf("rate limit requests must be positive, got %d", config.Requests)
43+
}
44+
if config.Window <= 0 {
45+
return nil, fmt.Errorf("rate limit window must be positive, got %v", config.Window)
46+
}
47+
if config.CleanupInterval <= 0 {
48+
return nil, fmt.Errorf("rate limit cleanup interval must be positive, got %v", config.CleanupInterval)
49+
}
50+
4051
// Compile path regexes if paths are provided
4152
if len(config.Paths) > 0 {
4253
config.PathRegexes = make([]*regexp.Regexp, len(config.Paths))

rules.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
)
1616

1717
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, target, value string, state *WAFState) bool {
18-
logID := r.Context().Value(ContextKeyLogId("logID")).(string)
18+
logID := getLogID(r.Context())
1919

2020
redactedValue := m.requestValueExtractor.RedactValueIfSensitive(target, value)
2121

tor.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ func (t *TorConfig) updateTorExitNodes() error {
4848
url = t.CustomTORExitNodeURL
4949
}
5050

51-
resp, err := http.Get(url)
51+
// Create HTTP client with timeout to avoid hanging requests
52+
client := &http.Client{
53+
Timeout: 30 * time.Second,
54+
}
55+
56+
resp, err := client.Get(url)
5257
if err != nil {
5358
return fmt.Errorf("http get failed for %s: %w", url, err) // Improved error message with URL
5459
}

0 commit comments

Comments
 (0)