Skip to content

Commit a75c670

Browse files
committed
mcp: improve error messages from Wait for streamable clients
Previously, the error message received from ClientSession.Wait would only report the closeErr, which would often be nil even if the client transport was broken. Wait should return the reason the session terminated, if abnormal. I'm not sure of the exact semantics of this, but surely returning nil is less useful than returning a meaningful non-nil error. We can refine our handling of errors once we have more feedback. Also add a test for client termination on HTTP server shutdown, described in #265. This should work as long as (1) the session is stateful (with a hanging GET), or (2) the session is stateless but the client has a keepalive ping. Also: don't send DELETE if the session was terminated with 404; +test. Fixes #265
1 parent 4f197bc commit a75c670

File tree

4 files changed

+202
-21
lines changed

4 files changed

+202
-21
lines changed

internal/jsonrpc2/conn.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,32 @@ func (c *Connection) Cancel(id ID) {
483483

484484
// Wait blocks until the connection is fully closed, but does not close it.
485485
func (c *Connection) Wait() error {
486+
return c.wait(true)
487+
}
488+
489+
// wait for the connection to close, and aggregates the most cause of its
490+
// termination, if abnormal.
491+
//
492+
// The fromWait argument allows this logic to be shared with Close, where we
493+
// only want to expose the closeErr.
494+
//
495+
// (Previously, Wait also only returned the closeErr, which was misleading if
496+
// the connection was broken for another reason).
497+
func (c *Connection) wait(fromWait bool) error {
486498
var err error
487499
<-c.done
488500
c.updateInFlight(func(s *inFlightState) {
489-
err = s.closeErr
501+
if fromWait {
502+
if !errors.Is(s.readErr, io.EOF) {
503+
err = s.readErr
504+
}
505+
if err == nil && !errors.Is(s.writeErr, io.EOF) {
506+
err = s.writeErr
507+
}
508+
}
509+
if err == nil {
510+
err = s.closeErr
511+
}
490512
})
491513
return err
492514
}
@@ -502,8 +524,7 @@ func (c *Connection) Close() error {
502524
// Stop handling new requests, and interrupt the reader (by closing the
503525
// connection) as soon as the active requests finish.
504526
c.updateInFlight(func(s *inFlightState) { s.connClosing = true })
505-
506-
return c.Wait()
527+
return c.wait(false)
507528
}
508529

509530
// readIncoming collects inbound messages from the reader and delivers them, either responding

mcp/streamable.go

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,17 @@ type streamableClientConn struct {
11191119
sessionID string
11201120
}
11211121

1122+
// errSessionMissing distinguishes if the session is known to not be present on
1123+
// the server (see [streamableClientConn.fail]).
1124+
//
1125+
// TODO(rfindley): should we expose this error value (and its corresponding
1126+
// API) to the user?
1127+
//
1128+
// The spec says that if the server returns 404, clients should reestablish
1129+
// a session. For now, we delegate that to the user, but do they need a way to
1130+
// differentiate a 'NotFound' error from other errors?
1131+
var errSessionMissing = errors.New("session not found")
1132+
11221133
var _ clientConnection = (*streamableClientConn)(nil)
11231134

11241135
func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
@@ -1146,6 +1157,10 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
11461157
//
11471158
// If err is non-nil, it is terminal, and subsequent (or pending) Reads will
11481159
// fail.
1160+
//
1161+
// If err wraps errSessionMissing, the failure indicates that the session is no
1162+
// longer present on the server, and no final DELETE will be performed when
1163+
// closing the connection.
11491164
func (c *streamableClientConn) fail(err error) {
11501165
if err != nil {
11511166
c.failOnce.Do(func() {
@@ -1193,9 +1208,19 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
11931208
return err
11941209
}
11951210

1211+
var requestSummary string
1212+
switch msg := msg.(type) {
1213+
case *jsonrpc.Request:
1214+
requestSummary = fmt.Sprintf("sending %q", msg.Method)
1215+
case *jsonrpc.Response:
1216+
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
1217+
default:
1218+
panic("unreachable")
1219+
}
1220+
11961221
data, err := jsonrpc.EncodeMessage(msg)
11971222
if err != nil {
1198-
return err
1223+
return fmt.Errorf("%s: %v", requestSummary, err)
11991224
}
12001225

12011226
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data))
@@ -1208,9 +1233,21 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
12081233

12091234
resp, err := c.client.Do(req)
12101235
if err != nil {
1236+
return fmt.Errorf("%s: %v", requestSummary, err)
1237+
}
1238+
1239+
// Section 2.5.3: "The server MAY terminate the session at any time, after
1240+
// which it MUST respond to requests containing that session ID with HTTP
1241+
// 404 Not Found."
1242+
if resp.StatusCode == http.StatusNotFound {
1243+
// Fail the session immediately, rather than relying on jsonrpc2 to fail
1244+
// (and close) it, because we want the call to Close to know that this
1245+
// session is missing (and therefore not send the DELETE).
1246+
err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing)
1247+
c.fail(err)
1248+
resp.Body.Close()
12111249
return err
12121250
}
1213-
12141251
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
12151252
resp.Body.Close()
12161253
return fmt.Errorf("broken session: %v", resp.Status)
@@ -1233,16 +1270,6 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
12331270
return nil
12341271
}
12351272

1236-
var requestSummary string
1237-
switch msg := msg.(type) {
1238-
case *jsonrpc.Request:
1239-
requestSummary = fmt.Sprintf("sending %q", msg.Method)
1240-
case *jsonrpc.Response:
1241-
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
1242-
default:
1243-
panic("unreachable")
1244-
}
1245-
12461273
switch ct := resp.Header.Get("Content-Type"); ct {
12471274
case "application/json":
12481275
go c.handleJSON(requestSummary, resp)
@@ -1333,6 +1360,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
13331360
resp.Body.Close()
13341361
return
13351362
}
1363+
// (see equivalent handling in [streamableClientConn.Write]).
1364+
if resp.StatusCode == http.StatusNotFound {
1365+
c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing))
1366+
return
1367+
}
13361368
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
13371369
resp.Body.Close()
13381370
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode)))
@@ -1423,13 +1455,17 @@ func (c *streamableClientConn) Close() error {
14231455
c.cancel()
14241456
close(c.done)
14251457

1426-
req, err := http.NewRequest(http.MethodDelete, c.url, nil)
1427-
if err != nil {
1428-
c.closeErr = err
1458+
if errors.Is(c.failure(), errSessionMissing) {
1459+
// If the session is missing, no need to delete it.
14291460
} else {
1430-
c.setMCPHeaders(req)
1431-
if _, err := c.client.Do(req); err != nil {
1461+
req, err := http.NewRequest(http.MethodDelete, c.url, nil)
1462+
if err != nil {
14321463
c.closeErr = err
1464+
} else {
1465+
c.setMCPHeaders(req)
1466+
if _, err := c.client.Do(req); err != nil {
1467+
c.closeErr = err
1468+
}
14331469
}
14341470
}
14351471
})

mcp/streamable_client_test.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ type streamableRequestKey struct {
2929
type header map[string]string
3030

3131
type streamableResponse struct {
32-
header header
32+
header header // response headers
3333
status int // or http.StatusOK
3434
body string // or ""
3535
optional bool // if set, request need not be sent
@@ -187,6 +187,56 @@ func TestStreamableClientTransportLifecycle(t *testing.T) {
187187
}
188188
}
189189

190+
func TestStreamableClientRedundantDelete(t *testing.T) {
191+
ctx := context.Background()
192+
193+
// The lifecycle test verifies various behavior of the streamable client
194+
// initialization:
195+
// - check that it can handle application/json responses
196+
// - check that it sends the negotiated protocol version
197+
fake := &fakeStreamableServer{
198+
t: t,
199+
responses: fakeResponses{
200+
{"POST", "", methodInitialize}: {
201+
header: header{
202+
"Content-Type": "application/json",
203+
sessionIDHeader: "123",
204+
},
205+
body: jsonBody(t, initResp),
206+
},
207+
{"POST", "123", notificationInitialized}: {
208+
status: http.StatusAccepted,
209+
wantProtocolVersion: latestProtocolVersion,
210+
},
211+
{"GET", "123", ""}: {
212+
status: http.StatusMethodNotAllowed,
213+
optional: true,
214+
},
215+
{"POST", "123", methodListTools}: {
216+
status: http.StatusNotFound,
217+
},
218+
},
219+
}
220+
221+
httpServer := httptest.NewServer(fake)
222+
defer httpServer.Close()
223+
224+
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
225+
client := NewClient(testImpl, nil)
226+
session, err := client.Connect(ctx, transport, nil)
227+
if err != nil {
228+
t.Fatalf("client.Connect() failed: %v", err)
229+
}
230+
_, err = session.ListTools(ctx, nil)
231+
if err == nil {
232+
t.Errorf("Listing tools: got nil error, want non-nil")
233+
}
234+
_ = session.Wait() // must not hang
235+
if missing := fake.missingRequests(); len(missing) > 0 {
236+
t.Errorf("did not receive expected requests: %v", missing)
237+
}
238+
}
239+
190240
func TestStreamableClientGETHandling(t *testing.T) {
191241
ctx := context.Background()
192242

mcp/streamable_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,80 @@ func TestStreamableTransports(t *testing.T) {
190190
}
191191
}
192192

193+
func TestStreamableServerShutdown(t *testing.T) {
194+
ctx := context.Background()
195+
196+
// This test checks that closing the streamable HTTP server actually results
197+
// in client session termination, provided one of following holds:
198+
// 1. The server is stateful, and therefore the hanging GET fails the connection.
199+
// 2. The server is stateless, and the client uses a KeepAlive.
200+
tests := []struct {
201+
name string
202+
stateless, keepalive bool
203+
}{
204+
{"stateful", false, false},
205+
{"stateless with keepalive", true, true},
206+
}
207+
208+
for _, test := range tests {
209+
t.Run(test.name, func(t *testing.T) {
210+
server := NewServer(testImpl, nil)
211+
// Add a tool, just so we can check things are working.
212+
AddTool(server, &Tool{Name: "greet"}, sayHi)
213+
214+
handler := NewStreamableHTTPHandler(
215+
func(req *http.Request) *Server { return server },
216+
&StreamableHTTPOptions{Stateless: test.stateless})
217+
218+
// When we shut down the server, we need to explicitly close ongoing
219+
// connections. Otherwise, the hanging GET may never terminate.
220+
httpServer := httptest.NewUnstartedServer(handler)
221+
httpServer.Config.RegisterOnShutdown(func() {
222+
for session := range server.Sessions() {
223+
session.Close()
224+
}
225+
})
226+
httpServer.Start()
227+
defer httpServer.Close()
228+
229+
// Connect and run a tool.
230+
var opts ClientOptions
231+
if test.keepalive {
232+
opts.KeepAlive = 50 * time.Millisecond
233+
}
234+
client := NewClient(testImpl, &opts)
235+
clientSession, err := client.Connect(ctx, &StreamableClientTransport{
236+
Endpoint: httpServer.URL,
237+
MaxRetries: -1, // avoid slow tests during exponential retries
238+
}, nil)
239+
if err != nil {
240+
t.Fatal(err)
241+
}
242+
243+
params := &CallToolParams{
244+
Name: "greet",
245+
Arguments: map[string]any{"name": "foo"},
246+
}
247+
// Verify that we can call a tool.
248+
if _, err := clientSession.CallTool(ctx, params); err != nil {
249+
t.Fatalf("CallTool() failed: %v", err)
250+
}
251+
252+
// Shut down the server. Sessions should terminate.
253+
go func() {
254+
if err := httpServer.Config.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
255+
t.Errorf("closing http server: %v", err)
256+
}
257+
}()
258+
259+
// Wait may return an error (after all, the connection failed), but it
260+
// should not hang.
261+
t.Log("Client waiting")
262+
_ = clientSession.Wait()
263+
})
264+
}
265+
}
266+
193267
// TestClientReplay verifies that the client can recover from a mid-stream
194268
// network failure and receive replayed messages (if replay is configured). It
195269
// uses a proxy that is killed and restarted to simulate a recoverable network

0 commit comments

Comments
 (0)