Skip to content

Commit bb6dade

Browse files
committed
mcp: fix cancellation for HTTP transport
In #202, I added the checkRequest helper to validate incoming requests, and invoked it in the stremable transports to preemptively reject invalid HTTP requests, so that a jsonrpc error could be translated to an HTTP error. However, this introduced a bug: since cancellation was handled in the jsonrpc2 layer, we never had to validate it in the mcp layer, and therefore never added methodInfo. As a result, it was reported as an invalid request in the http layer. Add a test, and a fix. The simplest fix was to create stubs that are placeholders for cancellation. This was discovered in the course of investigating #285.
1 parent a834f3c commit bb6dade

File tree

6 files changed

+72
-26
lines changed

6 files changed

+72
-26
lines changed

mcp/client.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ var clientMethodInfos = map[string]methodInfo{
305305
methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK),
306306
methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK),
307307
methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0),
308+
notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK),
308309
notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK),
309310
notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK),
310311
notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK),
@@ -344,6 +345,15 @@ func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) {
344345
return &emptyResult{}, nil
345346
}
346347

348+
// cancel is a placeholder: cancellation is handled the jsonrpc2 package.
349+
//
350+
// It should never be invoked in practice because cancellation is preempted,
351+
// but having its signature here facilitates the construction of methodInfo
352+
// that can be used to validate incoming cancellation notifications.
353+
func (*ClientSession) cancel(context.Context, *CancelledParams) (Result, error) {
354+
return nil, nil
355+
}
356+
347357
func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] {
348358
return &ClientRequest[P]{Session: cs, Params: params}
349359
}

mcp/mcp_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,7 @@ func TestCancellation(t *testing.T) {
646646
start = make(chan struct{})
647647
cancelled = make(chan struct{}, 1) // don't block the request
648648
)
649-
650-
slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) {
649+
slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
651650
start <- struct{}{}
652651
select {
653652
case <-ctx.Done():
@@ -658,7 +657,7 @@ func TestCancellation(t *testing.T) {
658657
return nil, nil
659658
}
660659
_, cs := basicConnection(t, func(s *Server) {
661-
s.AddTool(&Tool{Name: "slow", InputSchema: &jsonschema.Schema{}}, slowRequest)
660+
AddTool(s, &Tool{Name: "slow"}, slowRequest)
662661
})
663662
defer cs.Close()
664663

mcp/server.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ var serverMethodInfos = map[string]methodInfo{
760760
methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0),
761761
methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0),
762762
methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0),
763+
notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK),
763764
notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK),
764765
notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK),
765766
notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification),
@@ -838,6 +839,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error
838839
return &emptyResult{}, nil
839840
}
840841

842+
// cancel is a placeholder: cancellation is handled the jsonrpc2 package.
843+
//
844+
// It should never be invoked in practice because cancellation is preempted,
845+
// but having its signature here facilitates the construction of methodInfo
846+
// that can be used to validate incoming cancellation notifications.
847+
func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) {
848+
return nil, nil
849+
}
850+
841851
func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) {
842852
ss.updateState(func(state *ServerSessionState) {
843853
state.LogLevel = params.Level

mcp/streamable.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
118118
return
119119
}
120120

121-
var session *StreamableServerTransport
121+
var transport *StreamableServerTransport
122122
if id := req.Header.Get(sessionIDHeader); id != "" {
123123
h.mu.Lock()
124-
session, _ = h.transports[id]
124+
transport, _ = h.transports[id]
125125
h.mu.Unlock()
126-
if session == nil {
126+
if transport == nil {
127127
http.Error(w, "session not found", http.StatusNotFound)
128128
return
129129
}
@@ -132,22 +132,22 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
132132
// TODO(rfindley): simplify the locking so that each request has only one
133133
// critical section.
134134
if req.Method == http.MethodDelete {
135-
if session == nil {
135+
if transport == nil {
136136
// => Mcp-Session-Id was not set; else we'd have returned NotFound above.
137137
http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
138138
return
139139
}
140140
h.mu.Lock()
141-
delete(h.transports, session.SessionID)
141+
delete(h.transports, transport.SessionID)
142142
h.mu.Unlock()
143-
session.connection.Close()
143+
transport.connection.Close()
144144
w.WriteHeader(http.StatusNoContent)
145145
return
146146
}
147147

148148
switch req.Method {
149149
case http.MethodPost, http.MethodGet:
150-
if req.Method == http.MethodGet && session == nil {
150+
if req.Method == http.MethodGet && transport == nil {
151151
http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed)
152152
return
153153
}
@@ -157,7 +157,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
157157
return
158158
}
159159

160-
if session == nil {
160+
if transport == nil {
161161
server := h.getServer(req)
162162
if server == nil {
163163
// The getServer argument to NewStreamableHTTPHandler returned nil.
@@ -194,10 +194,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
194194
h.transports[s.SessionID] = s
195195
h.mu.Unlock()
196196
}
197-
session = s
197+
transport = s
198198
}
199199

200-
session.ServeHTTP(w, req)
200+
transport.ServeHTTP(w, req)
201201
}
202202

203203
// StreamableServerTransportOptions configures the stramable server transport.

mcp/streamable_test.go

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,26 @@ func TestStreamableTransports(t *testing.T) {
3737

3838
for _, useJSON := range []bool{false, true} {
3939
t.Run(fmt.Sprintf("JSONResponse=%v", useJSON), func(t *testing.T) {
40-
// 1. Create a server with a simple "greet" tool.
40+
// Create a server with some simple tools.
4141
server := NewServer(testImpl, nil)
4242
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
43+
// The "hang" tool checks that context cancellation is propagated.
44+
// It hangs until the context is cancelled.
45+
var (
46+
start = make(chan struct{})
47+
cancelled = make(chan struct{}, 1) // don't block the request
48+
)
49+
hang := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
50+
start <- struct{}{}
51+
select {
52+
case <-ctx.Done():
53+
cancelled <- struct{}{}
54+
case <-time.After(5 * time.Second):
55+
return nil, nil
56+
}
57+
return nil, nil
58+
}
59+
AddTool(server, &Tool{Name: "hang"}, hang)
4360
AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
4461
// Test that we can make sampling requests during tool handling.
4562
//
@@ -60,7 +77,7 @@ func TestStreamableTransports(t *testing.T) {
6077
return &CallToolResultFor[any]{}, nil
6178
})
6279

63-
// 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
80+
// Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
6481
// cookie-checking middleware.
6582
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{
6683
jsonResponse: useJSON,
@@ -84,7 +101,7 @@ func TestStreamableTransports(t *testing.T) {
84101
}))
85102
defer httpServer.Close()
86103

87-
// 3. Create a client and connect it to the server using our StreamableClientTransport.
104+
// Create a client and connect it to the server using our StreamableClientTransport.
88105
// Check that all requests honor a custom client.
89106
jar, err := cookiejar.New(nil)
90107
if err != nil {
@@ -117,10 +134,13 @@ func TestStreamableTransports(t *testing.T) {
117134
if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w {
118135
t.Fatalf("got protocol version %q, want %q", g, w)
119136
}
120-
// 4. The client calls the "greet" tool.
137+
138+
// Verify the behavior of various tools.
139+
140+
// The "greet" tool should just work.
121141
params := &CallToolParams{
122142
Name: "greet",
123-
Arguments: map[string]any{"name": "streamy"},
143+
Arguments: map[string]any{"name": "foo"},
124144
}
125145
got, err := session.CallTool(ctx, params)
126146
if err != nil {
@@ -132,19 +152,26 @@ func TestStreamableTransports(t *testing.T) {
132152
if g, w := lastHeader.Get(protocolVersionHeader), latestProtocolVersion; g != w {
133153
t.Errorf("got protocol version header %q, want %q", g, w)
134154
}
135-
136-
// 5. Verify that the correct response is received.
137155
want := &CallToolResult{
138-
Content: []Content{
139-
&TextContent{Text: "hi streamy"},
140-
},
156+
Content: []Content{&TextContent{Text: "hi foo"}},
141157
}
142158
if diff := cmp.Diff(want, got); diff != "" {
143159
t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff)
144160
}
145161

146-
// 6. Run the "sampling" tool and verify that the streamable server can
147-
// call tools.
162+
// The "hang" tool should be cancellable.
163+
ctx2, cancel := context.WithCancel(context.Background())
164+
go session.CallTool(ctx2, &CallToolParams{Name: "hang"})
165+
<-start
166+
cancel()
167+
select {
168+
case <-cancelled:
169+
case <-time.After(5 * time.Second):
170+
t.Fatal("timeout waiting for cancellation")
171+
}
172+
173+
// The "sampling" tool should be able to issue sampling requests during
174+
// tool operation.
148175
result, err := session.CallTool(ctx, &CallToolParams{
149176
Name: "sample",
150177
Arguments: map[string]any{},

mcp/transport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ type canceller struct {
171171

172172
// Preempt implements [jsonrpc2.Preempter].
173173
func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) {
174-
if req.Method == "notifications/cancelled" {
174+
if req.Method == notificationCancelled {
175175
var params CancelledParams
176176
if err := json.Unmarshal(req.Params, &params); err != nil {
177177
return nil, err

0 commit comments

Comments
 (0)