Skip to content

Commit 2c40bdc

Browse files
committed
mcp: systematically improve streamable client errors
The streamable client connection can break for a variety of reasons, asynchronously to the client's request. Decorate these failures with additional context to clarify why they occurred. Add a test for the failure message of #393. Fixes #393
1 parent a4313f9 commit 2c40bdc

File tree

3 files changed

+114
-21
lines changed

3 files changed

+114
-21
lines changed

mcp/streamable.go

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
11301130
// § 2.5: A server using the Streamable HTTP transport MAY assign a session
11311131
// ID at initialization time, by including it in an Mcp-Session-Id header
11321132
// on the HTTP response containing the InitializeResult.
1133-
go c.handleSSE(nil, true, nil)
1133+
go c.handleSSE("hanging GET", nil, true, nil)
11341134
}
11351135

11361136
// fail handles an asynchronous error while reading.
@@ -1224,24 +1224,27 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
12241224
return nil
12251225
}
12261226

1227+
var requestSummary string
1228+
switch msg := msg.(type) {
1229+
case *jsonrpc.Request:
1230+
requestSummary = fmt.Sprintf("sending %q", msg.Method)
1231+
case *jsonrpc.Response:
1232+
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
1233+
default:
1234+
panic("unreachable")
1235+
}
1236+
12271237
switch ct := resp.Header.Get("Content-Type"); ct {
12281238
case "application/json":
1229-
go c.handleJSON(resp)
1239+
go c.handleJSON(requestSummary, resp)
12301240

12311241
case "text/event-stream":
12321242
jsonReq, _ := msg.(*jsonrpc.Request)
1233-
go c.handleSSE(resp, false, jsonReq)
1243+
go c.handleSSE(requestSummary, resp, false, jsonReq)
12341244

12351245
default:
12361246
resp.Body.Close()
1237-
switch msg := msg.(type) {
1238-
case *jsonrpc.Request:
1239-
return fmt.Errorf("unsupported content type %q when sending %q (status: %d)", ct, msg.Method, resp.StatusCode)
1240-
case *jsonrpc.Response:
1241-
return fmt.Errorf("unsupported content type %q when sending jsonrpc response #%d (status: %d)", ct, msg.ID, resp.StatusCode)
1242-
default:
1243-
panic("unreachable")
1244-
}
1247+
return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct)
12451248
}
12461249
return nil
12471250
}
@@ -1265,16 +1268,16 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
12651268
}
12661269
}
12671270

1268-
func (c *streamableClientConn) handleJSON(resp *http.Response) {
1271+
func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) {
12691272
body, err := io.ReadAll(resp.Body)
12701273
resp.Body.Close()
12711274
if err != nil {
1272-
c.fail(err)
1275+
c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err))
12731276
return
12741277
}
12751278
msg, err := jsonrpc.DecodeMessage(body)
12761279
if err != nil {
1277-
c.fail(fmt.Errorf("failed to decode response: %v", err))
1280+
c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err))
12781281
return
12791282
}
12801283
select {
@@ -1289,12 +1292,12 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) {
12891292
//
12901293
// If forReq is set, it is the request that initiated the stream, and the
12911294
// stream is complete when we receive its response.
1292-
func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) {
1295+
func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) {
12931296
resp := initialResp
12941297
var lastEventID string
12951298
for {
12961299
if resp != nil {
1297-
eventID, clientClosed := c.processStream(resp, forReq)
1300+
eventID, clientClosed := c.processStream(requestSummary, resp, forReq)
12981301
lastEventID = eventID
12991302

13001303
// If the connection was closed by the client, we're done.
@@ -1312,7 +1315,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
13121315
newResp, err := c.reconnect(lastEventID)
13131316
if err != nil {
13141317
// All reconnection attempts failed: fail the connection.
1315-
c.fail(err)
1318+
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err))
13161319
return
13171320
}
13181321
resp = newResp
@@ -1323,7 +1326,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
13231326
}
13241327
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
13251328
resp.Body.Close()
1326-
c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode)))
1329+
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode)))
13271330
return
13281331
}
13291332
// Reconnection was successful. Continue the loop with the new response.
@@ -1334,7 +1337,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
13341337
// incoming channel. It returns the ID of the last processed event and a flag
13351338
// indicating if the connection was closed by the client. If resp is nil, it
13361339
// returns "", false.
1337-
func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) {
1340+
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) {
13381341
defer resp.Body.Close()
13391342
for evt, err := range scanEvents(resp.Body) {
13401343
if err != nil {
@@ -1347,7 +1350,7 @@ func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrp
13471350

13481351
msg, err := jsonrpc.DecodeMessage(evt.Data)
13491352
if err != nil {
1350-
c.fail(fmt.Errorf("failed to decode event: %v", err))
1353+
c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err))
13511354
return "", true
13521355
}
13531356

mcp/streamable_client_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ package mcp
66

77
import (
88
"context"
9+
"fmt"
910
"io"
1011
"net/http"
1112
"net/http/httptest"
13+
"strings"
1214
"sync"
1315
"testing"
16+
"time"
1417

1518
"github.com/google/go-cmp/cmp"
1619
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
@@ -183,3 +186,90 @@ func TestStreamableClientTransportLifecycle(t *testing.T) {
183186
t.Errorf("mismatch (-want, +got):\n%s", diff)
184187
}
185188
}
189+
190+
func TestStreamableClientGETHandling(t *testing.T) {
191+
ctx := context.Background()
192+
193+
tests := []struct {
194+
status int
195+
wantErrorContaining string
196+
}{
197+
{http.StatusOK, ""},
198+
{http.StatusMethodNotAllowed, ""},
199+
{http.StatusBadRequest, "hanging GET"},
200+
}
201+
202+
for _, test := range tests {
203+
t.Run(fmt.Sprintf("status=%d", test.status), func(t *testing.T) {
204+
fake := &fakeStreamableServer{
205+
t: t,
206+
responses: fakeResponses{
207+
{"POST", "", methodInitialize}: {
208+
header: header{
209+
"Content-Type": "application/json",
210+
sessionIDHeader: "123",
211+
},
212+
body: jsonBody(t, initResp),
213+
},
214+
{"POST", "123", notificationInitialized}: {
215+
status: http.StatusAccepted,
216+
wantProtocolVersion: latestProtocolVersion,
217+
},
218+
{"GET", "123", ""}: {
219+
header: header{
220+
"Content-Type": "text/event-stream",
221+
},
222+
status: test.status,
223+
wantProtocolVersion: latestProtocolVersion,
224+
},
225+
{"POST", "123", methodListTools}: {
226+
header: header{
227+
"Content-Type": "application/json",
228+
sessionIDHeader: "123",
229+
},
230+
body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)),
231+
optional: true,
232+
},
233+
{"DELETE", "123", ""}: {optional: true},
234+
},
235+
}
236+
httpServer := httptest.NewServer(fake)
237+
defer httpServer.Close()
238+
239+
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
240+
client := NewClient(testImpl, nil)
241+
session, err := client.Connect(ctx, transport, nil)
242+
if err != nil {
243+
t.Fatalf("client.Connect() failed: %v", err)
244+
}
245+
246+
// wait for all required requests to be handled, with exponential
247+
// backoff.
248+
start := time.Now()
249+
delay := 1 * time.Millisecond
250+
for range 10 {
251+
if len(fake.missingRequests()) == 0 {
252+
break
253+
}
254+
time.Sleep(delay)
255+
delay *= 2
256+
}
257+
if missing := fake.missingRequests(); len(missing) > 0 {
258+
t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing)
259+
}
260+
261+
_, err = session.ListTools(ctx, nil)
262+
if (err != nil) != (test.wantErrorContaining != "") {
263+
t.Errorf("After initialization, got error %v, want %v", err, test.wantErrorContaining)
264+
} else if err != nil {
265+
if !strings.Contains(err.Error(), test.wantErrorContaining) {
266+
t.Errorf("After initialization, got error %s, want containing %q", err, test.wantErrorContaining)
267+
}
268+
}
269+
270+
if err := session.Close(); err != nil {
271+
t.Errorf("closing session: %v", err)
272+
}
273+
})
274+
}
275+
}

mcp/transport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params
194194
err := call.Await(ctx, result)
195195
switch {
196196
case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing):
197-
return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed)
197+
return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err)
198198
case ctx.Err() != nil:
199199
// Notify the peer of cancellation.
200200
err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{

0 commit comments

Comments
 (0)