@@ -737,7 +737,6 @@ func TestInvalidApiResponse(t *testing.T) {
737737 var apiStub = httptest .NewServer (http .HandlerFunc (apiHandlerInvalid ))
738738
739739 cfg := createTesterConfig ()
740- fmt .Println (apiStub .URL )
741740 cfg .API = apiStub .URL + "/{ip}"
742741 cfg .Countries = append (cfg .Countries , "CH" )
743742
@@ -770,7 +769,6 @@ func TestApiResponseTimeoutAllowed(t *testing.T) {
770769 var apiStub = httptest .NewServer (http .HandlerFunc (apiTimeout ))
771770
772771 cfg := createTesterConfig ()
773- fmt .Println (apiStub .URL )
774772 cfg .API = apiStub .URL + "/{ip}"
775773 cfg .Countries = append (cfg .Countries , "CH" )
776774 cfg .APITimeoutMs = 5
@@ -805,7 +803,6 @@ func TestApiResponseTimeoutNotAllowed(t *testing.T) {
805803 var apiStub = httptest .NewServer (http .HandlerFunc (apiTimeout ))
806804
807805 cfg := createTesterConfig ()
808- fmt .Println (apiStub .URL )
809806 cfg .API = apiStub .URL + "/{ip}"
810807 cfg .Countries = append (cfg .Countries , "CH" )
811808 cfg .APITimeoutMs = 5
@@ -863,6 +860,41 @@ func TestExplicitlyAllowedIP(t *testing.T) {
863860 assertStatusCode (t , recorder .Result (), http .StatusOK )
864861}
865862
863+ func TestExplicitlyAllowedIPWithIPCountryHeader (t * testing.T ) {
864+ // set up our fake api server
865+ apiHandler := & CountryCodeHandler {ResponseCountryCode : "CA" }
866+ var apiStub = httptest .NewServer (apiHandler )
867+
868+ cfg := createTesterConfig ()
869+ cfg .API = apiStub .URL + "/{ip}"
870+ cfg .Countries = append (cfg .Countries , "CH" )
871+ cfg .AllowedIPAddresses = append (cfg .AllowedIPAddresses , caExampleIP )
872+ cfg .LogLocalRequests = true
873+ cfg .AddCountryHeader = true
874+
875+ ctx := context .Background ()
876+ next := http .HandlerFunc (func (_ http.ResponseWriter , _ * http.Request ) {})
877+
878+ handler , err := geoblock .New (ctx , next , cfg , "GeoBlock" )
879+ if err != nil {
880+ t .Fatal (err )
881+ }
882+
883+ recorder := httptest .NewRecorder ()
884+
885+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , "http://localhost" , nil )
886+ if err != nil {
887+ t .Fatal (err )
888+ }
889+
890+ req .Header .Add (xForwardedFor , caExampleIP )
891+
892+ handler .ServeHTTP (recorder , req )
893+
894+ assertStatusCode (t , recorder .Result (), http .StatusOK )
895+ assertRequestHeader (t , req , CountryHeader , "CA" )
896+ }
897+
866898func TestExplicitlyAllowedIPNoMatch (t * testing.T ) {
867899 cfg := createTesterConfig ()
868900 cfg .Countries = append (cfg .Countries , "CA" )
@@ -1156,8 +1188,6 @@ func assertStatusCode(t *testing.T, req *http.Response, expected int) {
11561188func assertRequestHeader (t * testing.T , req * http.Request , key string , expected string ) {
11571189 t .Helper ()
11581190
1159- fmt .Println (req .Header .Get (key ))
1160-
11611191 if received := req .Header .Get (key ); received != expected {
11621192 t .Errorf ("header value mismatch: %s: %s <> %s" , key , expected , received )
11631193 }
0 commit comments