Skip to content

Commit a829633

Browse files
committed
Add new configuration option to allow specific ip addresses
1 parent c95af23 commit a829633

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

geoblock.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type Config struct {
4141
UnknownCountryAPIResponse string `yaml:"unknownCountryApiResponse"`
4242
BlackListMode bool `yaml:"blacklist"`
4343
Countries []string `yaml:"countries,omitempty"`
44+
AllowedIPAddresses []string `yaml:"allowedIPAddresses,omitempty"`
4445
}
4546

4647
type ipEntry struct {
@@ -67,6 +68,7 @@ type GeoBlock struct {
6768
unknownCountryCode string
6869
blackListMode bool
6970
countries []string
71+
allowedIPAddresses []net.IP
7072
privateIPRanges []*net.IPNet
7173
database *lru.LRUCache
7274
name string
@@ -86,6 +88,15 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
8688
config.APITimeoutMs = 750
8789
}
8890

91+
allowedIPAddresses := make([]net.IP, len(config.AllowedIPAddresses))
92+
for idx, ipAddressEntry := range config.AllowedIPAddresses {
93+
ipAddress := net.ParseIP(ipAddressEntry)
94+
if ipAddress == nil {
95+
infoLogger.Fatal("Invalid ip address given!")
96+
}
97+
allowedIPAddresses[idx] = ipAddress
98+
}
99+
89100
infoLogger.SetOutput(os.Stdout)
90101

91102
infoLogger.Printf("allow local IPs: %t", config.AllowLocalRequests)
@@ -119,6 +130,7 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
119130
unknownCountryCode: config.UnknownCountryAPIResponse,
120131
blackListMode: config.BlackListMode,
121132
countries: config.Countries,
133+
allowedIPAddresses: allowedIPAddresses,
122134
privateIPRanges: initPrivateIPBlocks(),
123135
database: cache,
124136
name: name,
@@ -155,6 +167,14 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
155167
return
156168
}
157169

170+
if ipInSlice(*ipAddress, a.allowedIPAddresses) {
171+
if a.logLocalRequests {
172+
infoLogger.Println("Allow explicitly allowed ip: ", ipAddress)
173+
}
174+
a.next.ServeHTTP(rw, req)
175+
return
176+
}
177+
158178
cacheEntry, ok := a.database.Get(ipAddressString)
159179

160180
if !ok {
@@ -307,6 +327,15 @@ func stringInSlice(a string, list []string) bool {
307327
return false
308328
}
309329

330+
func ipInSlice(a net.IP, list []net.IP) bool {
331+
for _, b := range list {
332+
if b.Equal(a) {
333+
return true
334+
}
335+
}
336+
return false
337+
}
338+
310339
func parseIP(addr string) (net.IP, error) {
311340
ipAddress := net.ParseIP(addr)
312341

geoblock_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,61 @@ func TestInvalidApiResponse(t *testing.T) {
392392
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
393393
}
394394

395+
func TestExplicitlyAllowedIP(t *testing.T) {
396+
cfg := createTesterConfig()
397+
cfg.Countries = append(cfg.Countries, "CH")
398+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, caExampleIP)
399+
cfg.LogLocalRequests = true
400+
401+
ctx := context.Background()
402+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
403+
404+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
405+
if err != nil {
406+
t.Fatal(err)
407+
}
408+
409+
recorder := httptest.NewRecorder()
410+
411+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
412+
if err != nil {
413+
t.Fatal(err)
414+
}
415+
416+
req.Header.Add(xForwardedFor, caExampleIP)
417+
418+
handler.ServeHTTP(recorder, req)
419+
420+
assertStatusCode(t, recorder.Result(), http.StatusOK)
421+
}
422+
423+
func TestExplicitlyAllowedIPNoMatch(t *testing.T) {
424+
cfg := createTesterConfig()
425+
cfg.Countries = append(cfg.Countries, "CA")
426+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, caExampleIP)
427+
428+
ctx := context.Background()
429+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
430+
431+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
432+
if err != nil {
433+
t.Fatal(err)
434+
}
435+
436+
recorder := httptest.NewRecorder()
437+
438+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
439+
if err != nil {
440+
t.Fatal(err)
441+
}
442+
443+
req.Header.Add(xForwardedFor, chExampleIP)
444+
445+
handler.ServeHTTP(recorder, req)
446+
447+
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
448+
}
449+
395450
func assertStatusCode(t *testing.T, req *http.Response, expected int) {
396451
t.Helper()
397452

readme.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,7 @@ When set to `true` the filter logic is inverted, i.e. requests originating from
479479
### Countries `countries`
480480

481481
A list of country codes from which connections to the service should be allowed. Logic can be inverted by using the [`blackListMode`](#back-list-mode-blacklistmode).
482+
483+
### Allowed IP addresses `allowedIPAddresses`
484+
485+
A list of explicitly allowed IP addresses. IP addresses added to this list will always be allowed.

0 commit comments

Comments
 (0)