@@ -7,6 +7,8 @@ package main
77import (
88 "context"
99 "errors"
10+ "log"
11+ "sync"
1012 "time"
1113
1214 "github.com/modelcontextprotocol/go-sdk/mcp"
@@ -43,12 +45,52 @@ func PerMethodRateLimiterMiddleware[S mcp.Session](limiters map[string]*rate.Lim
4345 }
4446}
4547
48+ // PerSessionRateLimiterMiddleware creates a middleware that applies rate limiting
49+ // on a per-session basis for receiving requests.
50+ func PerSessionRateLimiterMiddleware [S mcp.Session ](limit rate.Limit , burst int ) mcp.Middleware [S ] {
51+ // A map to store limiters, keyed by the session ID.
52+ var (
53+ sessionLimiters = make (map [string ]* rate.Limiter )
54+ mu sync.Mutex
55+ )
56+
57+ return func (next mcp.MethodHandler [S ]) mcp.MethodHandler [S ] {
58+ return func (ctx context.Context , session S , method string , params mcp.Params ) (mcp.Result , error ) {
59+ // It's possible that session.ID() may be empty at this point in time
60+ // for some transports (e.g., stdio) or until the MCP initialize handshake
61+ // has completed.
62+ sessionID := session .ID ()
63+ if sessionID == "" {
64+ // In this situation, you could apply a single global identifier
65+ // if session ID is empty or bypass the rate limiter.
66+ // In this example, we bypass the rate limiter.
67+ log .Printf ("Warning: Session ID is empty for method %q. Skipping per-session rate limiting." , method )
68+ return next (ctx , session , method , params ) // Skip limiting if ID is unavailable
69+ }
70+ mu .Lock ()
71+ limiter , ok := sessionLimiters [sessionID ]
72+ if ! ok {
73+ limiter = rate .NewLimiter (limit , burst )
74+ sessionLimiters [sessionID ] = limiter
75+ }
76+ mu .Unlock ()
77+ if ! limiter .Allow () {
78+ return nil , errors .New ("JSON RPC overloaded" )
79+ }
80+ return next (ctx , session , method , params )
81+ }
82+ }
83+ }
84+
4685func main () {
4786 server := mcp .NewServer ("greeter1" , "v0.0.1" , nil )
4887 server .AddReceivingMiddleware (GlobalRateLimiterMiddleware [* mcp.ServerSession ](rate .NewLimiter (rate .Every (time .Second / 5 ), 10 )))
4988 server .AddReceivingMiddleware (PerMethodRateLimiterMiddleware [* mcp.ServerSession ](map [string ]* rate.Limiter {
5089 "callTool" : rate .NewLimiter (rate .Every (time .Second ), 5 ), // once a second with a burst up to 5
5190 "listTools" : rate .NewLimiter (rate .Every (time .Minute ), 20 ), // once a minute with a burst up to 20
5291 }))
92+ server .AddReceivingMiddleware (PerSessionRateLimiterMiddleware [* mcp.ServerSession ](rate .Every (time .Second / 5 ), 10 ))
5393 // Run Server logic.
94+ log .Println ("MCP Server instance created with Middleware (but not running)." )
95+ log .Println ("This example demonstrates configuration, not live interaction." )
5496}
0 commit comments