Skip to content

Commit 3e7ad90

Browse files
authored
add standard http middleware support (#113)
1 parent d77d780 commit 3e7ad90

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

tollbooth.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,30 @@ func LimitHandler(lmt *limiter.Limiter, next http.Handler) http.Handler {
346346
func LimitFuncHandler(lmt *limiter.Limiter, nextFunc func(http.ResponseWriter, *http.Request)) http.Handler {
347347
return LimitHandler(lmt, http.HandlerFunc(nextFunc))
348348
}
349+
350+
// HTTPMiddleware wraps http.Handler with tollbooth limiter
351+
func HTTPMiddleware(lmt *limiter.Limiter) func(http.Handler) http.Handler {
352+
// // set IP lookup only if not set
353+
if lmt.GetIPLookup().Name == "" {
354+
lmt.SetIPLookup(limiter.IPLookup{Name: "RemoteAddr"})
355+
}
356+
357+
return func(next http.Handler) http.Handler {
358+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
359+
select {
360+
case <-r.Context().Done():
361+
http.Error(w, "Context was canceled", http.StatusServiceUnavailable)
362+
return
363+
default:
364+
if httpError := LimitByRequest(lmt, w, r); httpError != nil {
365+
lmt.ExecOnLimitReached(w, r)
366+
w.Header().Add("Content-Type", lmt.GetMessageContentType())
367+
w.WriteHeader(httpError.StatusCode)
368+
w.Write([]byte(httpError.Message)) //nolint:gosec // not much we can do here with failed write
369+
return
370+
}
371+
next.ServeHTTP(w, r)
372+
}
373+
})
374+
}
375+
}

tollbooth_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,137 @@ func TestLimitHandlerEmptyHeader(t *testing.T) {
653653

654654
wg.Wait() // Block until go func is done.
655655
}
656+
657+
func TestHTTPMiddleware(t *testing.T) {
658+
t.Run("basic request", func(t *testing.T) {
659+
lmt := NewLimiter(1, nil)
660+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
661+
w.WriteHeader(http.StatusOK)
662+
})
663+
wrapped := HTTPMiddleware(lmt)(handler)
664+
w := httptest.NewRecorder()
665+
r := httptest.NewRequest(http.MethodGet, "/test", nil)
666+
r.RemoteAddr = "127.0.0.1:12345"
667+
wrapped.ServeHTTP(w, r)
668+
if w.Code != http.StatusOK {
669+
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
670+
}
671+
})
672+
673+
t.Run("rate limit exceeded", func(t *testing.T) {
674+
lmt := NewLimiter(0.1, nil) // only allow one request per 10 seconds
675+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
676+
w.WriteHeader(http.StatusOK)
677+
})
678+
wrapped := HTTPMiddleware(lmt)(handler)
679+
680+
// first request
681+
w1 := httptest.NewRecorder()
682+
r1 := httptest.NewRequest(http.MethodGet, "/test", nil)
683+
r1.RemoteAddr = "127.0.0.1:12345"
684+
wrapped.ServeHTTP(w1, r1)
685+
if w1.Code != http.StatusOK {
686+
t.Errorf("first request: expected status %d, got %d", http.StatusOK, w1.Code)
687+
}
688+
689+
// immediate second request should fail
690+
w2 := httptest.NewRecorder()
691+
r2 := httptest.NewRequest(http.MethodGet, "/test", nil)
692+
r2.RemoteAddr = "127.0.0.1:12345"
693+
wrapped.ServeHTTP(w2, r2)
694+
if w2.Code != http.StatusTooManyRequests {
695+
t.Errorf("second request: expected status %d, got %d", http.StatusTooManyRequests, w2.Code)
696+
}
697+
if !strings.Contains(w2.Body.String(), "maximum request limit") {
698+
t.Errorf("expected error message containing 'maximum request limit', got %q", w2.Body.String())
699+
}
700+
})
701+
702+
t.Run("context cancelled", func(t *testing.T) {
703+
lmt := NewLimiter(1, nil)
704+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
705+
w.WriteHeader(http.StatusOK)
706+
})
707+
wrapped := HTTPMiddleware(lmt)(handler)
708+
w := httptest.NewRecorder()
709+
r := httptest.NewRequest(http.MethodGet, "/test", nil)
710+
ctx, cancel := context.WithCancel(r.Context())
711+
cancel()
712+
r = r.WithContext(ctx)
713+
wrapped.ServeHTTP(w, r)
714+
if w.Code != http.StatusServiceUnavailable {
715+
t.Errorf("expected status %d, got %d", http.StatusServiceUnavailable, w.Code)
716+
}
717+
if !strings.Contains(w.Body.String(), "Context was canceled") {
718+
t.Errorf("expected error message containing 'Context was canceled', got %q", w.Body.String())
719+
}
720+
})
721+
722+
t.Run("custom error handler", func(t *testing.T) {
723+
lmt := NewLimiter(0.1, nil) // only allow one request per 10 seconds
724+
customMsg := "custom limit reached"
725+
lmt.SetMessage(customMsg)
726+
727+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
728+
w.WriteHeader(http.StatusOK)
729+
})
730+
wrapped := HTTPMiddleware(lmt)(handler)
731+
732+
// first request
733+
w1 := httptest.NewRecorder()
734+
r1 := httptest.NewRequest(http.MethodGet, "/test", nil)
735+
r1.RemoteAddr = "127.0.0.1:12345"
736+
wrapped.ServeHTTP(w1, r1)
737+
if w1.Code != http.StatusOK {
738+
t.Errorf("first request: expected status %d, got %d", http.StatusOK, w1.Code)
739+
}
740+
741+
// immediate second request should fail
742+
w2 := httptest.NewRecorder()
743+
r2 := httptest.NewRequest(http.MethodGet, "/test", nil)
744+
r2.RemoteAddr = "127.0.0.1:12345"
745+
wrapped.ServeHTTP(w2, r2)
746+
if w2.Code != http.StatusTooManyRequests {
747+
t.Errorf("second request: expected status %d, got %d", http.StatusTooManyRequests, w2.Code)
748+
}
749+
if !strings.Contains(w2.Body.String(), customMsg) {
750+
t.Errorf("expected error message containing %q, got %q", customMsg, w2.Body.String())
751+
}
752+
})
753+
754+
t.Run("custom IP lookup", func(t *testing.T) {
755+
lmt := NewLimiter(0.1, nil)
756+
lmt.SetIPLookup(limiter.IPLookup{Name: "X-Real-IP"})
757+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
758+
w.WriteHeader(http.StatusOK)
759+
})
760+
wrapped := HTTPMiddleware(lmt)(handler)
761+
762+
// first request with IP1
763+
w1 := httptest.NewRecorder()
764+
r1 := httptest.NewRequest(http.MethodGet, "/test", nil)
765+
r1.Header.Set("X-Real-IP", "5.5.5.5")
766+
wrapped.ServeHTTP(w1, r1)
767+
if w1.Code != http.StatusOK {
768+
t.Errorf("first request: expected status %d, got %d", http.StatusOK, w1.Code)
769+
}
770+
771+
// second request with IP1 should fail
772+
w2 := httptest.NewRecorder()
773+
r2 := httptest.NewRequest(http.MethodGet, "/test", nil)
774+
r2.Header.Set("X-Real-IP", "5.5.5.5")
775+
wrapped.ServeHTTP(w2, r2)
776+
if w2.Code != http.StatusTooManyRequests {
777+
t.Errorf("second request: expected status %d, got %d", http.StatusTooManyRequests, w2.Code)
778+
}
779+
780+
// request with IP2 should pass
781+
w3 := httptest.NewRecorder()
782+
r3 := httptest.NewRequest(http.MethodGet, "/test", nil)
783+
r3.Header.Set("X-Real-IP", "6.6.6.6")
784+
wrapped.ServeHTTP(w3, r3)
785+
if w3.Code != http.StatusOK {
786+
t.Errorf("third request: expected status %d, got %d", http.StatusOK, w3.Code)
787+
}
788+
})
789+
}

0 commit comments

Comments
 (0)