@@ -653,3 +653,137 @@ func TestLimitHandlerEmptyHeader(t *testing.T) {
653653
654654 wg .Wait () // Block until go func is done.
655655}
656+
657+ func TestHTTPMiddleware (t * testing.T ) {
658+ t .Run ("basic request" , func (t * testing.T ) {
659+ lmt := NewLimiter (1 , nil )
660+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
661+ w .WriteHeader (http .StatusOK )
662+ })
663+ wrapped := HTTPMiddleware (lmt )(handler )
664+ w := httptest .NewRecorder ()
665+ r := httptest .NewRequest (http .MethodGet , "/test" , nil )
666+ r .RemoteAddr = "127.0.0.1:12345"
667+ wrapped .ServeHTTP (w , r )
668+ if w .Code != http .StatusOK {
669+ t .Errorf ("expected status %d, got %d" , http .StatusOK , w .Code )
670+ }
671+ })
672+
673+ t .Run ("rate limit exceeded" , func (t * testing.T ) {
674+ lmt := NewLimiter (0.1 , nil ) // only allow one request per 10 seconds
675+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
676+ w .WriteHeader (http .StatusOK )
677+ })
678+ wrapped := HTTPMiddleware (lmt )(handler )
679+
680+ // first request
681+ w1 := httptest .NewRecorder ()
682+ r1 := httptest .NewRequest (http .MethodGet , "/test" , nil )
683+ r1 .RemoteAddr = "127.0.0.1:12345"
684+ wrapped .ServeHTTP (w1 , r1 )
685+ if w1 .Code != http .StatusOK {
686+ t .Errorf ("first request: expected status %d, got %d" , http .StatusOK , w1 .Code )
687+ }
688+
689+ // immediate second request should fail
690+ w2 := httptest .NewRecorder ()
691+ r2 := httptest .NewRequest (http .MethodGet , "/test" , nil )
692+ r2 .RemoteAddr = "127.0.0.1:12345"
693+ wrapped .ServeHTTP (w2 , r2 )
694+ if w2 .Code != http .StatusTooManyRequests {
695+ t .Errorf ("second request: expected status %d, got %d" , http .StatusTooManyRequests , w2 .Code )
696+ }
697+ if ! strings .Contains (w2 .Body .String (), "maximum request limit" ) {
698+ t .Errorf ("expected error message containing 'maximum request limit', got %q" , w2 .Body .String ())
699+ }
700+ })
701+
702+ t .Run ("context cancelled" , func (t * testing.T ) {
703+ lmt := NewLimiter (1 , nil )
704+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
705+ w .WriteHeader (http .StatusOK )
706+ })
707+ wrapped := HTTPMiddleware (lmt )(handler )
708+ w := httptest .NewRecorder ()
709+ r := httptest .NewRequest (http .MethodGet , "/test" , nil )
710+ ctx , cancel := context .WithCancel (r .Context ())
711+ cancel ()
712+ r = r .WithContext (ctx )
713+ wrapped .ServeHTTP (w , r )
714+ if w .Code != http .StatusServiceUnavailable {
715+ t .Errorf ("expected status %d, got %d" , http .StatusServiceUnavailable , w .Code )
716+ }
717+ if ! strings .Contains (w .Body .String (), "Context was canceled" ) {
718+ t .Errorf ("expected error message containing 'Context was canceled', got %q" , w .Body .String ())
719+ }
720+ })
721+
722+ t .Run ("custom error handler" , func (t * testing.T ) {
723+ lmt := NewLimiter (0.1 , nil ) // only allow one request per 10 seconds
724+ customMsg := "custom limit reached"
725+ lmt .SetMessage (customMsg )
726+
727+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
728+ w .WriteHeader (http .StatusOK )
729+ })
730+ wrapped := HTTPMiddleware (lmt )(handler )
731+
732+ // first request
733+ w1 := httptest .NewRecorder ()
734+ r1 := httptest .NewRequest (http .MethodGet , "/test" , nil )
735+ r1 .RemoteAddr = "127.0.0.1:12345"
736+ wrapped .ServeHTTP (w1 , r1 )
737+ if w1 .Code != http .StatusOK {
738+ t .Errorf ("first request: expected status %d, got %d" , http .StatusOK , w1 .Code )
739+ }
740+
741+ // immediate second request should fail
742+ w2 := httptest .NewRecorder ()
743+ r2 := httptest .NewRequest (http .MethodGet , "/test" , nil )
744+ r2 .RemoteAddr = "127.0.0.1:12345"
745+ wrapped .ServeHTTP (w2 , r2 )
746+ if w2 .Code != http .StatusTooManyRequests {
747+ t .Errorf ("second request: expected status %d, got %d" , http .StatusTooManyRequests , w2 .Code )
748+ }
749+ if ! strings .Contains (w2 .Body .String (), customMsg ) {
750+ t .Errorf ("expected error message containing %q, got %q" , customMsg , w2 .Body .String ())
751+ }
752+ })
753+
754+ t .Run ("custom IP lookup" , func (t * testing.T ) {
755+ lmt := NewLimiter (0.1 , nil )
756+ lmt .SetIPLookup (limiter.IPLookup {Name : "X-Real-IP" })
757+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
758+ w .WriteHeader (http .StatusOK )
759+ })
760+ wrapped := HTTPMiddleware (lmt )(handler )
761+
762+ // first request with IP1
763+ w1 := httptest .NewRecorder ()
764+ r1 := httptest .NewRequest (http .MethodGet , "/test" , nil )
765+ r1 .Header .Set ("X-Real-IP" , "5.5.5.5" )
766+ wrapped .ServeHTTP (w1 , r1 )
767+ if w1 .Code != http .StatusOK {
768+ t .Errorf ("first request: expected status %d, got %d" , http .StatusOK , w1 .Code )
769+ }
770+
771+ // second request with IP1 should fail
772+ w2 := httptest .NewRecorder ()
773+ r2 := httptest .NewRequest (http .MethodGet , "/test" , nil )
774+ r2 .Header .Set ("X-Real-IP" , "5.5.5.5" )
775+ wrapped .ServeHTTP (w2 , r2 )
776+ if w2 .Code != http .StatusTooManyRequests {
777+ t .Errorf ("second request: expected status %d, got %d" , http .StatusTooManyRequests , w2 .Code )
778+ }
779+
780+ // request with IP2 should pass
781+ w3 := httptest .NewRecorder ()
782+ r3 := httptest .NewRequest (http .MethodGet , "/test" , nil )
783+ r3 .Header .Set ("X-Real-IP" , "6.6.6.6" )
784+ wrapped .ServeHTTP (w3 , r3 )
785+ if w3 .Code != http .StatusOK {
786+ t .Errorf ("third request: expected status %d, got %d" , http .StatusOK , w3 .Code )
787+ }
788+ })
789+ }
0 commit comments