@@ -18,54 +18,54 @@ import (
1818// GlobalRateLimiterMiddleware creates a middleware that applies a global rate limit.
1919// Every request attempting to pass through will try to acquire a token.
2020// If a token cannot be acquired immediately, the request will be rejected.
21- func GlobalRateLimiterMiddleware [ S mcp. Session ] (limiter * rate.Limiter ) mcp.Middleware [ S ] {
22- return func (next mcp.MethodHandler [ S ] ) mcp.MethodHandler [ S ] {
23- return func (ctx context.Context , session S , method string , params mcp.Params ) (mcp.Result , error ) {
21+ func GlobalRateLimiterMiddleware (limiter * rate.Limiter ) mcp.Middleware {
22+ return func (next mcp.MethodHandler ) mcp.MethodHandler {
23+ return func (ctx context.Context , method string , req mcp.Request ) (mcp.Result , error ) {
2424 if ! limiter .Allow () {
2525 return nil , errors .New ("JSON RPC overloaded" )
2626 }
27- return next (ctx , session , method , params )
27+ return next (ctx , method , req )
2828 }
2929 }
3030}
3131
3232// PerMethodRateLimiterMiddleware creates a middleware that applies rate limiting
3333// on a per-method basis.
3434// Methods not specified in limiters will not be rate limited by this middleware.
35- func PerMethodRateLimiterMiddleware [ S mcp. Session ] (limiters map [string ]* rate.Limiter ) mcp.Middleware [ S ] {
36- return func (next mcp.MethodHandler [ S ] ) mcp.MethodHandler [ S ] {
37- return func (ctx context.Context , session S , method string , params mcp.Params ) (mcp.Result , error ) {
35+ func PerMethodRateLimiterMiddleware (limiters map [string ]* rate.Limiter ) mcp.Middleware {
36+ return func (next mcp.MethodHandler ) mcp.MethodHandler {
37+ return func (ctx context.Context , method string , req mcp.Request ) (mcp.Result , error ) {
3838 if limiter , ok := limiters [method ]; ok {
3939 if ! limiter .Allow () {
4040 return nil , errors .New ("JSON RPC overloaded" )
4141 }
4242 }
43- return next (ctx , session , method , params )
43+ return next (ctx , method , req )
4444 }
4545 }
4646}
4747
4848// PerSessionRateLimiterMiddleware creates a middleware that applies rate limiting
4949// on a per-session basis for receiving requests.
50- func PerSessionRateLimiterMiddleware [ S mcp. Session ] (limit rate.Limit , burst int ) mcp.Middleware [ S ] {
50+ func PerSessionRateLimiterMiddleware (limit rate.Limit , burst int ) mcp.Middleware {
5151 // A map to store limiters, keyed by the session ID.
5252 var (
5353 sessionLimiters = make (map [string ]* rate.Limiter )
5454 mu sync.Mutex
5555 )
5656
57- return func (next mcp.MethodHandler [ S ] ) mcp.MethodHandler [ S ] {
58- return func (ctx context.Context , session S , method string , params mcp.Params ) (mcp.Result , error ) {
57+ return func (next mcp.MethodHandler ) mcp.MethodHandler {
58+ return func (ctx context.Context , method string , req mcp.Request ) (mcp.Result , error ) {
5959 // It's possible that session.ID() may be empty at this point in time
6060 // for some transports (e.g., stdio) or until the MCP initialize handshake
6161 // has completed.
62- sessionID := session .ID ()
62+ sessionID := req . GetSession () .ID ()
6363 if sessionID == "" {
6464 // In this situation, you could apply a single global identifier
6565 // if session ID is empty or bypass the rate limiter.
6666 // In this example, we bypass the rate limiter.
6767 log .Printf ("Warning: Session ID is empty for method %q. Skipping per-session rate limiting." , method )
68- return next (ctx , session , method , params ) // Skip limiting if ID is unavailable
68+ return next (ctx , method , req ) // Skip limiting if ID is unavailable
6969 }
7070 mu .Lock ()
7171 limiter , ok := sessionLimiters [sessionID ]
@@ -77,19 +77,19 @@ func PerSessionRateLimiterMiddleware[S mcp.Session](limit rate.Limit, burst int)
7777 if ! limiter .Allow () {
7878 return nil , errors .New ("JSON RPC overloaded" )
7979 }
80- return next (ctx , session , method , params )
80+ return next (ctx , method , req )
8181 }
8282 }
8383}
8484
8585func main () {
86- server := mcp .NewServer ("greeter1" , "v0.0.1" , nil )
87- server .AddReceivingMiddleware (GlobalRateLimiterMiddleware [ * mcp. ServerSession ] (rate .NewLimiter (rate .Every (time .Second / 5 ), 10 )))
88- server .AddReceivingMiddleware (PerMethodRateLimiterMiddleware [ * mcp. ServerSession ] (map [string ]* rate.Limiter {
86+ server := mcp .NewServer (& mcp. Implementation { Name : "greeter1" , Version : "v0.0.1" } , nil )
87+ server .AddReceivingMiddleware (GlobalRateLimiterMiddleware (rate .NewLimiter (rate .Every (time .Second / 5 ), 10 )))
88+ server .AddReceivingMiddleware (PerMethodRateLimiterMiddleware (map [string ]* rate.Limiter {
8989 "callTool" : rate .NewLimiter (rate .Every (time .Second ), 5 ), // once a second with a burst up to 5
9090 "listTools" : rate .NewLimiter (rate .Every (time .Minute ), 20 ), // once a minute with a burst up to 20
9191 }))
92- server .AddReceivingMiddleware (PerSessionRateLimiterMiddleware [ * mcp. ServerSession ] (rate .Every (time .Second / 5 ), 10 ))
92+ server .AddReceivingMiddleware (PerSessionRateLimiterMiddleware (rate .Every (time .Second / 5 ), 10 ))
9393 // Run Server logic.
9494 log .Println ("MCP Server instance created with Middleware (but not running)." )
9595 log .Println ("This example demonstrates configuration, not live interaction." )
0 commit comments