diff --git a/design/design.md b/design/design.md index 4ade34ca..7273fcf3 100644 --- a/design/design.md +++ b/design/design.md @@ -470,6 +470,10 @@ server.AddReceivingMiddleware(withLogging) **Differences from mcp-go**: Version 0.26.0 of mcp-go defines 24 server hooks. Each hook consists of a field in the `Hooks` struct, a `Hooks.Add` method, and a type for the hook function. These are rarely used. The most common is `OnError`, which occurs fewer than ten times in open-source code. +#### Rate Limiting + +Rate limiting can be configured using middleware. Please see [examples/rate-limiting](] for an example on how to implement this. + ### Errors With the exception of tool handler errors, protocol errors are handled transparently as Go errors: errors in server-side feature handlers are propagated as errors from calls from the `ClientSession`, and vice-versa. diff --git a/examples/rate-limiting/go.mod b/examples/rate-limiting/go.mod new file mode 100644 index 00000000..5ec49ddc --- /dev/null +++ b/examples/rate-limiting/go.mod @@ -0,0 +1,8 @@ +module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting + +go 1.25 + +require ( + github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 + golang.org/x/time v0.12.0 +) diff --git a/examples/rate-limiting/go.sum b/examples/rate-limiting/go.sum new file mode 100644 index 00000000..c7027682 --- /dev/null +++ b/examples/rate-limiting/go.sum @@ -0,0 +1,8 @@ +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 h1:kUGBYP25FTv3ZRBhLT4iQvtx4FDl7hPkWe3isYrMxyo= +github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89/go.mod h1:DcXfbr7yl7e35oMpzHfKw2nUYRjhIGS2uou/6tdsTB0= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/examples/rate-limiting/main.go b/examples/rate-limiting/main.go new file mode 100644 index 00000000..7e91b79f --- /dev/null +++ b/examples/rate-limiting/main.go @@ -0,0 +1,54 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "errors" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "golang.org/x/time/rate" +) + +// GlobalRateLimiterMiddleware creates a middleware that applies a global rate limit. +// Every request attempting to pass through will try to acquire a token. +// If a token cannot be acquired immediately, the request will be rejected. +func GlobalRateLimiterMiddleware[S mcp.Session](limiter *rate.Limiter) mcp.Middleware[S] { + return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { + return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { + if !limiter.Allow() { + return nil, errors.New("JSON RPC overloaded") + } + return next(ctx, session, method, params) + } + } +} + +// PerMethodRateLimiterMiddleware creates a middleware that applies rate limiting +// on a per-method basis. +// Methods not specified in limiters will not be rate limited by this middleware. +func PerMethodRateLimiterMiddleware[S mcp.Session](limiters map[string]*rate.Limiter) mcp.Middleware[S] { + return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { + return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { + if limiter, ok := limiters[method]; ok { + if !limiter.Allow() { + return nil, errors.New("JSON RPC overloaded") + } + } + return next(ctx, session, method, params) + } + } +} + +func main() { + server := mcp.NewServer("greeter1", "v0.0.1", nil) + server.AddReceivingMiddleware(GlobalRateLimiterMiddleware[*mcp.ServerSession](rate.NewLimiter(rate.Every(time.Second/5), 10))) + server.AddReceivingMiddleware(PerMethodRateLimiterMiddleware[*mcp.ServerSession](map[string]*rate.Limiter{ + "callTool": rate.NewLimiter(rate.Every(time.Second), 5), // once a second with a burst up to 5 + "listTools": rate.NewLimiter(rate.Every(time.Minute), 20), // once a minute with a burst up to 20 + })) + // Run Server logic. +}