Skip to content

Commit 7c04a3d

Browse files
committed
refactor(auth-middleware): rename functions for clarity and update comments
This commit renames several functions in the authentication middleware to improve clarity, changing `jwtVerifier` to `verifyJWT` and `apiKeyVerifier` to `verifyAPIKey`. Additionally, comments throughout the code have been updated for consistency and clarity, ensuring they accurately describe the functionality. The changes also include adjustments to the handling of user information extraction and scope checks in the MCP tools.
1 parent d26d381 commit 7c04a3d

File tree

2 files changed

+75
-89
lines changed

2 files changed

+75
-89
lines changed

examples/server/auth-middleware/README.md

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -195,29 +195,21 @@ func jwtVerifier(ctx context.Context, tokenString string) (*auth.TokenInfo, erro
195195

196196
```go
197197
// Get authentication information in MCP tool
198-
func MyTool(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[MyArgs]]) (*mcp.CallToolResultFor[struct{}], error) {
199-
// Extract authentication info from context
200-
userInfo := ctx.Value("user_info").(*auth.TokenInfo)
198+
func MyTool(ctx context.Context, req *mcp.CallToolRequest, args MyArgs) (*mcp.CallToolResult, any, error) {
199+
// Extract authentication info from request
200+
userInfo := req.Extra.TokenInfo
201201

202202
// Check scopes
203-
hasReadScope := false
204-
for _, scope := range userInfo.Scopes {
205-
if scope == "read" {
206-
hasReadScope = true
207-
break
208-
}
209-
}
210-
211-
if !hasReadScope {
212-
return nil, fmt.Errorf("insufficient permissions: read scope required")
203+
if !slices.Contains(userInfo.Scopes, "read") {
204+
return nil, nil, fmt.Errorf("insufficient permissions: read scope required")
213205
}
214206

215207
// Execute tool logic
216-
return &mcp.CallToolResultFor[struct{}]{
208+
return &mcp.CallToolResult{
217209
Content: []mcp.Content{
218210
&mcp.TextContent{Text: "Tool executed successfully"},
219211
},
220-
}, nil
212+
}, nil, nil
221213
}
222214
```
223215

examples/server/auth-middleware/main.go

Lines changed: 68 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"fmt"
1414
"log"
1515
"net/http"
16+
"slices"
1617
"strings"
1718
"time"
1819

@@ -64,7 +65,7 @@ var jwtSecret = []byte("your-secret-key")
6465
// generateToken creates a JWT token for testing purposes.
6566
// In a real application, this would be handled by your authentication service.
6667
func generateToken(userID string, scopes []string, expiresIn time.Duration) (string, error) {
67-
// Create JWT claims with user information and scopes
68+
// Create JWT claims with user information and scopes.
6869
claims := JWTClaims{
6970
UserID: userID,
7071
Scopes: scopes,
@@ -75,71 +76,71 @@ func generateToken(userID string, scopes []string, expiresIn time.Duration) (str
7576
},
7677
}
7778

78-
// Create and sign the JWT token
79+
// Create and sign the JWT token.
7980
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
8081
return token.SignedString(jwtSecret)
8182
}
8283

83-
// jwtVerifier verifies JWT tokens and returns TokenInfo for the auth middleware.
84+
// verifyJWT verifies JWT tokens and returns TokenInfo for the auth middleware.
8485
// This function implements the TokenVerifier interface required by auth.RequireBearerToken.
85-
func jwtVerifier(ctx context.Context, tokenString string) (*auth.TokenInfo, error) {
86-
// Parse and validate the JWT token
87-
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
88-
// Verify the signing method is HMAC
86+
func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) {
87+
// Parse and validate the JWT token.
88+
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
89+
// Verify the signing method is HMAC.
8990
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
9091
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
9192
}
9293
return jwtSecret, nil
9394
})
9495

9596
if err != nil {
96-
// Return standard error for invalid tokens
97-
return nil, auth.ErrInvalidToken
97+
// Return standard error for invalid tokens.
98+
return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err)
9899
}
99100

100-
// Extract claims and verify token validity
101+
// Extract claims and verify token validity.
101102
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
102103
return &auth.TokenInfo{
103104
Scopes: claims.Scopes, // User permissions
104105
Expiration: claims.ExpiresAt.Time, // Token expiration time
105106
}, nil
106107
}
107108

108-
return nil, auth.ErrInvalidToken
109+
return nil, fmt.Errorf("%w: invalid token claims", auth.ErrInvalidToken)
109110
}
110111

111-
// apiKeyVerifier verifies API keys and returns TokenInfo for the auth middleware.
112+
// verifyAPIKey verifies API keys and returns TokenInfo for the auth middleware.
112113
// This function implements the TokenVerifier interface required by auth.RequireBearerToken.
113-
func apiKeyVerifier(ctx context.Context, apiKey string) (*auth.TokenInfo, error) {
114-
// Look up the API key in our storage
114+
func verifyAPIKey(ctx context.Context, apiKey string) (*auth.TokenInfo, error) {
115+
// Look up the API key in our storage.
115116
key, exists := apiKeys[apiKey]
116117
if !exists {
117-
return nil, auth.ErrInvalidToken
118+
return nil, fmt.Errorf("%w: API key not found", auth.ErrInvalidToken)
118119
}
119120

120-
// API keys don't expire in this example, but you could add expiration logic here
121-
// For demonstration, we set a 24-hour expiration
121+
// API keys don't expire in this example, but you could add expiration logic here.
122+
// For demonstration, we set a 24-hour expiration.
122123
return &auth.TokenInfo{
123124
Scopes: key.Scopes, // User permissions
124125
Expiration: time.Now().Add(24 * time.Hour), // 24 hour expiration
125126
}, nil
126127
}
127128

128129
// MCP Tool Arguments
129-
type GetUserInfoArgs struct {
130+
type getUserInfoArgs struct {
130131
UserID string `json:"user_id" jsonschema:"the user ID to get information for"`
131132
}
132133

133-
type CreateResourceArgs struct {
134+
type createResourceArgs struct {
134135
Name string `json:"name" jsonschema:"the name of the resource"`
135136
Description string `json:"description" jsonschema:"the description of the resource"`
136137
Content string `json:"content" jsonschema:"the content of the resource"`
137138
}
138139

139140
// SayHi is a simple MCP tool that requires authentication
140141
func SayHi(ctx context.Context, req *mcp.CallToolRequest, args struct{}) (*mcp.CallToolResult, any, error) {
141-
// Extract user information from context (set by auth middleware)
142-
userInfo := ctx.Value("user_info").(*auth.TokenInfo)
142+
// Extract user information from request (v0.3.0+)
143+
userInfo := req.Extra.TokenInfo
143144

144145
return &mcp.CallToolResult{
145146
Content: []mcp.Content{
@@ -149,30 +150,25 @@ func SayHi(ctx context.Context, req *mcp.CallToolRequest, args struct{}) (*mcp.C
149150
}
150151

151152
// GetUserInfo is an MCP tool that requires read scope
152-
func GetUserInfo(ctx context.Context, req *mcp.CallToolRequest, args GetUserInfoArgs) (*mcp.CallToolResult, any, error) {
153-
// Extract user information from context (set by auth middleware)
154-
userInfo := ctx.Value("user_info").(*auth.TokenInfo)
155-
156-
// Check if user has read scope
157-
hasReadScope := false
158-
for _, scope := range userInfo.Scopes {
159-
if scope == "read" {
160-
hasReadScope = true
161-
break
162-
}
163-
}
153+
func GetUserInfo(ctx context.Context, req *mcp.CallToolRequest, args getUserInfoArgs) (*mcp.CallToolResult, any, error) {
154+
// Extract user information from request (v0.3.0+)
155+
userInfo := req.Extra.TokenInfo
164156

165-
if !hasReadScope {
157+
// Check if user has read scope.
158+
if !slices.Contains(userInfo.Scopes, "read") {
166159
return nil, nil, fmt.Errorf("insufficient permissions: read scope required")
167160
}
168161

169-
userData := map[string]interface{}{
162+
userData := map[string]any{
170163
"requested_user_id": args.UserID,
171164
"your_scopes": userInfo.Scopes,
172165
"message": "User information retrieved successfully",
173166
}
174167

175-
userDataJSON, _ := json.Marshal(userData)
168+
userDataJSON, err := json.Marshal(userData)
169+
if err != nil {
170+
return nil, nil, fmt.Errorf("failed to marshal user data: %w", err)
171+
}
176172

177173
return &mcp.CallToolResult{
178174
Content: []mcp.Content{
@@ -182,32 +178,27 @@ func GetUserInfo(ctx context.Context, req *mcp.CallToolRequest, args GetUserInfo
182178
}
183179

184180
// CreateResource is an MCP tool that requires write scope
185-
func CreateResource(ctx context.Context, req *mcp.CallToolRequest, args CreateResourceArgs) (*mcp.CallToolResult, any, error) {
186-
// Extract user information from context (set by auth middleware)
187-
userInfo := ctx.Value("user_info").(*auth.TokenInfo)
188-
189-
// Check if user has write scope
190-
hasWriteScope := false
191-
for _, scope := range userInfo.Scopes {
192-
if scope == "write" {
193-
hasWriteScope = true
194-
break
195-
}
196-
}
181+
func CreateResource(ctx context.Context, req *mcp.CallToolRequest, args createResourceArgs) (*mcp.CallToolResult, any, error) {
182+
// Extract user information from request (v0.3.0+)
183+
userInfo := req.Extra.TokenInfo
197184

198-
if !hasWriteScope {
185+
// Check if user has write scope.
186+
if !slices.Contains(userInfo.Scopes, "write") {
199187
return nil, nil, fmt.Errorf("insufficient permissions: write scope required")
200188
}
201189

202-
resourceInfo := map[string]interface{}{
190+
resourceInfo := map[string]any{
203191
"name": args.Name,
204192
"description": args.Description,
205193
"content": args.Content,
206194
"created_by": "authenticated_user",
207195
"created_at": time.Now().Format(time.RFC3339),
208196
}
209197

210-
resourceInfoJSON, _ := json.Marshal(resourceInfo)
198+
resourceInfoJSON, err := json.Marshal(resourceInfo)
199+
if err != nil {
200+
return nil, nil, fmt.Errorf("failed to marshal resource info: %w", err)
201+
}
211202

212203
return &mcp.CallToolResult{
213204
Content: []mcp.Content{
@@ -233,7 +224,7 @@ func authMiddleware(next http.Handler) http.Handler {
233224
func createMCPServer() *mcp.Server {
234225
server := mcp.NewServer(&mcp.Implementation{Name: "authenticated-mcp-server"}, nil)
235226

236-
// Add tools that require authentication
227+
// Add tools that require authentication.
237228
mcp.AddTool(server, &mcp.Tool{
238229
Name: "say_hi",
239230
Description: "A simple greeting tool that requires authentication",
@@ -255,61 +246,61 @@ func createMCPServer() *mcp.Server {
255246
func main() {
256247
flag.Parse()
257248

258-
// Create the MCP server
249+
// Create the MCP server.
259250
server := createMCPServer()
260251

261-
// Create authentication middleware
262-
jwtAuth := auth.RequireBearerToken(jwtVerifier, &auth.RequireBearerTokenOptions{
252+
// Create authentication middleware.
253+
jwtAuth := auth.RequireBearerToken(verifyJWT, &auth.RequireBearerTokenOptions{
263254
Scopes: []string{"read"}, // Require "read" permission
264255
})
265256

266-
apiKeyAuth := auth.RequireBearerToken(apiKeyVerifier, &auth.RequireBearerTokenOptions{
257+
apiKeyAuth := auth.RequireBearerToken(verifyAPIKey, &auth.RequireBearerTokenOptions{
267258
Scopes: []string{"read"}, // Require "read" permission
268259
})
269260

270-
// Create HTTP handler with authentication
261+
// Create HTTP handler with authentication.
271262
handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
272263
return server
273264
}, nil)
274265

275-
// Apply authentication middleware to the MCP handler
266+
// Apply authentication middleware to the MCP handler.
276267
authenticatedHandler := jwtAuth(authMiddleware(handler))
277268
apiKeyHandler := apiKeyAuth(authMiddleware(handler))
278269

279-
// Create router for different authentication methods
270+
// Create router for different authentication methods.
280271
http.HandleFunc("/mcp/jwt", authenticatedHandler.ServeHTTP)
281272
http.HandleFunc("/mcp/apikey", apiKeyHandler.ServeHTTP)
282273

283-
// Add utility endpoints for token generation
274+
// Add utility endpoints for token generation.
284275
http.HandleFunc("/generate-token", func(w http.ResponseWriter, r *http.Request) {
285-
// Get user ID from query parameters (default: "test-user")
276+
// Get user ID from query parameters (default: "test-user").
286277
userID := r.URL.Query().Get("user_id")
287278
if userID == "" {
288279
userID = "test-user"
289280
}
290281

291-
// Get scopes from query parameters (default: ["read", "write"])
282+
// Get scopes from query parameters (default: ["read", "write"]).
292283
scopes := strings.Split(r.URL.Query().Get("scopes"), ",")
293284
if len(scopes) == 1 && scopes[0] == "" {
294285
scopes = []string{"read", "write"}
295286
}
296287

297-
// Get expiration time from query parameters (default: 1 hour)
288+
// Get expiration time from query parameters (default: 1 hour).
298289
expiresIn := 1 * time.Hour
299290
if expStr := r.URL.Query().Get("expires_in"); expStr != "" {
300291
if exp, err := time.ParseDuration(expStr); err == nil {
301292
expiresIn = exp
302293
}
303294
}
304295

305-
// Generate the JWT token
296+
// Generate the JWT token.
306297
token, err := generateToken(userID, scopes, expiresIn)
307298
if err != nil {
308299
http.Error(w, "Failed to generate token", http.StatusInternalServerError)
309300
return
310301
}
311302

312-
// Return the generated token
303+
// Return the generated token.
313304
w.Header().Set("Content-Type", "application/json")
314305
json.NewEncoder(w).Encode(map[string]string{
315306
"token": token,
@@ -318,40 +309,43 @@ func main() {
318309
})
319310

320311
http.HandleFunc("/generate-api-key", func(w http.ResponseWriter, r *http.Request) {
321-
// Generate a random API key using cryptographically secure random bytes
312+
// Generate a random API key using cryptographically secure random bytes.
322313
bytes := make([]byte, 16)
323-
rand.Read(bytes)
314+
if _, err := rand.Read(bytes); err != nil {
315+
http.Error(w, "Failed to generate random bytes", http.StatusInternalServerError)
316+
return
317+
}
324318
apiKey := "sk-" + base64.URLEncoding.EncodeToString(bytes)
325319

326-
// Get user ID from query parameters (default: "test-user")
320+
// Get user ID from query parameters (default: "test-user").
327321
userID := r.URL.Query().Get("user_id")
328322
if userID == "" {
329323
userID = "test-user"
330324
}
331325

332-
// Get scopes from query parameters (default: ["read"])
326+
// Get scopes from query parameters (default: ["read"]).
333327
scopes := strings.Split(r.URL.Query().Get("scopes"), ",")
334328
if len(scopes) == 1 && scopes[0] == "" {
335329
scopes = []string{"read"}
336330
}
337331

338-
// Store the new API key in our in-memory storage
339-
// In production, this would be stored in a database
332+
// Store the new API key in our in-memory storage.
333+
// In production, this would be stored in a database.
340334
apiKeys[apiKey] = &APIKey{
341335
Key: apiKey,
342336
UserID: userID,
343337
Scopes: scopes,
344338
}
345339

346-
// Return the generated API key
340+
// Return the generated API key.
347341
w.Header().Set("Content-Type", "application/json")
348342
json.NewEncoder(w).Encode(map[string]string{
349343
"api_key": apiKey,
350344
"type": "Bearer",
351345
})
352346
})
353347

354-
// Health check endpoint
348+
// Health check endpoint.
355349
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
356350
w.Header().Set("Content-Type", "application/json")
357351
json.NewEncoder(w).Encode(map[string]string{
@@ -360,7 +354,7 @@ func main() {
360354
})
361355
})
362356

363-
// Start the HTTP server
357+
// Start the HTTP server.
364358
log.Println("Authenticated MCP Server")
365359
log.Println("========================")
366360
log.Println("Server starting on", *httpAddr)

0 commit comments

Comments
 (0)