Skip to content

Commit f923e4e

Browse files
committed
Allow adding ip address ranges
1 parent a829633 commit f923e4e

File tree

3 files changed

+143
-4
lines changed

3 files changed

+143
-4
lines changed

geoblock.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ type GeoBlock struct {
6969
blackListMode bool
7070
countries []string
7171
allowedIPAddresses []net.IP
72+
allowedIPRanges []*net.IPNet
7273
privateIPRanges []*net.IPNet
7374
database *lru.LRUCache
7475
name string
@@ -88,13 +89,21 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
8889
config.APITimeoutMs = 750
8990
}
9091

91-
allowedIPAddresses := make([]net.IP, len(config.AllowedIPAddresses))
92-
for idx, ipAddressEntry := range config.AllowedIPAddresses {
92+
var allowedIPAddresses []net.IP
93+
var allowedIPRanges []*net.IPNet
94+
for _, ipAddressEntry := range config.AllowedIPAddresses {
95+
ip, ipBlock, err := net.ParseCIDR(ipAddressEntry)
96+
if err == nil {
97+
allowedIPAddresses = append(allowedIPAddresses, ip)
98+
allowedIPRanges = append(allowedIPRanges, ipBlock)
99+
continue
100+
}
101+
93102
ipAddress := net.ParseIP(ipAddressEntry)
94103
if ipAddress == nil {
95104
infoLogger.Fatal("Invalid ip address given!")
96105
}
97-
allowedIPAddresses[idx] = ipAddress
106+
allowedIPAddresses = append(allowedIPAddresses, ipAddress)
98107
}
99108

100109
infoLogger.SetOutput(os.Stdout)
@@ -131,6 +140,7 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
131140
blackListMode: config.BlackListMode,
132141
countries: config.Countries,
133142
allowedIPAddresses: allowedIPAddresses,
143+
allowedIPRanges: allowedIPRanges,
134144
privateIPRanges: initPrivateIPBlocks(),
135145
database: cache,
136146
name: name,
@@ -175,6 +185,16 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
175185
return
176186
}
177187

188+
for _, ipRange := range a.allowedIPRanges {
189+
if ipRange.Contains(*ipAddress) {
190+
if a.logLocalRequests {
191+
infoLogger.Println("Allow explicitly allowed ip: ", ipAddress)
192+
}
193+
a.next.ServeHTTP(rw, req)
194+
return
195+
}
196+
}
197+
178198
cacheEntry, ok := a.database.Get(ipAddressString)
179199

180200
if !ok {

geoblock_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,118 @@ func TestExplicitlyAllowedIPNoMatch(t *testing.T) {
447447
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
448448
}
449449

450+
func TestExplicitlyAllowedIPRangeIPV6(t *testing.T) {
451+
cfg := createTesterConfig()
452+
cfg.Countries = append(cfg.Countries, "CA")
453+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, "2a00:00c0:2:3::567:8001/128")
454+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, "8.8.8.8")
455+
456+
ctx := context.Background()
457+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
458+
459+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
460+
if err != nil {
461+
t.Fatal(err)
462+
}
463+
464+
recorder := httptest.NewRecorder()
465+
466+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
467+
if err != nil {
468+
t.Fatal(err)
469+
}
470+
471+
req.Header.Add(xForwardedFor, "2a00:00c0:2:3::567:8001")
472+
473+
handler.ServeHTTP(recorder, req)
474+
475+
assertStatusCode(t, recorder.Result(), http.StatusOK)
476+
}
477+
478+
func TestExplicitlyAllowedIPRangeIPV6NoMatch(t *testing.T) {
479+
cfg := createTesterConfig()
480+
cfg.Countries = append(cfg.Countries, "CA")
481+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, "2a00:00c0:2:3::567:8001/128")
482+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, "8.8.8.8")
483+
484+
ctx := context.Background()
485+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
486+
487+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
488+
if err != nil {
489+
t.Fatal(err)
490+
}
491+
492+
recorder := httptest.NewRecorder()
493+
494+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
495+
if err != nil {
496+
t.Fatal(err)
497+
}
498+
499+
req.Header.Add(xForwardedFor, "2a00:00c0:2:3::567:8002")
500+
501+
handler.ServeHTTP(recorder, req)
502+
503+
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
504+
}
505+
506+
func TestExplicitlyAllowedIPRangeIPV4(t *testing.T) {
507+
cfg := createTesterConfig()
508+
cfg.Countries = append(cfg.Countries, "CA")
509+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, "178.90.234.0/27")
510+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, "8.8.8.8")
511+
512+
ctx := context.Background()
513+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
514+
515+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
516+
if err != nil {
517+
t.Fatal(err)
518+
}
519+
520+
recorder := httptest.NewRecorder()
521+
522+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
523+
if err != nil {
524+
t.Fatal(err)
525+
}
526+
527+
req.Header.Add(xForwardedFor, "178.90.234.30")
528+
529+
handler.ServeHTTP(recorder, req)
530+
531+
assertStatusCode(t, recorder.Result(), http.StatusOK)
532+
}
533+
534+
func TestExplicitlyAllowedIPRangeIPV4NoMatch(t *testing.T) {
535+
cfg := createTesterConfig()
536+
cfg.Countries = append(cfg.Countries, "CA")
537+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, "178.90.234.0/27")
538+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, "8.8.8.8")
539+
540+
ctx := context.Background()
541+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
542+
543+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
544+
if err != nil {
545+
t.Fatal(err)
546+
}
547+
548+
recorder := httptest.NewRecorder()
549+
550+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
551+
if err != nil {
552+
t.Fatal(err)
553+
}
554+
555+
req.Header.Add(xForwardedFor, "178.90.234.55")
556+
557+
handler.ServeHTTP(recorder, req)
558+
559+
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
560+
}
561+
450562
func assertStatusCode(t *testing.T, req *http.Response, expected int) {
451563
t.Helper()
452564

readme.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,4 +482,11 @@ A list of country codes from which connections to the service should be allowed.
482482

483483
### Allowed IP addresses `allowedIPAddresses`
484484

485-
A list of explicitly allowed IP addresses. IP addresses added to this list will always be allowed.
485+
A list of explicitly allowed IP addresses or IP address ranges. IP addresses and ranges added to this list will always be allowed.
486+
487+
```yaml
488+
allowedIPAddresses:
489+
- 192.0.2.10 # single IPv4 address
490+
- 203.0.113.0/24 # IPv4 range in CIDR format
491+
- 2001:db8:1234:/48 # IPv6 range in CIDR format
492+
```

0 commit comments

Comments
 (0)