Skip to content

Commit 3db848a

Browse files
committed
mcp: implement a concurrency model for calls
Implement the concurrency model described in #26: notifications are synchronous, but calls are asynchronous (except for 'initialize'). To achieve this, implement jsonrpc2.Async(ctx) to signal asynchronous handling. This is simpler to use than returning ErrAsyncResponse and calling Respond, and since this is an internal detail we don't need to worry too much about whether it's idiomatic. Add tests that verify both features, for both client and server. Also: - replace req.ID.IsValid with req.IsCall - remove the methodHandler type as we can just use MethodHandler Fixes #26
1 parent 112ca4e commit 3db848a

File tree

11 files changed

+208
-86
lines changed

11 files changed

+208
-86
lines changed

internal/jsonrpc2/conn.go

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,46 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async
374374
return ac
375375
}
376376

377+
// Async, signals that the current jsonrpc2 request may be handled
378+
// asynchronously to subsequent requests, when ctx is the request context.
379+
//
380+
// Async must be called at most once on each request's context (and its
381+
// descendants).
382+
func Async(ctx context.Context) {
383+
if r, ok := ctx.Value(asyncKey).(*releaser); ok {
384+
r.release(false)
385+
}
386+
}
387+
388+
type asyncKeyType struct{}
389+
390+
var asyncKey = asyncKeyType{}
391+
392+
// A releaser implements concurrency safe 'releasing' of async requests. (A
393+
// request is released when it is allowed to run concurrent with other
394+
// requests, via a call to [Async].)
395+
type releaser struct {
396+
mu sync.Mutex
397+
ch chan struct{}
398+
released bool
399+
}
400+
401+
// release closes the associated channel. If soft is set, multiple calls to
402+
// release are allowed.
403+
func (r *releaser) release(soft bool) {
404+
r.mu.Lock()
405+
defer r.mu.Unlock()
406+
407+
if r.released {
408+
if !soft {
409+
panic("jsonrpc2.Async called multiple times")
410+
}
411+
} else {
412+
close(r.ch)
413+
r.released = true
414+
}
415+
}
416+
377417
type AsyncCall struct {
378418
id ID
379419
ready chan struct{} // closed after response has been set
@@ -425,28 +465,6 @@ func (ac *AsyncCall) Await(ctx context.Context, result any) error {
425465
return json.Unmarshal(ac.response.Result, result)
426466
}
427467

428-
// Respond delivers a response to an incoming Call.
429-
//
430-
// Respond must be called exactly once for any message for which a handler
431-
// returns ErrAsyncResponse. It must not be called for any other message.
432-
func (c *Connection) Respond(id ID, result any, err error) error {
433-
var req *incomingRequest
434-
c.updateInFlight(func(s *inFlightState) {
435-
req = s.incomingByID[id]
436-
})
437-
if req == nil {
438-
return c.internalErrorf("Request not found for ID %v", id)
439-
}
440-
441-
if err == ErrAsyncResponse {
442-
// Respond is supposed to supply the asynchronous response, so it would be
443-
// confusing to call Respond with an error that promises to call Respond
444-
// again.
445-
err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method)
446-
}
447-
return c.processResult("Respond", req, result, err)
448-
}
449-
450468
// Cancel cancels the Context passed to the Handle call for the inbound message
451469
// with the given ID.
452470
//
@@ -576,11 +594,6 @@ func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter
576594
if preempter != nil {
577595
result, err := preempter.Preempt(req.ctx, req.Request)
578596

579-
if req.IsCall() && errors.Is(err, ErrAsyncResponse) {
580-
// This request will remain in flight until Respond is called for it.
581-
return
582-
}
583-
584597
if !errors.Is(err, ErrNotHandled) {
585598
c.processResult("Preempt", req, result, err)
586599
return
@@ -655,19 +668,20 @@ func (c *Connection) handleAsync() {
655668
continue
656669
}
657670

658-
result, err := c.handler.Handle(req.ctx, req.Request)
659-
c.processResult(c.handler, req, result, err)
671+
releaser := &releaser{ch: make(chan struct{})}
672+
ctx := context.WithValue(req.ctx, asyncKey, releaser)
673+
go func() {
674+
defer releaser.release(true)
675+
result, err := c.handler.Handle(ctx, req.Request)
676+
c.processResult(c.handler, req, result, err)
677+
}()
678+
<-releaser.ch
660679
}
661680
}
662681

663682
// processResult processes the result of a request and, if appropriate, sends a response.
664683
func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error {
665684
switch err {
666-
case ErrAsyncResponse:
667-
if !req.IsCall() {
668-
return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method)
669-
}
670-
return nil // This request is still in flight, so don't record the result yet.
671685
case ErrNotHandled, ErrMethodNotFound:
672686
// Add detail describing the unhandled method.
673687
err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method)

internal/jsonrpc2/jsonrpc2.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@ var (
2222
// If a Handler returns ErrNotHandled, the server replies with
2323
// ErrMethodNotFound.
2424
ErrNotHandled = errors.New("JSON RPC not handled")
25-
26-
// ErrAsyncResponse is returned from a handler to indicate it will generate a
27-
// response asynchronously.
28-
//
29-
// ErrAsyncResponse must not be returned for notifications,
30-
// which do not receive responses.
31-
ErrAsyncResponse = errors.New("JSON RPC asynchronous response")
3225
)
3326

3427
// Preempter handles messages on a connection before they are queued to the main

internal/jsonrpc2/jsonrpc2_test.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -371,16 +371,14 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error
371371
if err := json.Unmarshal(req.Params, &name); err != nil {
372372
return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
373373
}
374+
jsonrpc2.Async(ctx)
374375
waitFor := h.waiter(name)
375-
go func() {
376-
select {
377-
case <-waitFor:
378-
h.conn.Respond(req.ID, true, nil)
379-
case <-ctx.Done():
380-
h.conn.Respond(req.ID, nil, ctx.Err())
381-
}
382-
}()
383-
return nil, jsonrpc2.ErrAsyncResponse
376+
select {
377+
case <-waitFor:
378+
return true, nil
379+
case <-ctx.Done():
380+
return nil, ctx.Err()
381+
}
384382
default:
385383
return nil, jsonrpc2.ErrNotHandled
386384
}

mcp/client.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,16 +328,19 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo {
328328
}
329329

330330
func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) {
331+
if req.IsCall() {
332+
jsonrpc2.Async(ctx)
333+
}
331334
return handleReceive(ctx, cs, req)
332335
}
333336

334-
func (cs *ClientSession) sendingMethodHandler() methodHandler {
337+
func (cs *ClientSession) sendingMethodHandler() MethodHandler {
335338
cs.client.mu.Lock()
336339
defer cs.client.mu.Unlock()
337340
return cs.client.sendingMethodHandler_
338341
}
339342

340-
func (cs *ClientSession) receivingMethodHandler() methodHandler {
343+
func (cs *ClientSession) receivingMethodHandler() MethodHandler {
341344
cs.client.mu.Lock()
342345
defer cs.client.mu.Unlock()
343346
return cs.client.receivingMethodHandler_

mcp/conformance_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func runServerTest(t *testing.T, test *conformanceTest) {
183183
return nil, err, false
184184
}
185185
serverMessages = append(serverMessages, msg)
186-
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() {
186+
if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() {
187187
// Pair up the next outgoing response with this request.
188188
// We assume requests arrive in the same order every time.
189189
if len(outResponses) == 0 {
@@ -201,8 +201,8 @@ func runServerTest(t *testing.T, test *conformanceTest) {
201201
// Synthetic peer interacts with real peer.
202202
for _, req := range outRequests {
203203
writeMsg(req)
204-
if req.ID.IsValid() {
205-
// A request (as opposed to a notification). Wait for the response.
204+
if req.IsCall() {
205+
// A call (as opposed to a notification). Wait for the response.
206206
res, err, ok := nextResponse()
207207
if err != nil {
208208
t.Fatalf("reading server messages failed: %v", err)

mcp/content.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e
253253

254254
func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) {
255255
if wire == nil {
256-
return nil, fmt.Errorf("content wire is nil")
256+
return nil, fmt.Errorf("nil content")
257257
}
258258
if allow != nil && !allow[wire.Type] {
259259
return nil, fmt.Errorf("invalid content type %q", wire.Type)

mcp/mcp_test.go

Lines changed: 122 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"slices"
1818
"strings"
1919
"sync"
20+
"sync/atomic"
2021
"testing"
2122
"time"
2223

@@ -549,31 +550,47 @@ func errorCode(err error) int64 {
549550
//
550551
// The caller should cancel either the client connection or server connection
551552
// when the connections are no longer needed.
552-
func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *ClientSession) {
553+
func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession) {
554+
return basicClientServerConnection(t, nil, nil, config)
555+
}
556+
557+
// basicClientServerConnection creates a basic connection between client and
558+
// server. If either client or server is nil, empty implementations are used.
559+
//
560+
// The provided function may be used to configure features on the resulting
561+
// server, prior to connection.
562+
//
563+
// The caller should cancel either the client connection or server connection
564+
// when the connections are no longer needed.
565+
func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession) {
553566
t.Helper()
554567

555568
ctx := context.Background()
556569
ct, st := NewInMemoryTransports()
557570

558-
s := NewServer(testImpl, nil)
571+
if server == nil {
572+
server = NewServer(testImpl, nil)
573+
}
559574
if config != nil {
560-
config(s)
575+
config(server)
561576
}
562-
ss, err := s.Connect(ctx, st, nil)
577+
ss, err := server.Connect(ctx, st, nil)
563578
if err != nil {
564579
t.Fatal(err)
565580
}
566581

567-
c := NewClient(testImpl, nil)
568-
cs, err := c.Connect(ctx, ct, nil)
582+
if client == nil {
583+
client = NewClient(testImpl, nil)
584+
}
585+
cs, err := client.Connect(ctx, ct, nil)
569586
if err != nil {
570587
t.Fatal(err)
571588
}
572-
return ss, cs
589+
return cs, ss
573590
}
574591

575592
func TestServerClosing(t *testing.T) {
576-
cc, cs := basicConnection(t, func(s *Server) {
593+
cs, ss := basicConnection(t, func(s *Server) {
577594
AddTool(s, greetTool(), sayHi)
578595
})
579596
defer cs.Close()
@@ -593,7 +610,7 @@ func TestServerClosing(t *testing.T) {
593610
}); err != nil {
594611
t.Fatalf("after connecting: %v", err)
595612
}
596-
cc.Close()
613+
ss.Close()
597614
wg.Wait()
598615
if _, err := cs.CallTool(ctx, &CallToolParams{
599616
Name: "greet",
@@ -656,7 +673,7 @@ func TestCancellation(t *testing.T) {
656673
}
657674
return nil, nil
658675
}
659-
_, cs := basicConnection(t, func(s *Server) {
676+
cs, _ := basicConnection(t, func(s *Server) {
660677
AddTool(s, &Tool{Name: "slow"}, slowRequest)
661678
})
662679
defer cs.Close()
@@ -940,7 +957,7 @@ func TestKeepAliveFailure(t *testing.T) {
940957
func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
941958
// Adding the same tool pointer twice should not panic and should not
942959
// produce duplicates in the server's tool list.
943-
_, cs := basicConnection(t, func(s *Server) {
960+
cs, _ := basicConnection(t, func(s *Server) {
944961
// Use two distinct Tool instances with the same name but different
945962
// descriptions to ensure the second replaces the first
946963
// This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors
@@ -972,4 +989,98 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
972989
}
973990
}
974991

992+
func TestSynchronousNotifications(t *testing.T) {
993+
var toolsChanged atomic.Bool
994+
clientOpts := &ClientOptions{
995+
ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) {
996+
toolsChanged.Store(true)
997+
},
998+
CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
999+
if !toolsChanged.Load() {
1000+
return nil, fmt.Errorf("didn't get a tools changed notification")
1001+
}
1002+
// TODO(rfindley): investigate the error returned from this test if
1003+
// CreateMessageResult is new(CreateMessageResult): it's a mysterious
1004+
// unmarshalling error that we should improve.
1005+
return &CreateMessageResult{Content: &TextContent{}}, nil
1006+
},
1007+
}
1008+
client := NewClient(testImpl, clientOpts)
1009+
1010+
var rootsChanged atomic.Bool
1011+
serverOpts := &ServerOptions{
1012+
RootsListChangedHandler: func(_ context.Context, req *ServerRequest[*RootsListChangedParams]) {
1013+
rootsChanged.Store(true)
1014+
},
1015+
}
1016+
server := NewServer(testImpl, serverOpts)
1017+
cs, ss := basicClientServerConnection(t, client, server, func(s *Server) {
1018+
AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
1019+
if !rootsChanged.Load() {
1020+
return nil, fmt.Errorf("didn't get root change notification")
1021+
}
1022+
return new(CallToolResult), nil
1023+
})
1024+
})
1025+
1026+
t.Run("from client", func(t *testing.T) {
1027+
client.AddRoots(&Root{Name: "myroot", URI: "file://foo"})
1028+
res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"})
1029+
if err != nil {
1030+
t.Fatalf("CallTool failed: %v", err)
1031+
}
1032+
if res.IsError {
1033+
t.Errorf("tool error: %v", res.Content[0].(*TextContent).Text)
1034+
}
1035+
})
1036+
1037+
t.Run("from server", func(t *testing.T) {
1038+
server.RemoveTools("tool")
1039+
if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil {
1040+
t.Errorf("CreateMessage failed: %v", err)
1041+
}
1042+
})
1043+
}
1044+
1045+
func TestNoDistributedDeadlock(t *testing.T) {
1046+
// This test verifies that calls are asynchronous, and so it's not possible
1047+
// to have a distributed deadlock.
1048+
//
1049+
// The setup creates potential deadlock for both the client and server: the
1050+
// client sends a call to tool1, which itself calls createMessage, which in
1051+
// turn calls tool2, which calls ping.
1052+
//
1053+
// If the server were not asynchronous, the call to tool2 would hang. If the
1054+
// client were not asynchronous, the call to ping would hang.
1055+
//
1056+
// Such a scenario is unlikely in practice, but is still theoretically
1057+
// possible, and in any case making tool calls asynchronous by default
1058+
// delegates synchronization to the user.
1059+
clientOpts := &ClientOptions{
1060+
CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
1061+
req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"})
1062+
return &CreateMessageResult{Content: &TextContent{}}, nil
1063+
},
1064+
}
1065+
client := NewClient(testImpl, clientOpts)
1066+
cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) {
1067+
AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
1068+
req.Session.CreateMessage(ctx, new(CreateMessageParams))
1069+
return new(CallToolResult), nil
1070+
})
1071+
AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
1072+
req.Session.Ping(ctx, nil)
1073+
return new(CallToolResult), nil
1074+
})
1075+
})
1076+
defer cs.Close()
1077+
1078+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1079+
defer cancel()
1080+
if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}); err != nil {
1081+
// should not deadlock
1082+
t.Fatalf("CallTool failed: %v", err)
1083+
}
1084+
}
1085+
9751086
var testImpl = &Implementation{Name: "test", Version: "v1.0.0"}

0 commit comments

Comments
 (0)