@@ -7,6 +7,8 @@ package main
77import (
88 "context"
99 "errors"
10+ "log"
11+ "sync"
1012 "time"
1113
1214 "github.com/modelcontextprotocol/go-sdk/mcp"
@@ -43,12 +45,38 @@ func PerMethodRateLimiterMiddleware[S mcp.Session](limiters map[string]*rate.Lim
4345 }
4446}
4547
48+ // PerSessionRateLimiterMiddleware creates a middleware that applies rate limiting
49+ // on a per-session basis for receiving requests.
50+ func PerSessionRateLimiterMiddleware [S mcp.Session ](limit rate.Limit , burst int ) mcp.Middleware [S ] {
51+ // A map to store limiters, keyed by the session ID.
52+ // Use sync.Map for concurrent access as sessions can connect/disconnect concurrently.
53+ var sessionLimiters sync.Map // map[string]*rate.Limiter
54+
55+ return func (next mcp.MethodHandler [S ]) mcp.MethodHandler [S ] {
56+ return func (ctx context.Context , session S , method string , params mcp.Params ) (mcp.Result , error ) {
57+ sessionID := session .ID ()
58+
59+ // Load or CREATE a new limiter for this session if it doesn't exist
60+ actualLimiter , _ := sessionLimiters .LoadOrStore (sessionID , rate .NewLimiter (limit , burst ))
61+ rateLimiter := actualLimiter .(* rate.Limiter )
62+
63+ if ! rateLimiter .Allow () {
64+ return nil , errors .New ("JSON RPC overloaded" )
65+ }
66+ return next (ctx , session , method , params )
67+ }
68+ }
69+ }
70+
4671func main () {
4772 server := mcp .NewServer ("greeter1" , "v0.0.1" , nil )
4873 server .AddReceivingMiddleware (GlobalRateLimiterMiddleware [* mcp.ServerSession ](rate .NewLimiter (rate .Every (time .Second / 5 ), 10 )))
4974 server .AddReceivingMiddleware (PerMethodRateLimiterMiddleware [* mcp.ServerSession ](map [string ]* rate.Limiter {
5075 "callTool" : rate .NewLimiter (rate .Every (time .Second ), 5 ), // once a second with a burst up to 5
5176 "listTools" : rate .NewLimiter (rate .Every (time .Minute ), 20 ), // once a minute with a burst up to 20
5277 }))
78+ server .AddReceivingMiddleware (PerSessionRateLimiterMiddleware [* mcp.ServerSession ](rate .Every (time .Second / 5 ), 10 ))
5379 // Run Server logic.
80+ log .Println ("MCP Server instance created with Middleware (but not running)." )
81+ log .Println ("This example demonstrates configuration, not live interaction." )
5482}
0 commit comments