Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,16 @@ func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]

// clientMethodInfos maps from the RPC method name to serverMethodInfos.
var clientMethodInfos = map[string]methodInfo{
methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete)),
methodPing: newMethodInfo(sessionMethod((*ClientSession).ping)),
methodListRoots: newMethodInfo(clientMethod((*Client).listRoots)),
methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage)),
notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)),
notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)),
notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)),
notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler)),
notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)),
notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)),
methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete), true),
methodPing: newMethodInfo(sessionMethod((*ClientSession).ping), true),
methodListRoots: newMethodInfo(clientMethod((*Client).listRoots), true),
methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage), true),
notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler), false),
notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler), false),
notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler), false),
notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), false),
notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler), false),
notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler), false),
}

func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo {
Expand All @@ -323,7 +323,7 @@ func (cs *ClientSession) receivingMethodHandler() methodHandler {
return cs.client.receivingMethodHandler_
}

// getConn implements [session.getConn].
// getConn implements [Session.getConn].
func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn }

func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) {
Expand Down
32 changes: 16 additions & 16 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,22 +688,22 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession]

// serverMethodInfos maps from the RPC method name to serverMethodInfos.
var serverMethodInfos = map[string]methodInfo{
methodComplete: newMethodInfo(serverMethod((*Server).complete)),
methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize)),
methodPing: newMethodInfo(sessionMethod((*ServerSession).ping)),
methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts)),
methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt)),
methodListTools: newMethodInfo(serverMethod((*Server).listTools)),
methodCallTool: newMethodInfo(serverMethod((*Server).callTool)),
methodListResources: newMethodInfo(serverMethod((*Server).listResources)),
methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)),
methodReadResource: newMethodInfo(serverMethod((*Server).readResource)),
methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)),
methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe)),
methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe)),
notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)),
notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)),
notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)),
methodComplete: newMethodInfo(serverMethod((*Server).complete), true),
methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize), true),
methodPing: newMethodInfo(sessionMethod((*ServerSession).ping), true),
methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts), true),
methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt), true),
methodListTools: newMethodInfo(serverMethod((*Server).listTools), true),
methodCallTool: newMethodInfo(serverMethod((*Server).callTool), true),
methodListResources: newMethodInfo(serverMethod((*Server).listResources), true),
methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates), true),
methodReadResource: newMethodInfo(serverMethod((*Server).readResource), true),
methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), true),
methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), true),
methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), true),
notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), false),
notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), false),
notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), false),
}

func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos }
Expand Down
34 changes: 30 additions & 4 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ func defaultReceivingMethodHandler[S Session](ctx context.Context, session S, me
}

func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Request) (Result, error) {
info, ok := session.receivingMethodInfos()[req.Method]
if !ok {
return nil, jsonrpc2.ErrNotHandled
info, err := checkRequest(req, session.receivingMethodInfos())
if err != nil {
return nil, err
}
params, err := info.unmarshalParams(req.Params)
if err != nil {
Expand All @@ -141,8 +141,30 @@ func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Reque
return res, nil
}

// checkRequest checks the given request against the provided method info, to
// ensure it is a valid MCP request.
//
// If valid, the relevant method info is returned. Otherwise, a non-nil error
// is returned describing why the request is invalid.
//
// This is extracted from request handling so that it can be called in the
// transport layer to preemptively reject bad requests.
func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo, error) {
info, ok := infos[req.Method]
if !ok {
return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method)
}
if info.isRequest && !req.ID.IsValid() {
return methodInfo{}, fmt.Errorf("%w: %q missing ID", jsonrpc2.ErrInvalidRequest, req.Method)
}
return info, nil
}

// methodInfo is information about sending and receiving a method.
type methodInfo struct {
// isRequest reports whether the method is a JSON-RPC request.
// Otherwise, the method is treated as a notification.
isRequest bool
// Unmarshal params from the wire into a Params struct.
// Used on the receive side.
unmarshalParams func(json.RawMessage) (Params, error)
Expand All @@ -169,8 +191,12 @@ type paramsPtr[T any] interface {
}

// newMethodInfo creates a methodInfo from a typedMethodHandler.
func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R]) methodInfo {
//
// If isRequest is set, the method is treated as a request rather than a
// notification.
func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R], isRequest bool) methodInfo {
return methodInfo{
isRequest: isRequest,
unmarshalParams: func(m json.RawMessage) (Params, error) {
var p P
if m != nil {
Expand Down
6 changes: 6 additions & 0 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request)
http.Error(w, "failed to parse body", http.StatusBadRequest)
return
}
if req, ok := msg.(*jsonrpc.Request); ok {
if _, err := checkRequest(req, serverMethodInfos); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}
select {
case t.incoming <- msg:
w.WriteHeader(http.StatusAccepted)
Expand Down
43 changes: 39 additions & 4 deletions mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
package mcp

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
Expand All @@ -24,10 +26,10 @@ func TestSSEServer(t *testing.T) {

sseHandler := NewSSEHandler(func(*http.Request) *Server { return server })

conns := make(chan *ServerSession, 1)
sseHandler.onConnection = func(cc *ServerSession) {
serverSessions := make(chan *ServerSession, 1)
sseHandler.onConnection = func(ss *ServerSession) {
select {
case conns <- cc:
case serverSessions <- ss:
default:
}
}
Expand All @@ -54,7 +56,7 @@ func TestSSEServer(t *testing.T) {
if err := cs.Ping(ctx, nil); err != nil {
t.Fatal(err)
}
ss := <-conns
ss := <-serverSessions
gotHi, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
Expand All @@ -76,6 +78,39 @@ func TestSSEServer(t *testing.T) {
t.Error("Expected custom HTTP client to be used, but it wasn't")
}

t.Run("badrequests", func(t *testing.T) {
msgEndpoint := cs.mcpConn.(*sseClientConn).msgEndpoint.String()

// Test some invalid data, and verify that we get 400s.
badRequests := []struct {
name string
body string
responseContains string
}{
{"not a method", `{"jsonrpc":"2.0", "method":"notamethod"}`, "not handled"},
{"missing ID", `{"jsonrpc":"2.0", "method":"ping"}`, "missing ID"},
}
for _, r := range badRequests {
t.Run(r.name, func(t *testing.T) {
resp, err := http.Post(msgEndpoint, "application/json", bytes.NewReader([]byte(r.body)))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if got, want := resp.StatusCode, http.StatusBadRequest; got != want {
t.Errorf("Sending bad request %q: got status %d, want %d", r.body, got, want)
}
result, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Reading response: %v", err)
}
if !bytes.Contains(result, []byte(r.responseContains)) {
t.Errorf("Response body does not contain %q:\n%s", r.responseContains, string(result))
}
})
}
})

// Test that closing either end of the connection terminates the other
// end.
if closeServerFirst {
Expand Down
12 changes: 10 additions & 2 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,16 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
}
requests := make(map[jsonrpc.ID]struct{})
for _, msg := range incoming {
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() {
requests[req.ID] = struct{}{}
if req, ok := msg.(*jsonrpc.Request); ok {
// Preemptively check that this is a valid request, so that we can fail
// the HTTP request. If we didn't do this, a request with a bad method or
// missing ID could be silently swallowed.
if _, err := checkRequest(req, serverMethodInfos); err != nil {
return http.StatusBadRequest, err.Error()
}
if req.ID.IsValid() {
requests[req.ID] = struct{}{}
}
}
}

Expand Down
14 changes: 12 additions & 2 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func TestStreamableServerTransport(t *testing.T) {
}

// Predefined steps, to avoid repetition below.
initReq := req(1, "initialize", &InitializeParams{})
initReq := req(1, methodInitialize, &InitializeParams{})
initResp := resp(1, &InitializeResult{
Capabilities: &serverCapabilities{
Completions: &completionCapabilities{},
Expand All @@ -290,7 +290,7 @@ func TestStreamableServerTransport(t *testing.T) {
ProtocolVersion: latestProtocolVersion,
ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"},
}, nil)
initializedMsg := req(0, "initialized", &InitializedParams{})
initializedMsg := req(0, notificationInitialized, &InitializedParams{})
initialize := step{
Method: "POST",
Send: []jsonrpc.Message{initReq},
Expand Down Expand Up @@ -438,6 +438,16 @@ func TestStreamableServerTransport(t *testing.T) {
Method: "DELETE",
StatusCode: http.StatusBadRequest,
},
{
Method: "POST",
Send: []jsonrpc.Message{req(1, "notamethod", nil)},
StatusCode: http.StatusBadRequest, // notamethod is an invalid method
},
{
Method: "POST",
Send: []jsonrpc.Message{req(0, "tools/call", &CallToolParams{Name: "tool"})},
StatusCode: http.StatusBadRequest, // tools/call must have an ID
},
{
Method: "POST",
Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
Expand Down
54 changes: 54 additions & 0 deletions mcp/testdata/conformance/server/missing_fields.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
Check robustness to missing fields: servers should reject and otherwise ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why shouldn't we return -32600 Invalid Request?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we return anything? There is no ID.

bad requests.

Fixed bugs:
- No id in 'initialize' should not panic (#197).
- No id in 'ping' should not panic (#194).

TODO:
- No params in 'initialize' should not panic (#195).

-- prompts --
code_review

-- client --
{
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": { "name": "ExampleClient", "version": "1.0.0" }
}
}
{
"jsonrpc": "2.0",
"id": 2,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": { "name": "ExampleClient", "version": "1.0.0" }
}
}
{"jsonrpc":"2.0", "method":"ping"}

-- server --
{
"jsonrpc": "2.0",
"id": 2,
"result": {
"capabilities": {
"completions": {},
"logging": {},
"prompts": {
"listChanged": true
}
},
"protocolVersion": "2024-11-05",
"serverInfo": {
"name": "testServer",
"version": "v1.0.0"
}
}
}
Loading