Skip to content

Commit f4169f8

Browse files
maciej-tatarskiPascalMinder
authored andcommitted
add option to IgnoreAPIFailures
1 parent a45bf42 commit f4169f8

File tree

2 files changed

+83
-7
lines changed

2 files changed

+83
-7
lines changed

geoblock.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type Config struct {
3838
API string `yaml:"api"`
3939
APITimeoutMs int `yaml:"apiTimeoutMs"`
4040
IgnoreAPITimeout bool `yaml:"ignoreApiTimeout"`
41+
IgnoreAPIFailures bool `yaml:"ignoreApiFailures"`
4142
IPGeolocationHTTPHeaderField string `yaml:"ipGeolocationHttpHeaderField"`
4243
XForwardedForReverseProxy bool `yaml:"xForwardedForReverseProxy"`
4344
CacheSize int `yaml:"cacheSize"`
@@ -74,6 +75,7 @@ type GeoBlock struct {
7475
apiURI string
7576
apiTimeoutMs int
7677
ignoreAPITimeout bool
78+
ignoreAPIFailures bool
7779
iPGeolocationHTTPHeaderField string
7880
xForwardedForReverseProxy bool
7981
forceMonthlyUpdate bool
@@ -162,6 +164,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
162164
apiURI: config.API,
163165
apiTimeoutMs: config.APITimeoutMs,
164166
ignoreAPITimeout: config.IgnoreAPITimeout,
167+
ignoreAPIFailures: config.IgnoreAPIFailures,
165168
iPGeolocationHTTPHeaderField: config.IPGeolocationHTTPHeaderField,
166169
xForwardedForReverseProxy: config.XForwardedForReverseProxy,
167170
forceMonthlyUpdate: config.ForceMonthlyUpdate,
@@ -270,19 +273,26 @@ func (a *GeoBlock) allowDenyIPAddress(requestIPAddr *net.IP, req *http.Request)
270273

271274
func (a *GeoBlock) allowDenyCachedRequestIP(requestIPAddr *net.IP, req *http.Request) (bool, string) {
272275
ipAddressString := requestIPAddr.String()
273-
cacheEntry, ok := a.database.Get(ipAddressString)
276+
cacheEntry, cacheHit := a.database.Get(ipAddressString)
274277

275278
var entry ipEntry
276279
var err error
277-
if !ok {
280+
if !cacheHit {
278281
entry, err = a.createNewIPEntry(req, ipAddressString)
279-
if err != nil && !(os.IsTimeout(err) && a.ignoreAPITimeout) {
282+
if err != nil {
283+
if a.ignoreAPIFailures {
284+
a.infoLogger.Printf("%s: request allowed [%s] due to API failure", a.name, requestIPAddr)
285+
return true, ""
286+
}
287+
288+
if os.IsTimeout(err) && a.ignoreAPITimeout {
289+
a.infoLogger.Printf("%s: request allowed [%s] due to API timeout", a.name, requestIPAddr)
290+
// TODO: this was previously an immediate response to the client
291+
return true, ""
292+
}
293+
280294
a.infoLogger.Printf("%s: request denied [%s] due to error: %s", a.name, requestIPAddr, err)
281295
return false, ""
282-
} else if os.IsTimeout(err) && a.ignoreAPITimeout {
283-
a.infoLogger.Printf("%s: request allowed [%s] due to API timeout", a.name, requestIPAddr)
284-
// TODO: this was previously an immediate response to the client
285-
return true, ""
286296
}
287297
} else {
288298
entry = cacheEntry.(ipEntry)
@@ -296,6 +306,10 @@ func (a *GeoBlock) allowDenyCachedRequestIP(requestIPAddr *net.IP, req *http.Req
296306
if time.Since(entry.Timestamp).Hours() >= numberOfHoursInMonth && a.forceMonthlyUpdate {
297307
entry, err = a.createNewIPEntry(req, ipAddressString)
298308
if err != nil {
309+
if a.ignoreAPIFailures {
310+
a.infoLogger.Printf("%s: request allowed [%s] due to API failure", a.name, requestIPAddr)
311+
return true, ""
312+
}
299313
a.infoLogger.Printf("%s: request denied [%s] due to error: %s", a.name, requestIPAddr, err)
300314
return false, ""
301315
}
@@ -464,6 +478,10 @@ func (a *GeoBlock) callGeoJS(ipAddress string) (string, error) {
464478
return "", err
465479
}
466480

481+
if res.StatusCode != http.StatusOK {
482+
return "", fmt.Errorf("API response status code: %d", res.StatusCode)
483+
}
484+
467485
if res.Body != nil {
468486
defer res.Body.Close()
469487
}

geoblock_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,64 @@ func TestTimeoutOnApiResponse_AllowWhenIgnoreTimeoutTrue(t *testing.T) {
13161316
assertStatusCode(t, rec.Result(), http.StatusOK)
13171317
}
13181318

1319+
func TestErrorOnApiResponse_AllowWhenIgnoreAPIFailuresTrue(t *testing.T) {
1320+
// Stub server that fails to respond correctly.
1321+
apiStub := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1322+
w.WriteHeader(http.StatusInternalServerError)
1323+
}))
1324+
defer apiStub.Close()
1325+
1326+
cfg := createTesterConfig()
1327+
cfg.API = apiStub.URL + "/{ip}"
1328+
cfg.Countries = append(cfg.Countries, "CH")
1329+
cfg.IgnoreAPIFailures = true // API failures should ALLOW
1330+
1331+
ctx := context.Background()
1332+
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})
1333+
1334+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
1335+
if err != nil {
1336+
t.Fatal(err)
1337+
}
1338+
1339+
rec := httptest.NewRecorder()
1340+
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
1341+
req.Header.Add(xForwardedFor, chExampleIP)
1342+
1343+
handler.ServeHTTP(rec, req)
1344+
1345+
assertStatusCode(t, rec.Result(), http.StatusOK)
1346+
}
1347+
1348+
func TestErrorOnApiResponse_AllowWhenIgnoreAPIFailuresFalse(t *testing.T) {
1349+
// Stub server that fails to respond correctly.
1350+
apiStub := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1351+
w.WriteHeader(http.StatusInternalServerError)
1352+
}))
1353+
defer apiStub.Close()
1354+
1355+
cfg := createTesterConfig()
1356+
cfg.API = apiStub.URL + "/{ip}"
1357+
cfg.Countries = append(cfg.Countries, "CH")
1358+
cfg.IgnoreAPIFailures = false // API failures should DENY
1359+
1360+
ctx := context.Background()
1361+
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})
1362+
1363+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
1364+
if err != nil {
1365+
t.Fatal(err)
1366+
}
1367+
1368+
rec := httptest.NewRecorder()
1369+
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
1370+
req.Header.Add(xForwardedFor, chExampleIP)
1371+
1372+
handler.ServeHTTP(rec, req)
1373+
1374+
assertStatusCode(t, rec.Result(), http.StatusForbidden)
1375+
}
1376+
13191377
func assertStatusCode(t *testing.T, req *http.Response, expected int) {
13201378
t.Helper()
13211379

0 commit comments

Comments
 (0)