Skip to content

Commit 728e0e3

Browse files
committed
mcp: improve error messages from Wait for streamable clients
Previously, the error message received from ClientSession.Wait would only report the closeErr, which would often be nil even if the client transport was broken. Wait should return the reason the session terminated, if abnormal. I'm not sure of the exact semantics of this, but surely returning nil is less useful than returning a meaningful non-nil error. We can refine our handling of errors once we have more feedback. Also add a test for client termination on HTTP server shutdown, described in #265. This should work as long as (1) the session is stateful (with a hanging GET), or (2) the session is stateless but the client has a keepalive ping. Also: don't send DELETE if the session was terminated with 404; +test. Fixes #265
1 parent b1c75f0 commit 728e0e3

File tree

4 files changed

+202
-21
lines changed

4 files changed

+202
-21
lines changed

internal/jsonrpc2/conn.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,32 @@ func (c *Connection) Cancel(id ID) {
483483

484484
// Wait blocks until the connection is fully closed, but does not close it.
485485
func (c *Connection) Wait() error {
486+
return c.wait(true)
487+
}
488+
489+
// wait for the connection to close, and aggregates the most cause of its
490+
// termination, if abnormal.
491+
//
492+
// The fromWait argument allows this logic to be shared with Close, where we
493+
// only want to expose the closeErr.
494+
//
495+
// (Previously, Wait also only returned the closeErr, which was misleading if
496+
// the connection was broken for another reason).
497+
func (c *Connection) wait(fromWait bool) error {
486498
var err error
487499
<-c.done
488500
c.updateInFlight(func(s *inFlightState) {
489-
err = s.closeErr
501+
if fromWait {
502+
if !errors.Is(s.readErr, io.EOF) {
503+
err = s.readErr
504+
}
505+
if err == nil && !errors.Is(s.writeErr, io.EOF) {
506+
err = s.writeErr
507+
}
508+
}
509+
if err == nil {
510+
err = s.closeErr
511+
}
490512
})
491513
return err
492514
}
@@ -502,8 +524,7 @@ func (c *Connection) Close() error {
502524
// Stop handling new requests, and interrupt the reader (by closing the
503525
// connection) as soon as the active requests finish.
504526
c.updateInFlight(func(s *inFlightState) { s.connClosing = true })
505-
506-
return c.Wait()
527+
return c.wait(false)
507528
}
508529

509530
// readIncoming collects inbound messages from the reader and delivers them, either responding

mcp/streamable.go

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,17 @@ type streamableClientConn struct {
10661066
sessionID string
10671067
}
10681068

1069+
// errSessionMissing distinguishes if the session is known to not be present on
1070+
// the server (see [streamableClientConn.fail]).
1071+
//
1072+
// TODO(rfindley): should we expose this error value (and its corresponding
1073+
// API) to the user?
1074+
//
1075+
// The spec says that if the server returns 404, clients should reestablish
1076+
// a session. For now, we delegate that to the user, but do they need a way to
1077+
// differentiate a 'NotFound' error from other errors?
1078+
var errSessionMissing = errors.New("session not found")
1079+
10691080
var _ clientConnection = (*streamableClientConn)(nil)
10701081

10711082
func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
@@ -1093,6 +1104,10 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
10931104
//
10941105
// If err is non-nil, it is terminal, and subsequent (or pending) Reads will
10951106
// fail.
1107+
//
1108+
// If err wraps errSessionMissing, the failure indicates that the session is no
1109+
// longer present on the server, and no final DELETE will be performed when
1110+
// closing the connection.
10961111
func (c *streamableClientConn) fail(err error) {
10971112
if err != nil {
10981113
c.failOnce.Do(func() {
@@ -1140,9 +1155,19 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
11401155
return err
11411156
}
11421157

1158+
var requestSummary string
1159+
switch msg := msg.(type) {
1160+
case *jsonrpc.Request:
1161+
requestSummary = fmt.Sprintf("sending %q", msg.Method)
1162+
case *jsonrpc.Response:
1163+
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
1164+
default:
1165+
panic("unreachable")
1166+
}
1167+
11431168
data, err := jsonrpc.EncodeMessage(msg)
11441169
if err != nil {
1145-
return err
1170+
return fmt.Errorf("%s: %v", requestSummary, err)
11461171
}
11471172

11481173
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data))
@@ -1155,9 +1180,21 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
11551180

11561181
resp, err := c.client.Do(req)
11571182
if err != nil {
1183+
return fmt.Errorf("%s: %v", requestSummary, err)
1184+
}
1185+
1186+
// Section 2.5.3: "The server MAY terminate the session at any time, after
1187+
// which it MUST respond to requests containing that session ID with HTTP
1188+
// 404 Not Found."
1189+
if resp.StatusCode == http.StatusNotFound {
1190+
// Fail the session immediately, rather than relying on jsonrpc2 to fail
1191+
// (and close) it, because we want the call to Close to know that this
1192+
// session is missing (and therefore not send the DELETE).
1193+
err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing)
1194+
c.fail(err)
1195+
resp.Body.Close()
11581196
return err
11591197
}
1160-
11611198
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
11621199
resp.Body.Close()
11631200
return fmt.Errorf("broken session: %v", resp.Status)
@@ -1180,16 +1217,6 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
11801217
return nil
11811218
}
11821219

1183-
var requestSummary string
1184-
switch msg := msg.(type) {
1185-
case *jsonrpc.Request:
1186-
requestSummary = fmt.Sprintf("sending %q", msg.Method)
1187-
case *jsonrpc.Response:
1188-
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
1189-
default:
1190-
panic("unreachable")
1191-
}
1192-
11931220
switch ct := resp.Header.Get("Content-Type"); ct {
11941221
case "application/json":
11951222
go c.handleJSON(requestSummary, resp)
@@ -1280,6 +1307,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
12801307
resp.Body.Close()
12811308
return
12821309
}
1310+
// (see equivalent handling in [streamableClientConn.Write]).
1311+
if resp.StatusCode == http.StatusNotFound {
1312+
c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing))
1313+
return
1314+
}
12831315
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
12841316
resp.Body.Close()
12851317
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode)))
@@ -1370,13 +1402,17 @@ func (c *streamableClientConn) Close() error {
13701402
c.cancel()
13711403
close(c.done)
13721404

1373-
req, err := http.NewRequest(http.MethodDelete, c.url, nil)
1374-
if err != nil {
1375-
c.closeErr = err
1405+
if errors.Is(c.failure(), errSessionMissing) {
1406+
// If the session is missing, no need to delete it.
13761407
} else {
1377-
c.setMCPHeaders(req)
1378-
if _, err := c.client.Do(req); err != nil {
1408+
req, err := http.NewRequest(http.MethodDelete, c.url, nil)
1409+
if err != nil {
13791410
c.closeErr = err
1411+
} else {
1412+
c.setMCPHeaders(req)
1413+
if _, err := c.client.Do(req); err != nil {
1414+
c.closeErr = err
1415+
}
13801416
}
13811417
}
13821418
})

mcp/streamable_client_test.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ type streamableRequestKey struct {
2929
type header map[string]string
3030

3131
type streamableResponse struct {
32-
header header
32+
header header // response headers
3333
status int // or http.StatusOK
3434
body string // or ""
3535
optional bool // if set, request need not be sent
@@ -187,6 +187,56 @@ func TestStreamableClientTransportLifecycle(t *testing.T) {
187187
}
188188
}
189189

190+
func TestStreamableClientRedundantDelete(t *testing.T) {
191+
ctx := context.Background()
192+
193+
// The lifecycle test verifies various behavior of the streamable client
194+
// initialization:
195+
// - check that it can handle application/json responses
196+
// - check that it sends the negotiated protocol version
197+
fake := &fakeStreamableServer{
198+
t: t,
199+
responses: fakeResponses{
200+
{"POST", "", methodInitialize}: {
201+
header: header{
202+
"Content-Type": "application/json",
203+
sessionIDHeader: "123",
204+
},
205+
body: jsonBody(t, initResp),
206+
},
207+
{"POST", "123", notificationInitialized}: {
208+
status: http.StatusAccepted,
209+
wantProtocolVersion: latestProtocolVersion,
210+
},
211+
{"GET", "123", ""}: {
212+
status: http.StatusMethodNotAllowed,
213+
optional: true,
214+
},
215+
{"POST", "123", methodListTools}: {
216+
status: http.StatusNotFound,
217+
},
218+
},
219+
}
220+
221+
httpServer := httptest.NewServer(fake)
222+
defer httpServer.Close()
223+
224+
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
225+
client := NewClient(testImpl, nil)
226+
session, err := client.Connect(ctx, transport, nil)
227+
if err != nil {
228+
t.Fatalf("client.Connect() failed: %v", err)
229+
}
230+
_, err = session.ListTools(ctx, nil)
231+
if err == nil {
232+
t.Errorf("Listing tools: got nil error, want non-nil")
233+
}
234+
_ = session.Wait() // must not hang
235+
if missing := fake.missingRequests(); len(missing) > 0 {
236+
t.Errorf("did not receive expected requests: %v", missing)
237+
}
238+
}
239+
190240
func TestStreamableClientGETHandling(t *testing.T) {
191241
ctx := context.Background()
192242

mcp/streamable_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,80 @@ func TestStreamableTransports(t *testing.T) {
190190
}
191191
}
192192

193+
func TestStreamableServerShutdown(t *testing.T) {
194+
ctx := context.Background()
195+
196+
// This test checks that closing the streamable HTTP server actually results
197+
// in client session termination, provided one of following holds:
198+
// 1. The server is stateful, and therefore the hanging GET fails the connection.
199+
// 2. The server is stateless, and the client uses a KeepAlive.
200+
tests := []struct {
201+
name string
202+
stateless, keepalive bool
203+
}{
204+
{"stateful", false, false},
205+
{"stateless with keepalive", true, true},
206+
}
207+
208+
for _, test := range tests {
209+
t.Run(test.name, func(t *testing.T) {
210+
server := NewServer(testImpl, nil)
211+
// Add a tool, just so we can check things are working.
212+
AddTool(server, &Tool{Name: "greet"}, sayHi)
213+
214+
handler := NewStreamableHTTPHandler(
215+
func(req *http.Request) *Server { return server },
216+
&StreamableHTTPOptions{Stateless: test.stateless})
217+
218+
// When we shut down the server, we need to explicitly close ongoing
219+
// connections. Otherwise, the hanging GET may never terminate.
220+
httpServer := httptest.NewUnstartedServer(handler)
221+
httpServer.Config.RegisterOnShutdown(func() {
222+
for session := range server.Sessions() {
223+
session.Close()
224+
}
225+
})
226+
httpServer.Start()
227+
defer httpServer.Close()
228+
229+
// Connect and run a tool.
230+
var opts ClientOptions
231+
if test.keepalive {
232+
opts.KeepAlive = 50 * time.Millisecond
233+
}
234+
client := NewClient(testImpl, &opts)
235+
clientSession, err := client.Connect(ctx, &StreamableClientTransport{
236+
Endpoint: httpServer.URL,
237+
MaxRetries: -1, // avoid slow tests during exponential retries
238+
}, nil)
239+
if err != nil {
240+
t.Fatal(err)
241+
}
242+
243+
params := &CallToolParams{
244+
Name: "greet",
245+
Arguments: map[string]any{"name": "foo"},
246+
}
247+
// Verify that we can call a tool.
248+
if _, err := clientSession.CallTool(ctx, params); err != nil {
249+
t.Fatalf("CallTool() failed: %v", err)
250+
}
251+
252+
// Shut down the server. Sessions should terminate.
253+
go func() {
254+
if err := httpServer.Config.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
255+
t.Errorf("closing http server: %v", err)
256+
}
257+
}()
258+
259+
// Wait may return an error (after all, the connection failed), but it
260+
// should not hang.
261+
t.Log("Client waiting")
262+
_ = clientSession.Wait()
263+
})
264+
}
265+
}
266+
193267
// TestClientReplay verifies that the client can recover from a mid-stream
194268
// network failure and receive replayed messages (if replay is configured). It
195269
// uses a proxy that is killed and restarted to simulate a recoverable network

0 commit comments

Comments
 (0)