diff --git a/examples/rate-limiting/go.mod b/examples/rate-limiting/go.mod index 5ec49ddc..39a5af50 100644 --- a/examples/rate-limiting/go.mod +++ b/examples/rate-limiting/go.mod @@ -3,6 +3,6 @@ module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting go 1.25 require ( - github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 + github.com/modelcontextprotocol/go-sdk v0.1.0 golang.org/x/time v0.12.0 ) diff --git a/examples/rate-limiting/go.sum b/examples/rate-limiting/go.sum index c7027682..d73f0a54 100644 --- a/examples/rate-limiting/go.sum +++ b/examples/rate-limiting/go.sum @@ -1,7 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 h1:kUGBYP25FTv3ZRBhLT4iQvtx4FDl7hPkWe3isYrMxyo= -github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89/go.mod h1:DcXfbr7yl7e35oMpzHfKw2nUYRjhIGS2uou/6tdsTB0= +github.com/modelcontextprotocol/go-sdk v0.1.0 h1:ItzbFWYNt4EHcUrScX7P8JPASn1FVYb29G773Xkl+IU= +github.com/modelcontextprotocol/go-sdk v0.1.0/go.mod h1:DcXfbr7yl7e35oMpzHfKw2nUYRjhIGS2uou/6tdsTB0= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/examples/rate-limiting/main.go b/examples/rate-limiting/main.go index 7e91b79f..c3265c4c 100644 --- a/examples/rate-limiting/main.go +++ b/examples/rate-limiting/main.go @@ -7,6 +7,8 @@ package main import ( "context" "errors" + "log" + "sync" "time" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -43,6 +45,43 @@ func PerMethodRateLimiterMiddleware[S mcp.Session](limiters map[string]*rate.Lim } } +// PerSessionRateLimiterMiddleware creates a middleware that applies rate limiting +// on a per-session basis for receiving requests. +func PerSessionRateLimiterMiddleware[S mcp.Session](limit rate.Limit, burst int) mcp.Middleware[S] { + // A map to store limiters, keyed by the session ID. + var ( + sessionLimiters = make(map[string]*rate.Limiter) + mu sync.Mutex + ) + + return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { + return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { + // It's possible that session.ID() may be empty at this point in time + // for some transports (e.g., stdio) or until the MCP initialize handshake + // has completed. + sessionID := session.ID() + if sessionID == "" { + // In this situation, you could apply a single global identifier + // if session ID is empty or bypass the rate limiter. + // In this example, we bypass the rate limiter. + log.Printf("Warning: Session ID is empty for method %q. Skipping per-session rate limiting.", method) + return next(ctx, session, method, params) // Skip limiting if ID is unavailable + } + mu.Lock() + limiter, ok := sessionLimiters[sessionID] + if !ok { + limiter = rate.NewLimiter(limit, burst) + sessionLimiters[sessionID] = limiter + } + mu.Unlock() + if !limiter.Allow() { + return nil, errors.New("JSON RPC overloaded") + } + return next(ctx, session, method, params) + } + } +} + func main() { server := mcp.NewServer("greeter1", "v0.0.1", nil) server.AddReceivingMiddleware(GlobalRateLimiterMiddleware[*mcp.ServerSession](rate.NewLimiter(rate.Every(time.Second/5), 10))) @@ -50,5 +89,8 @@ func main() { "callTool": rate.NewLimiter(rate.Every(time.Second), 5), // once a second with a burst up to 5 "listTools": rate.NewLimiter(rate.Every(time.Minute), 20), // once a minute with a burst up to 20 })) + server.AddReceivingMiddleware(PerSessionRateLimiterMiddleware[*mcp.ServerSession](rate.Every(time.Second/5), 10)) // Run Server logic. + log.Println("MCP Server instance created with Middleware (but not running).") + log.Println("This example demonstrates configuration, not live interaction.") }