Skip to content

Commit c2fd6a3

Browse files
Merge pull request #83 from Bestigor89/fix/xff-country-asn-blocking
fix: use X-Forwarded-For for country/ASN blocking behind proxies
2 parents cf45542 + 4c18165 commit c2fd6a3

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

handler.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
328328
// Whitelisting
329329
if m.CountryWhitelist.Enabled {
330330
m.logger.Debug("Starting country whitelisting phase")
331-
allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP)
331+
clientIP := getClientIP(r)
332+
allowed, err := m.isCountryInList(clientIP, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP)
332333
if err != nil {
333334
m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist",
334335
r,
@@ -360,7 +361,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
360361
// ASN Blocking
361362
if m.BlockASNs.Enabled {
362363
m.logger.Debug("Starting ASN blocking phase")
363-
blocked, err := m.geoIPHandler.IsASNInList(r.RemoteAddr, m.BlockASNs.BlockedASNs, m.BlockASNs.geoIP)
364+
clientIP := getClientIP(r)
365+
blocked, err := m.geoIPHandler.IsASNInList(clientIP, m.BlockASNs.BlockedASNs, m.BlockASNs.geoIP)
364366
if err != nil {
365367
m.logRequest(zapcore.ErrorLevel, "Failed to check ASN blocking",
366368
r,
@@ -377,7 +379,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
377379
return
378380
}
379381
} else if blocked {
380-
asnInfo := m.geoIPHandler.GetASN(r.RemoteAddr, m.BlockASNs.geoIP)
382+
asnInfo := m.geoIPHandler.GetASN(clientIP, m.BlockASNs.geoIP)
381383
m.blockRequest(w, r, state, http.StatusForbidden, "asn_block", "asn_block_rule",
382384
zap.String("message", "Request blocked by ASN"),
383385
zap.String("asn", asnInfo),
@@ -394,7 +396,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
394396
// Blacklisting
395397
if m.CountryBlacklist.Enabled {
396398
m.logger.Debug("Starting country blacklisting phase")
397-
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP)
399+
clientIP := getClientIP(r)
400+
blocked, err := m.isCountryInList(clientIP, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP)
398401
if err != nil {
399402
m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting",
400403
r,

helpers.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package caddywaf
22

33
import (
44
"net"
5+
"net/http"
56
"os"
67
"strings"
78
)
@@ -45,3 +46,14 @@ func extractIP(remoteAddr string) string {
4546
}
4647
return host
4748
}
49+
50+
// getClientIP returns the real client IP, checking X-Forwarded-For header first.
51+
func getClientIP(r *http.Request) string {
52+
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
53+
ips := strings.Split(xff, ",")
54+
if len(ips) > 0 {
55+
return strings.TrimSpace(ips[0])
56+
}
57+
}
58+
return r.RemoteAddr
59+
}

helpers_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package caddywaf
22

33
import (
4+
"net/http"
45
"os"
56
"testing"
67
)
@@ -48,3 +49,57 @@ func TestFileExists(t *testing.T) {
4849
})
4950
}
5051
}
52+
53+
func TestGetClientIP(t *testing.T) {
54+
tests := []struct {
55+
name string
56+
remoteAddr string
57+
xff string
58+
want string
59+
}{
60+
{
61+
name: "no X-Forwarded-For, use RemoteAddr",
62+
remoteAddr: "192.168.1.1:12345",
63+
xff: "",
64+
want: "192.168.1.1:12345",
65+
},
66+
{
67+
name: "single IP in X-Forwarded-For",
68+
remoteAddr: "10.0.0.1:12345",
69+
xff: "203.0.113.50",
70+
want: "203.0.113.50",
71+
},
72+
{
73+
name: "multiple IPs in X-Forwarded-For",
74+
remoteAddr: "10.0.0.1:12345",
75+
xff: "203.0.113.50, 70.41.3.18, 150.172.238.178",
76+
want: "203.0.113.50",
77+
},
78+
{
79+
name: "X-Forwarded-For with spaces",
80+
remoteAddr: "10.0.0.1:12345",
81+
xff: " 203.0.113.50 , 70.41.3.18 ",
82+
want: "203.0.113.50",
83+
},
84+
{
85+
name: "IPv6 in X-Forwarded-For",
86+
remoteAddr: "10.0.0.1:12345",
87+
xff: "2001:db8::1",
88+
want: "2001:db8::1",
89+
},
90+
}
91+
92+
for _, tt := range tests {
93+
t.Run(tt.name, func(t *testing.T) {
94+
req, _ := http.NewRequest("GET", "/", nil)
95+
req.RemoteAddr = tt.remoteAddr
96+
if tt.xff != "" {
97+
req.Header.Set("X-Forwarded-For", tt.xff)
98+
}
99+
100+
if got := getClientIP(req); got != tt.want {
101+
t.Errorf("getClientIP() = %v, want %v", got, tt.want)
102+
}
103+
})
104+
}
105+
}

0 commit comments

Comments
 (0)