Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions internal/jsonrpc2/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,46 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async
return ac
}

// Async, signals that the current jsonrpc2 request may be handled
// asynchronously to subsequent requests, when ctx is the request context.
//
// Async must be called at most once on each request's context (and its
// descendants).
func Async(ctx context.Context) {
if r, ok := ctx.Value(asyncKey).(*releaser); ok {
r.release(false)
}
}

type asyncKeyType struct{}

var asyncKey = asyncKeyType{}

// A releaser implements concurrency safe 'releasing' of async requests. (A
// request is released when it is allowed to run concurrent with other
// requests, via a call to [Async].)
type releaser struct {
mu sync.Mutex
ch chan struct{}
released bool
}

// release closes the associated channel. If soft is set, multiple calls to
// release are allowed.
func (r *releaser) release(soft bool) {
r.mu.Lock()
defer r.mu.Unlock()

if r.released {
if !soft {
panic("jsonrpc2.Async called multiple times")
}
} else {
close(r.ch)
r.released = true
}
}

type AsyncCall struct {
id ID
ready chan struct{} // closed after response has been set
Expand Down Expand Up @@ -425,28 +465,6 @@ func (ac *AsyncCall) Await(ctx context.Context, result any) error {
return json.Unmarshal(ac.response.Result, result)
}

// Respond delivers a response to an incoming Call.
//
// Respond must be called exactly once for any message for which a handler
// returns ErrAsyncResponse. It must not be called for any other message.
func (c *Connection) Respond(id ID, result any, err error) error {
var req *incomingRequest
c.updateInFlight(func(s *inFlightState) {
req = s.incomingByID[id]
})
if req == nil {
return c.internalErrorf("Request not found for ID %v", id)
}

if err == ErrAsyncResponse {
// Respond is supposed to supply the asynchronous response, so it would be
// confusing to call Respond with an error that promises to call Respond
// again.
err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method)
}
return c.processResult("Respond", req, result, err)
}

// Cancel cancels the Context passed to the Handle call for the inbound message
// with the given ID.
//
Expand Down Expand Up @@ -576,11 +594,6 @@ func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter
if preempter != nil {
result, err := preempter.Preempt(req.ctx, req.Request)

if req.IsCall() && errors.Is(err, ErrAsyncResponse) {
// This request will remain in flight until Respond is called for it.
return
}

if !errors.Is(err, ErrNotHandled) {
c.processResult("Preempt", req, result, err)
return
Expand Down Expand Up @@ -655,19 +668,20 @@ func (c *Connection) handleAsync() {
continue
}

result, err := c.handler.Handle(req.ctx, req.Request)
c.processResult(c.handler, req, result, err)
releaser := &releaser{ch: make(chan struct{})}
ctx := context.WithValue(req.ctx, asyncKey, releaser)
go func() {
defer releaser.release(true)
result, err := c.handler.Handle(ctx, req.Request)
c.processResult(c.handler, req, result, err)
}()
<-releaser.ch
}
}

// processResult processes the result of a request and, if appropriate, sends a response.
func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error {
switch err {
case ErrAsyncResponse:
if !req.IsCall() {
return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method)
}
return nil // This request is still in flight, so don't record the result yet.
case ErrNotHandled, ErrMethodNotFound:
// Add detail describing the unhandled method.
err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method)
Expand Down
7 changes: 0 additions & 7 deletions internal/jsonrpc2/jsonrpc2.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@ var (
// If a Handler returns ErrNotHandled, the server replies with
// ErrMethodNotFound.
ErrNotHandled = errors.New("JSON RPC not handled")

// ErrAsyncResponse is returned from a handler to indicate it will generate a
// response asynchronously.
//
// ErrAsyncResponse must not be returned for notifications,
// which do not receive responses.
ErrAsyncResponse = errors.New("JSON RPC asynchronous response")
)

// Preempter handles messages on a connection before they are queued to the main
Expand Down
16 changes: 7 additions & 9 deletions internal/jsonrpc2/jsonrpc2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,14 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error
if err := json.Unmarshal(req.Params, &name); err != nil {
return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
jsonrpc2.Async(ctx)
waitFor := h.waiter(name)
go func() {
select {
case <-waitFor:
h.conn.Respond(req.ID, true, nil)
case <-ctx.Done():
h.conn.Respond(req.ID, nil, ctx.Err())
}
}()
return nil, jsonrpc2.ErrAsyncResponse
select {
case <-waitFor:
return true, nil
case <-ctx.Done():
return nil, ctx.Err()
}
default:
return nil, jsonrpc2.ErrNotHandled
}
Expand Down
7 changes: 5 additions & 2 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,19 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo {
}

func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) {
if req.IsCall() {
jsonrpc2.Async(ctx)
}
return handleReceive(ctx, cs, req)
}

func (cs *ClientSession) sendingMethodHandler() methodHandler {
func (cs *ClientSession) sendingMethodHandler() MethodHandler {
cs.client.mu.Lock()
defer cs.client.mu.Unlock()
return cs.client.sendingMethodHandler_
}

func (cs *ClientSession) receivingMethodHandler() methodHandler {
func (cs *ClientSession) receivingMethodHandler() MethodHandler {
cs.client.mu.Lock()
defer cs.client.mu.Unlock()
return cs.client.receivingMethodHandler_
Expand Down
6 changes: 3 additions & 3 deletions mcp/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func runServerTest(t *testing.T, test *conformanceTest) {
return nil, err, false
}
serverMessages = append(serverMessages, msg)
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() {
if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() {
// Pair up the next outgoing response with this request.
// We assume requests arrive in the same order every time.
if len(outResponses) == 0 {
Expand All @@ -201,8 +201,8 @@ func runServerTest(t *testing.T, test *conformanceTest) {
// Synthetic peer interacts with real peer.
for _, req := range outRequests {
writeMsg(req)
if req.ID.IsValid() {
// A request (as opposed to a notification). Wait for the response.
if req.IsCall() {
// A call (as opposed to a notification). Wait for the response.
res, err, ok := nextResponse()
if err != nil {
t.Fatalf("reading server messages failed: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion mcp/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e

func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) {
if wire == nil {
return nil, fmt.Errorf("content wire is nil")
return nil, fmt.Errorf("nil content")
}
if allow != nil && !allow[wire.Type] {
return nil, fmt.Errorf("invalid content type %q", wire.Type)
Expand Down
133 changes: 122 additions & 11 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -549,31 +550,47 @@ func errorCode(err error) int64 {
//
// The caller should cancel either the client connection or server connection
// when the connections are no longer needed.
func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *ClientSession) {
func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession) {
return basicClientServerConnection(t, nil, nil, config)
}

// basicClientServerConnection creates a basic connection between client and
// server. If either client or server is nil, empty implementations are used.
//
// The provided function may be used to configure features on the resulting
// server, prior to connection.
//
// The caller should cancel either the client connection or server connection
// when the connections are no longer needed.
func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession) {
t.Helper()

ctx := context.Background()
ct, st := NewInMemoryTransports()

s := NewServer(testImpl, nil)
if server == nil {
server = NewServer(testImpl, nil)
}
if config != nil {
config(s)
config(server)
}
ss, err := s.Connect(ctx, st, nil)
ss, err := server.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}

c := NewClient(testImpl, nil)
cs, err := c.Connect(ctx, ct, nil)
if client == nil {
client = NewClient(testImpl, nil)
}
cs, err := client.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
return ss, cs
return cs, ss
}

func TestServerClosing(t *testing.T) {
cc, cs := basicConnection(t, func(s *Server) {
cs, ss := basicConnection(t, func(s *Server) {
AddTool(s, greetTool(), sayHi)
})
defer cs.Close()
Expand All @@ -593,7 +610,7 @@ func TestServerClosing(t *testing.T) {
}); err != nil {
t.Fatalf("after connecting: %v", err)
}
cc.Close()
ss.Close()
wg.Wait()
if _, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Expand Down Expand Up @@ -656,7 +673,7 @@ func TestCancellation(t *testing.T) {
}
return nil, nil
}
_, cs := basicConnection(t, func(s *Server) {
cs, _ := basicConnection(t, func(s *Server) {
AddTool(s, &Tool{Name: "slow"}, slowRequest)
})
defer cs.Close()
Expand Down Expand Up @@ -940,7 +957,7 @@ func TestKeepAliveFailure(t *testing.T) {
func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
// Adding the same tool pointer twice should not panic and should not
// produce duplicates in the server's tool list.
_, cs := basicConnection(t, func(s *Server) {
cs, _ := basicConnection(t, func(s *Server) {
// Use two distinct Tool instances with the same name but different
// descriptions to ensure the second replaces the first
// This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors
Expand Down Expand Up @@ -972,4 +989,98 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
}
}

func TestSynchronousNotifications(t *testing.T) {
var toolsChanged atomic.Bool
clientOpts := &ClientOptions{
ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) {
toolsChanged.Store(true)
},
CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
if !toolsChanged.Load() {
return nil, fmt.Errorf("didn't get a tools changed notification")
}
// TODO(rfindley): investigate the error returned from this test if
// CreateMessageResult is new(CreateMessageResult): it's a mysterious
// unmarshalling error that we should improve.
return &CreateMessageResult{Content: &TextContent{}}, nil
},
}
client := NewClient(testImpl, clientOpts)

var rootsChanged atomic.Bool
serverOpts := &ServerOptions{
RootsListChangedHandler: func(_ context.Context, req *ServerRequest[*RootsListChangedParams]) {
rootsChanged.Store(true)
},
}
server := NewServer(testImpl, serverOpts)
cs, ss := basicClientServerConnection(t, client, server, func(s *Server) {
AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
if !rootsChanged.Load() {
return nil, fmt.Errorf("didn't get root change notification")
}
return new(CallToolResult), nil
})
})

t.Run("from client", func(t *testing.T) {
client.AddRoots(&Root{Name: "myroot", URI: "file://foo"})
res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"})
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}
if res.IsError {
t.Errorf("tool error: %v", res.Content[0].(*TextContent).Text)
}
})

t.Run("from server", func(t *testing.T) {
server.RemoveTools("tool")
if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil {
t.Errorf("CreateMessage failed: %v", err)
}
})
}

func TestNoDistributedDeadlock(t *testing.T) {
// This test verifies that calls are asynchronous, and so it's not possible
// to have a distributed deadlock.
//
// The setup creates potential deadlock for both the client and server: the
// client sends a call to tool1, which itself calls createMessage, which in
// turn calls tool2, which calls ping.
//
// If the server were not asynchronous, the call to tool2 would hang. If the
// client were not asynchronous, the call to ping would hang.
//
// Such a scenario is unlikely in practice, but is still theoretically
// possible, and in any case making tool calls asynchronous by default
// delegates synchronization to the user.
clientOpts := &ClientOptions{
CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"})
return &CreateMessageResult{Content: &TextContent{}}, nil
},
}
client := NewClient(testImpl, clientOpts)
cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) {
AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
req.Session.CreateMessage(ctx, new(CreateMessageParams))
return new(CallToolResult), nil
})
AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
req.Session.Ping(ctx, nil)
return new(CallToolResult), nil
})
})
defer cs.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}); err != nil {
// should not deadlock
t.Fatalf("CallTool failed: %v", err)
}
}

var testImpl = &Implementation{Name: "test", Version: "v1.0.0"}
Loading