diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 6f48c9ba..963350e7 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -374,6 +374,46 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async return ac } +// Async, signals that the current jsonrpc2 request may be handled +// asynchronously to subsequent requests, when ctx is the request context. +// +// Async must be called at most once on each request's context (and its +// descendants). +func Async(ctx context.Context) { + if r, ok := ctx.Value(asyncKey).(*releaser); ok { + r.release(false) + } +} + +type asyncKeyType struct{} + +var asyncKey = asyncKeyType{} + +// A releaser implements concurrency safe 'releasing' of async requests. (A +// request is released when it is allowed to run concurrent with other +// requests, via a call to [Async].) +type releaser struct { + mu sync.Mutex + ch chan struct{} + released bool +} + +// release closes the associated channel. If soft is set, multiple calls to +// release are allowed. +func (r *releaser) release(soft bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.released { + if !soft { + panic("jsonrpc2.Async called multiple times") + } + } else { + close(r.ch) + r.released = true + } +} + type AsyncCall struct { id ID ready chan struct{} // closed after response has been set @@ -425,28 +465,6 @@ func (ac *AsyncCall) Await(ctx context.Context, result any) error { return json.Unmarshal(ac.response.Result, result) } -// Respond delivers a response to an incoming Call. -// -// Respond must be called exactly once for any message for which a handler -// returns ErrAsyncResponse. It must not be called for any other message. -func (c *Connection) Respond(id ID, result any, err error) error { - var req *incomingRequest - c.updateInFlight(func(s *inFlightState) { - req = s.incomingByID[id] - }) - if req == nil { - return c.internalErrorf("Request not found for ID %v", id) - } - - if err == ErrAsyncResponse { - // Respond is supposed to supply the asynchronous response, so it would be - // confusing to call Respond with an error that promises to call Respond - // again. - err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method) - } - return c.processResult("Respond", req, result, err) -} - // Cancel cancels the Context passed to the Handle call for the inbound message // with the given ID. // @@ -576,11 +594,6 @@ func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter if preempter != nil { result, err := preempter.Preempt(req.ctx, req.Request) - if req.IsCall() && errors.Is(err, ErrAsyncResponse) { - // This request will remain in flight until Respond is called for it. - return - } - if !errors.Is(err, ErrNotHandled) { c.processResult("Preempt", req, result, err) return @@ -655,19 +668,20 @@ func (c *Connection) handleAsync() { continue } - result, err := c.handler.Handle(req.ctx, req.Request) - c.processResult(c.handler, req, result, err) + releaser := &releaser{ch: make(chan struct{})} + ctx := context.WithValue(req.ctx, asyncKey, releaser) + go func() { + defer releaser.release(true) + result, err := c.handler.Handle(ctx, req.Request) + c.processResult(c.handler, req, result, err) + }() + <-releaser.ch } } // processResult processes the result of a request and, if appropriate, sends a response. func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error { switch err { - case ErrAsyncResponse: - if !req.IsCall() { - return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method) - } - return nil // This request is still in flight, so don't record the result yet. case ErrNotHandled, ErrMethodNotFound: // Add detail describing the unhandled method. err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method) diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index b9c320c8..234e6ee3 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -22,13 +22,6 @@ var ( // If a Handler returns ErrNotHandled, the server replies with // ErrMethodNotFound. ErrNotHandled = errors.New("JSON RPC not handled") - - // ErrAsyncResponse is returned from a handler to indicate it will generate a - // response asynchronously. - // - // ErrAsyncResponse must not be returned for notifications, - // which do not receive responses. - ErrAsyncResponse = errors.New("JSON RPC asynchronous response") ) // Preempter handles messages on a connection before they are queued to the main diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go index 16a5039b..8c79300c 100644 --- a/internal/jsonrpc2/jsonrpc2_test.go +++ b/internal/jsonrpc2/jsonrpc2_test.go @@ -371,16 +371,14 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error if err := json.Unmarshal(req.Params, &name); err != nil { return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } + jsonrpc2.Async(ctx) waitFor := h.waiter(name) - go func() { - select { - case <-waitFor: - h.conn.Respond(req.ID, true, nil) - case <-ctx.Done(): - h.conn.Respond(req.ID, nil, ctx.Err()) - } - }() - return nil, jsonrpc2.ErrAsyncResponse + select { + case <-waitFor: + return true, nil + case <-ctx.Done(): + return nil, ctx.Err() + } default: return nil, jsonrpc2.ErrNotHandled } diff --git a/mcp/client.go b/mcp/client.go index 65a7a954..33530e05 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -329,16 +329,19 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { } func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + if req.IsCall() { + jsonrpc2.Async(ctx) + } return handleReceive(ctx, cs, req) } -func (cs *ClientSession) sendingMethodHandler() methodHandler { +func (cs *ClientSession) sendingMethodHandler() MethodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.sendingMethodHandler_ } -func (cs *ClientSession) receivingMethodHandler() methodHandler { +func (cs *ClientSession) receivingMethodHandler() MethodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.receivingMethodHandler_ diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 8e6ea1be..9bd8b8f6 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -183,7 +183,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { return nil, err, false } serverMessages = append(serverMessages, msg) - if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() { + if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() { // Pair up the next outgoing response with this request. // We assume requests arrive in the same order every time. if len(outResponses) == 0 { @@ -201,8 +201,8 @@ func runServerTest(t *testing.T, test *conformanceTest) { // Synthetic peer interacts with real peer. for _, req := range outRequests { writeMsg(req) - if req.ID.IsValid() { - // A request (as opposed to a notification). Wait for the response. + if req.IsCall() { + // A call (as opposed to a notification). Wait for the response. res, err, ok := nextResponse() if err != nil { t.Fatalf("reading server messages failed: %v", err) diff --git a/mcp/content.go b/mcp/content.go index f8777154..108b0271 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -253,7 +253,7 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { if wire == nil { - return nil, fmt.Errorf("content wire is nil") + return nil, fmt.Errorf("nil content") } if allow != nil && !allow[wire.Type] { return nil, fmt.Errorf("invalid content type %q", wire.Type) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index d04235b8..91d350e8 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -17,6 +17,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "testing" "time" @@ -549,31 +550,47 @@ func errorCode(err error) int64 { // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *ClientSession) { +func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession) { + return basicClientServerConnection(t, nil, nil, config) +} + +// basicClientServerConnection creates a basic connection between client and +// server. If either client or server is nil, empty implementations are used. +// +// The provided function may be used to configure features on the resulting +// server, prior to connection. +// +// The caller should cancel either the client connection or server connection +// when the connections are no longer needed. +func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession) { t.Helper() ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer(testImpl, nil) + if server == nil { + server = NewServer(testImpl, nil) + } if config != nil { - config(s) + config(server) } - ss, err := s.Connect(ctx, st, nil) + ss, err := server.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) + if client == nil { + client = NewClient(testImpl, nil) + } + cs, err := client.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } - return ss, cs + return cs, ss } func TestServerClosing(t *testing.T) { - cc, cs := basicConnection(t, func(s *Server) { + cs, ss := basicConnection(t, func(s *Server) { AddTool(s, greetTool(), sayHi) }) defer cs.Close() @@ -593,7 +610,7 @@ func TestServerClosing(t *testing.T) { }); err != nil { t.Fatalf("after connecting: %v", err) } - cc.Close() + ss.Close() wg.Wait() if _, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", @@ -656,7 +673,7 @@ func TestCancellation(t *testing.T) { } return nil, nil } - _, cs := basicConnection(t, func(s *Server) { + cs, _ := basicConnection(t, func(s *Server) { AddTool(s, &Tool{Name: "slow"}, slowRequest) }) defer cs.Close() @@ -940,7 +957,7 @@ func TestKeepAliveFailure(t *testing.T) { func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { // Adding the same tool pointer twice should not panic and should not // produce duplicates in the server's tool list. - _, cs := basicConnection(t, func(s *Server) { + cs, _ := basicConnection(t, func(s *Server) { // Use two distinct Tool instances with the same name but different // descriptions to ensure the second replaces the first // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors @@ -972,4 +989,98 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { } } +func TestSynchronousNotifications(t *testing.T) { + var toolsChanged atomic.Bool + clientOpts := &ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) { + toolsChanged.Store(true) + }, + CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + if !toolsChanged.Load() { + return nil, fmt.Errorf("didn't get a tools changed notification") + } + // TODO(rfindley): investigate the error returned from this test if + // CreateMessageResult is new(CreateMessageResult): it's a mysterious + // unmarshalling error that we should improve. + return &CreateMessageResult{Content: &TextContent{}}, nil + }, + } + client := NewClient(testImpl, clientOpts) + + var rootsChanged atomic.Bool + serverOpts := &ServerOptions{ + RootsListChangedHandler: func(_ context.Context, req *ServerRequest[*RootsListChangedParams]) { + rootsChanged.Store(true) + }, + } + server := NewServer(testImpl, serverOpts) + cs, ss := basicClientServerConnection(t, client, server, func(s *Server) { + AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + if !rootsChanged.Load() { + return nil, fmt.Errorf("didn't get root change notification") + } + return new(CallToolResult), nil + }) + }) + + t.Run("from client", func(t *testing.T) { + client.AddRoots(&Root{Name: "myroot", URI: "file://foo"}) + res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"}) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + if res.IsError { + t.Errorf("tool error: %v", res.Content[0].(*TextContent).Text) + } + }) + + t.Run("from server", func(t *testing.T) { + server.RemoveTools("tool") + if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil { + t.Errorf("CreateMessage failed: %v", err) + } + }) +} + +func TestNoDistributedDeadlock(t *testing.T) { + // This test verifies that calls are asynchronous, and so it's not possible + // to have a distributed deadlock. + // + // The setup creates potential deadlock for both the client and server: the + // client sends a call to tool1, which itself calls createMessage, which in + // turn calls tool2, which calls ping. + // + // If the server were not asynchronous, the call to tool2 would hang. If the + // client were not asynchronous, the call to ping would hang. + // + // Such a scenario is unlikely in practice, but is still theoretically + // possible, and in any case making tool calls asynchronous by default + // delegates synchronization to the user. + clientOpts := &ClientOptions{ + CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"}) + return &CreateMessageResult{Content: &TextContent{}}, nil + }, + } + client := NewClient(testImpl, clientOpts) + cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) { + AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + req.Session.CreateMessage(ctx, new(CreateMessageParams)) + return new(CallToolResult), nil + }) + AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + req.Session.Ping(ctx, nil) + return new(CallToolResult), nil + }) + }) + defer cs.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}); err != nil { + // should not deadlock + t.Fatalf("CallTool failed: %v", err) + } +} + var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} diff --git a/mcp/server.go b/mcp/server.go index 5bc626b3..85c2b5ed 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -779,14 +779,14 @@ func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return cli func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } -func (ss *ServerSession) sendingMethodHandler() methodHandler { +func (ss *ServerSession) sendingMethodHandler() MethodHandler { s := ss.server s.mu.Lock() defer s.mu.Unlock() return s.sendingMethodHandler_ } -func (ss *ServerSession) receivingMethodHandler() methodHandler { +func (ss *ServerSession) receivingMethodHandler() MethodHandler { s := ss.server s.mu.Lock() defer s.mu.Unlock() @@ -801,6 +801,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, ss.mu.Lock() initialized := ss.state.InitializedParams != nil ss.mu.Unlock() + // From the spec: // "The client SHOULD NOT send requests other than pings before the server // has responded to the initialize request." @@ -811,6 +812,14 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } + + // modelcontextprotocol/go-sdk#26: handle calls asynchronously, and + // notifications synchronously, except for 'initialize' which shouldn't be + // asynchronous to other + if req.IsCall() && req.Method != methodInitialize { + jsonrpc2.Async(ctx) + } + // For the streamable transport, we need the request ID to correlate // server->client calls and notifications to the incoming request from which // they originated. See [idContextKey] for details. diff --git a/mcp/shared.go b/mcp/shared.go index ca062214..608e2aaf 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -38,12 +38,6 @@ var supportedProtocolVersions = []string{ // For notifications, both must be nil. type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) -// A methodHandler is a MethodHandler[Session] for some session. -// We need to give up type safety here, or we will end up with a type cycle somewhere -// else. For example, if Session.methodHandler returned a MethodHandler[Session], -// the compiler would complain. -type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerSession] - // A Session is either a [ClientSession] or a [ServerSession]. type Session interface { // ID returns the session ID, or the empty string if there is none. @@ -51,8 +45,8 @@ type Session interface { sendingMethodInfos() map[string]methodInfo receivingMethodInfos() map[string]methodInfo - sendingMethodHandler() methodHandler - receivingMethodHandler() methodHandler + sendingMethodHandler() MethodHandler + receivingMethodHandler() MethodHandler getConn() *jsonrpc2.Connection } @@ -95,13 +89,13 @@ func orZero[T any, P *U, U any](p P) T { } func handleNotify(ctx context.Context, method string, req Request) error { - mh := req.GetSession().sendingMethodHandler().(MethodHandler) + mh := req.GetSession().sendingMethodHandler() _, err := mh(ctx, method, req) return err } func handleSend[R Result](ctx context.Context, method string, req Request) (R, error) { - mh := req.GetSession().sendingMethodHandler().(MethodHandler) + mh := req.GetSession().sendingMethodHandler() // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, method, req) if err != nil { @@ -118,7 +112,7 @@ func defaultReceivingMethodHandler[S Session](ctx context.Context, method string // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } - return info.handleMethod.(MethodHandler)(ctx, method, req) + return info.handleMethod(ctx, method, req) } func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Request) (Result, error) { @@ -131,7 +125,7 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ return nil, fmt.Errorf("handling '%s': %w", jreq.Method, err) } - mh := session.receivingMethodHandler().(MethodHandler) + mh := session.receivingMethodHandler() req := info.newRequest(session, params) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, jreq.Method, req) @@ -154,10 +148,10 @@ func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo if !ok { return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) } - if info.flags¬ification != 0 && req.ID.IsValid() { + if info.flags¬ification != 0 && req.IsCall() { return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) } - if info.flags¬ification == 0 && !req.ID.IsValid() { + if info.flags¬ification == 0 && !req.IsCall() { return methodInfo{}, fmt.Errorf("%w: missing id for %q", jsonrpc2.ErrInvalidRequest, req.Method) } // missingParamsOK is checked here to catch the common case where "params" is @@ -182,7 +176,7 @@ type methodInfo struct { newRequest func(Session, Params) Request // Run the code when a call to the method is received. // Used on the receive side. - handleMethod methodHandler + handleMethod MethodHandler // Create a pointer to a Result struct. // Used on the send side. newResult func() Result diff --git a/mcp/streamable.go b/mcp/streamable.go index 9ae20c02..1ecf201f 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -509,7 +509,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } - if req.ID.IsValid() { + if req.IsCall() { requests[req.ID] = struct{}{} } } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 4181303f..e0b00cc6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -689,7 +689,7 @@ func TestStreamableServerTransport(t *testing.T) { defer wg.Done() for m := range out { - if req, ok := m.(*jsonrpc.Request); ok && req.ID.IsValid() { + if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { // Encountered a server->client request. We should have a // response queued. Otherwise, we may deadlock. mu.Lock()