Skip to content

Commit 05e72f4

Browse files
committed
feat: Add bearer token http middleware
This middleware adds support for Bearer token authentication in your API. It extracts the Bearer token from the Authorization header and passes it to an `Authenticator` interface for verification. The implementation of the `Authenticator` interface is provided separately. Key features * Extracts Bearer token: The middleware extracts the Bearer token from the Authorization header if present. * Passes token to authenticator: It passes the extracted token to an `Authenticator` interface, which is responsible for verifying the token and returning a user object if the authentication is successful. * Handles missing or invalid tokens: The middleware handles cases where the Authorization header is missing or malformed by returning appropriate error responses. * Customizable error responses: The `errorFn` parameter allows customization of the error response returned when authentication fails. How to use it 1. Implement the `Authenticator` interface to define how the Bearer token is verified and a user object is returned. 2. Create an instance of the `BearerTokenAuthenticationMiddleware` using your `Authenticator` implementation and an `errorFn` to handle authentication errors. 3. Apply the middleware to your API routes. Additional resources * OpenAPI specification: https://swagger.io/docs/specification/v3_0/authentication/bearer-authentication/ * Google Cloud Endpoints documentation: https://cloud.google.com/endpoints/docs/openapi/authenticating-users-custom Important notes * This middleware assumes that the route requires authentication if the `BearerAuthScopes` field is set in the request context. * You are responsible for setting the `ctxKey` in the request context when authentication is needed.
1 parent 528738a commit 05e72f4

File tree

2 files changed

+314
-0
lines changed

2 files changed

+314
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package httpmiddlewares
16+
17+
import (
18+
"context"
19+
"errors"
20+
"net/http"
21+
"strings"
22+
23+
"github.com/GoogleChrome/webstatus.dev/lib/auth"
24+
)
25+
26+
var (
27+
ErrMissingAuthHeader = errors.New("missing authorization header")
28+
ErrInvalidAuthHeader = errors.New("authorization header is malformed")
29+
)
30+
31+
type authenticatedUserCtxKey struct{}
32+
33+
type BearerTokenAuthenticator interface {
34+
Authenticate(ctx context.Context, token string) (*auth.User, error)
35+
}
36+
37+
// NewBearerTokenAuthenticationMiddleware returns a middleware that can be used to authenticate requests.
38+
// It detects if a route requires authentication by checking if a field is set in the request context.
39+
// If the field is set, the middleware verifies the Authorization header and sets the authenticated user in the context.
40+
//
41+
// The errorFn parameter allows the caller to customize the error response returned when authentication fails.
42+
// This makes the middleware more generic and adaptable to different error handling requirements.
43+
//
44+
// It is the responsibility of the caller of this middleware to ensure that the `ctxKey` is set in the request context
45+
// whenever authentication is needed. This can be done using a wrapper middleware that knows about the OpenAPI
46+
// generator's security semantics.
47+
//
48+
// See https://github.com/oapi-codegen/oapi-codegen/issues/518 for details on the lack of per-endpoint middleware
49+
// support.
50+
func NewBearerTokenAuthenticationMiddleware(authenticator BearerTokenAuthenticator, ctxKey any,
51+
errorFn func(context.Context, int, http.ResponseWriter, error)) func(http.Handler) http.Handler {
52+
return func(next http.Handler) http.Handler {
53+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
54+
value := r.Context().Value(ctxKey)
55+
if value == nil {
56+
// The route does not have any security requirements set for it.
57+
next.ServeHTTP(w, r)
58+
59+
return
60+
}
61+
authHdr := r.Header.Get("Authorization")
62+
// Check for the Authorization header.
63+
if authHdr == "" {
64+
errorFn(r.Context(), http.StatusUnauthorized, w, ErrMissingAuthHeader)
65+
66+
return
67+
}
68+
prefix := "Bearer "
69+
if !strings.HasPrefix(authHdr, prefix) {
70+
errorFn(r.Context(), http.StatusUnauthorized, w, ErrInvalidAuthHeader)
71+
72+
return
73+
}
74+
75+
u, err := authenticator.Authenticate(r.Context(), strings.TrimPrefix(authHdr, prefix))
76+
if err != nil {
77+
errorFn(r.Context(), http.StatusUnauthorized, w, err)
78+
79+
return
80+
}
81+
82+
ctx := r.Context()
83+
84+
ctx = AuthenticatedUserToContext(ctx, u)
85+
86+
r = r.WithContext(ctx)
87+
88+
next.ServeHTTP(w, r)
89+
})
90+
}
91+
}
92+
93+
// AuthenticatedUserFromContext attempts to get the user from the given context.
94+
func AuthenticatedUserFromContext(ctx context.Context) (u *auth.User, ok bool) {
95+
u, ok = ctx.Value(authenticatedUserCtxKey{}).(*auth.User)
96+
97+
return
98+
}
99+
100+
// AuthenticatedUserToContext creates a new context with the user added to it.
101+
func AuthenticatedUserToContext(ctx context.Context, u *auth.User) context.Context {
102+
return context.WithValue(ctx, authenticatedUserCtxKey{}, u)
103+
}
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package httpmiddlewares
16+
17+
import (
18+
"context"
19+
"errors"
20+
"net/http"
21+
"net/http/httptest"
22+
"reflect"
23+
"testing"
24+
25+
"github.com/GoogleChrome/webstatus.dev/lib/auth"
26+
)
27+
28+
type authCtxKey struct{}
29+
30+
func TestBearerTokenAuthenticationMiddleware(t *testing.T) {
31+
const testID = "id"
32+
tests := []struct {
33+
name string
34+
ctxKey any
35+
authHeader string
36+
mockAuthenticator func(ctx context.Context, token string) (*auth.User, error)
37+
mockErrorFn func(context.Context, int, http.ResponseWriter, error)
38+
expectedStatusCode int
39+
expectedBody string
40+
expectedUser *auth.User
41+
}{
42+
{
43+
name: "No security requirements",
44+
ctxKey: nil,
45+
authHeader: "",
46+
mockAuthenticator: func(_ context.Context, _ string) (*auth.User, error) {
47+
t.Fatal("authenticate should not have been called")
48+
49+
// nolint:nilnil // WONTFIX - should not reach this.
50+
return nil, nil
51+
},
52+
mockErrorFn: func(_ context.Context, _ int, _ http.ResponseWriter, _ error) {
53+
t.Fatal("errorFn should not have been called")
54+
},
55+
expectedStatusCode: http.StatusOK,
56+
expectedBody: "next handler was called",
57+
expectedUser: nil,
58+
},
59+
{
60+
name: "Missing Authorization header",
61+
ctxKey: authCtxKey{},
62+
authHeader: "",
63+
mockAuthenticator: func(_ context.Context, _ string) (*auth.User, error) {
64+
t.Fatal("authenticate should not have been called")
65+
66+
// nolint:nilnil // WONTFIX - should not reach this.
67+
return nil, nil
68+
},
69+
mockErrorFn: func(_ context.Context, code int, w http.ResponseWriter, err error) {
70+
if code != http.StatusUnauthorized {
71+
t.Errorf("expected status code %d, got %d", http.StatusUnauthorized, code)
72+
}
73+
if !errors.Is(err, ErrMissingAuthHeader) {
74+
t.Errorf("expected error %v, got %v", ErrMissingAuthHeader, err)
75+
}
76+
w.WriteHeader(code)
77+
},
78+
expectedStatusCode: http.StatusUnauthorized,
79+
expectedUser: nil,
80+
expectedBody: "",
81+
},
82+
{
83+
name: "Invalid Authorization header",
84+
ctxKey: authCtxKey{},
85+
authHeader: "Invalid Auth",
86+
mockAuthenticator: func(_ context.Context, _ string) (*auth.User, error) {
87+
t.Fatal("authenticate should not have been called")
88+
89+
// nolint:nilnil // WONTFIX - should not reach this.
90+
return nil, nil
91+
},
92+
mockErrorFn: func(_ context.Context, code int, w http.ResponseWriter, err error) {
93+
if code != http.StatusUnauthorized {
94+
t.Errorf("expected status code %d, got %d", http.StatusUnauthorized, code)
95+
}
96+
if !errors.Is(err, ErrInvalidAuthHeader) {
97+
t.Errorf("expected error %v, got %v", ErrInvalidAuthHeader, err)
98+
}
99+
w.WriteHeader(code)
100+
},
101+
expectedStatusCode: http.StatusUnauthorized,
102+
expectedUser: nil,
103+
expectedBody: "",
104+
},
105+
{
106+
name: "Authentication failure",
107+
ctxKey: authCtxKey{},
108+
authHeader: "Bearer my-token",
109+
mockAuthenticator: func(_ context.Context, _ string) (*auth.User, error) {
110+
return nil, errors.New("authentication failed")
111+
},
112+
mockErrorFn: func(_ context.Context, code int, w http.ResponseWriter, err error) {
113+
if code != http.StatusUnauthorized {
114+
t.Errorf("expected status code %d, got %d", http.StatusUnauthorized, code)
115+
}
116+
if err == nil || err.Error() != "authentication failed" {
117+
t.Errorf("expected error 'authentication failed', got %v", err)
118+
}
119+
w.WriteHeader(code)
120+
},
121+
expectedStatusCode: http.StatusUnauthorized,
122+
expectedUser: nil,
123+
expectedBody: "",
124+
},
125+
{
126+
name: "Successful authentication",
127+
ctxKey: authCtxKey{},
128+
authHeader: "Bearer my-token",
129+
mockAuthenticator: func(_ context.Context, token string) (*auth.User, error) {
130+
if token != "my-token" {
131+
t.Errorf("expected token 'my-token', got %s", token)
132+
}
133+
134+
return &auth.User{
135+
ID: testID,
136+
}, nil
137+
},
138+
mockErrorFn: func(_ context.Context, _ int, _ http.ResponseWriter, _ error) {
139+
t.Fatal("errorFn should not have been called")
140+
},
141+
expectedStatusCode: http.StatusOK,
142+
expectedBody: "next handler was called",
143+
expectedUser: &auth.User{
144+
ID: testID,
145+
},
146+
},
147+
}
148+
149+
for _, tc := range tests {
150+
t.Run(tc.name, func(t *testing.T) {
151+
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152+
u, _ := AuthenticatedUserFromContext(r.Context())
153+
if !reflect.DeepEqual(u, tc.expectedUser) {
154+
t.Errorf("expected user %+v, received user %+v", tc.expectedUser, u)
155+
}
156+
_, err := w.Write([]byte("next handler was called"))
157+
if err != nil {
158+
t.Fatal(err)
159+
}
160+
})
161+
162+
middleware := NewBearerTokenAuthenticationMiddleware(
163+
&mockBearerTokenAuthenticator{tc.mockAuthenticator},
164+
tc.ctxKey,
165+
tc.mockErrorFn,
166+
)
167+
168+
handler := middleware(nextHandler)
169+
170+
req := httptest.NewRequest(http.MethodGet, "/", nil)
171+
if tc.authHeader != "" {
172+
req.Header.Set("Authorization", tc.authHeader)
173+
}
174+
if tc.ctxKey != nil {
175+
req = req.WithContext(context.WithValue(req.Context(), tc.ctxKey, "authCtxValue"))
176+
}
177+
178+
rr := httptest.NewRecorder()
179+
handler.ServeHTTP(rr, req)
180+
181+
assertStatusCode(t, rr, tc.expectedStatusCode)
182+
assertResponseBody(t, rr, tc.expectedBody)
183+
})
184+
}
185+
}
186+
187+
type mockBearerTokenAuthenticator struct {
188+
authenticateFn func(ctx context.Context, token string) (*auth.User, error)
189+
}
190+
191+
func (m *mockBearerTokenAuthenticator) Authenticate(ctx context.Context, token string) (*auth.User, error) {
192+
if m.authenticateFn == nil {
193+
panic("authenticateFn not set")
194+
}
195+
196+
return m.authenticateFn(ctx, token)
197+
}
198+
199+
func assertStatusCode(t *testing.T, rr *httptest.ResponseRecorder, expectedCode int) {
200+
t.Helper()
201+
if rr.Code != expectedCode {
202+
t.Errorf("expected status code %d, got %d", expectedCode, rr.Code)
203+
}
204+
}
205+
206+
func assertResponseBody(t *testing.T, rr *httptest.ResponseRecorder, expectedBody string) {
207+
t.Helper()
208+
if expectedBody != "" && rr.Body.String() != expectedBody {
209+
t.Errorf("expected body '%s', got '%s'", expectedBody, rr.Body.String())
210+
}
211+
}

0 commit comments

Comments
 (0)