Skip to content

Commit 5aab47f

Browse files
committed
mcp: make the streamable client transport less strict
In practice, the streamable client transport was having trouble connecting to various backends, because they are nonconformant in various ways. While it would be nice if all servers conformed to the spec, in practice there are certain spec violations that are recoverable, and we can and should recover them. Specifically: - tolerate 404 instead of 405 for the hanging GET (#393) - tolerate (=ignore) spurious response body for notifications and responses, since we know none are expected Additionally: - fix a bug that we weren't parsing Content-Type correctly. - update examples/client/listfeatures to accept a streamable endpoint Fixes #521
1 parent 5d64d61 commit 5aab47f

File tree

3 files changed

+170
-15
lines changed

3 files changed

+170
-15
lines changed

examples/client/listfeatures/main.go

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,51 @@ import (
2727
"github.com/modelcontextprotocol/go-sdk/mcp"
2828
)
2929

30+
var (
31+
endpoint = flag.String("http", "", "if set, connect to this streamable endpoint rather than running a stdio server")
32+
)
33+
3034
func main() {
3135
flag.Parse()
3236
args := flag.Args()
33-
if len(args) == 0 {
37+
if len(args) == 0 && *endpoint == "" {
3438
fmt.Fprintln(os.Stderr, "Usage: listfeatures <command> [<args>]")
39+
fmt.Fprintln(os.Stderr, "Usage: listfeatures --http=\"https://example.com/server/mcp\"")
3540
fmt.Fprintln(os.Stderr, "List all features for a stdio MCP server")
3641
fmt.Fprintln(os.Stderr)
3742
fmt.Fprintln(os.Stderr, "Example:\n\tlistfeatures npx @modelcontextprotocol/server-everything")
3843
os.Exit(2)
3944
}
4045

41-
ctx := context.Background()
42-
cmd := exec.Command(args[0], args[1:]...)
46+
var (
47+
ctx = context.Background()
48+
transport mcp.Transport
49+
)
50+
if *endpoint != "" {
51+
transport = &mcp.StreamableClientTransport{
52+
Endpoint: *endpoint,
53+
}
54+
} else {
55+
cmd := exec.Command(args[0], args[1:]...)
56+
transport = &mcp.CommandTransport{Command: cmd}
57+
}
4358
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil)
44-
cs, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil)
59+
cs, err := client.Connect(ctx, transport, nil)
4560
if err != nil {
4661
log.Fatal(err)
4762
}
4863
defer cs.Close()
4964

50-
printSection("tools", cs.Tools(ctx, nil), func(t *mcp.Tool) string { return t.Name })
51-
printSection("resources", cs.Resources(ctx, nil), func(r *mcp.Resource) string { return r.Name })
52-
printSection("resource templates", cs.ResourceTemplates(ctx, nil), func(r *mcp.ResourceTemplate) string { return r.Name })
53-
printSection("prompts", cs.Prompts(ctx, nil), func(p *mcp.Prompt) string { return p.Name })
65+
if cs.InitializeResult().Capabilities.Tools != nil {
66+
printSection("tools", cs.Tools(ctx, nil), func(t *mcp.Tool) string { return t.Name })
67+
}
68+
if cs.InitializeResult().Capabilities.Resources != nil {
69+
printSection("resources", cs.Resources(ctx, nil), func(r *mcp.Resource) string { return r.Name })
70+
printSection("resource templates", cs.ResourceTemplates(ctx, nil), func(r *mcp.ResourceTemplate) string { return r.Name })
71+
}
72+
if cs.InitializeResult().Capabilities.Prompts != nil {
73+
printSection("prompts", cs.Prompts(ctx, nil), func(p *mcp.Prompt) string { return p.Name })
74+
}
5475
}
5576

5677
func printSection[T any](name string, features iter.Seq2[T, error], featName func(T) string) {

mcp/streamable.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"fmt"
1313
"io"
1414
"iter"
15+
"log/slog"
1516
"math"
1617
"math/rand/v2"
1718
"net/http"
@@ -981,6 +982,14 @@ type StreamableClientTransport struct {
981982
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
982983
// It defaults to 5. To disable retries, use a negative number.
983984
MaxRetries int
985+
986+
// TODO(rfindley): propose exporting these.
987+
// If strict is set, the transport is in 'strict mode', where any violation
988+
// of the MCP spec causes a failure.
989+
strict bool
990+
// If logger is set, it is used to log aspects of the transport, such as spec
991+
// violations that were ignored.
992+
logger *slog.Logger
984993
}
985994

986995
// These settings are not (yet) exposed to the user in
@@ -1025,6 +1034,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
10251034
incoming: make(chan jsonrpc.Message, 10),
10261035
done: make(chan struct{}),
10271036
maxRetries: maxRetries,
1037+
strict: t.strict,
1038+
logger: t.logger,
10281039
ctx: connCtx,
10291040
cancel: cancel,
10301041
failed: make(chan struct{}),
@@ -1039,6 +1050,8 @@ type streamableClientConn struct {
10391050
cancel context.CancelFunc
10401051
incoming chan jsonrpc.Message
10411052
maxRetries int
1053+
strict bool // from [StreamableClientTransport.strict]
1054+
logger *slog.Logger // from [StreamableClientTransport.logger]
10421055

10431056
// Guard calls to Close, as it may be called multiple times.
10441057
closeOnce sync.Once
@@ -1152,9 +1165,11 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
11521165
}
11531166

11541167
var requestSummary string
1168+
var isCall bool
11551169
switch msg := msg.(type) {
11561170
case *jsonrpc.Request:
11571171
requestSummary = fmt.Sprintf("sending %q", msg.Method)
1172+
isCall = msg.IsCall()
11581173
case *jsonrpc.Response:
11591174
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
11601175
default:
@@ -1209,11 +1224,24 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
12091224
}
12101225
}
12111226
if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted {
1227+
// [§2.1.4]: "If the input is a JSON-RPC response or notification:
1228+
// If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body."
1229+
//
1230+
// [§2.1.4]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server
1231+
resp.Body.Close()
1232+
return nil
1233+
} else if !isCall && !c.strict {
1234+
// Some servers return 200, even with an empty json body.
1235+
// Ignore this response in non-strict mode.
1236+
if c.logger != nil {
1237+
c.logger.Warn(fmt.Sprintf("unexpected status code %d from non-call", resp.StatusCode))
1238+
}
12121239
resp.Body.Close()
12131240
return nil
12141241
}
12151242

1216-
switch ct := resp.Header.Get("Content-Type"); ct {
1243+
contentType := strings.TrimSpace(strings.SplitN(resp.Header.Get("Content-Type"), ";", 2)[0])
1244+
switch contentType {
12171245
case "application/json":
12181246
go c.handleJSON(requestSummary, resp)
12191247

@@ -1223,7 +1251,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
12231251

12241252
default:
12251253
resp.Body.Close()
1226-
return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct)
1254+
return fmt.Errorf("%s: unsupported content type %q", requestSummary, contentType)
12271255
}
12281256
return nil
12291257
}
@@ -1294,18 +1322,36 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
12941322
newResp, err := c.reconnect(lastEventID)
12951323
if err != nil {
12961324
// All reconnection attempts failed: fail the connection.
1297-
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err))
1325+
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
12981326
return
12991327
}
13001328
resp = newResp
13011329
if resp.StatusCode == http.StatusMethodNotAllowed && persistent {
1330+
// [§2.2.3]: "The server MUST either return Content-Type:
1331+
// text/event-stream in response to this HTTP GET, or else return HTTP
1332+
// 405 Method Not Allowed, indicating that the server does not offer an
1333+
// SSE stream at this endpoint."
1334+
//
1335+
// [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server
1336+
13021337
// The server doesn't support the hanging GET.
13031338
resp.Body.Close()
13041339
return
13051340
}
1341+
if resp.StatusCode == http.StatusNotFound && persistent && !c.strict {
1342+
// modelcontextprotocol/gosdk#393: some servers return NotFound instead
1343+
// of MethodNotAllowed for the persistent GET.
1344+
//
1345+
// Treat this like MethodNotAllowed in non-strict mode.
1346+
if c.logger != nil {
1347+
c.logger.Warn("got 404 instead of 405 for hanging GET")
1348+
}
1349+
resp.Body.Close()
1350+
return
1351+
}
13061352
// (see equivalent handling in [streamableClientConn.Write]).
13071353
if resp.StatusCode == http.StatusNotFound {
1308-
c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing))
1354+
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing))
13091355
return
13101356
}
13111357
if resp.StatusCode < 200 || resp.StatusCode >= 300 {

mcp/streamable_client_test.go

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ func TestStreamableClientGETHandling(t *testing.T) {
256256
responses: fakeResponses{
257257
{"POST", "", methodInitialize}: {
258258
header: header{
259-
"Content-Type": "application/json",
259+
"Content-Type": "application/json; charset=utf-8", // should ignore the charset
260260
sessionIDHeader: "123",
261261
},
262262
body: jsonBody(t, initResp),
@@ -293,8 +293,8 @@ func TestStreamableClientGETHandling(t *testing.T) {
293293
t.Fatalf("client.Connect() failed: %v", err)
294294
}
295295

296-
// wait for all required requests to be handled, with exponential
297-
// backoff.
296+
// Since we need the client to observe the result of the hanging GET,
297+
// wait for all requests to be handled.
298298
start := time.Now()
299299
delay := 1 * time.Millisecond
300300
for range 10 {
@@ -323,3 +323,91 @@ func TestStreamableClientGETHandling(t *testing.T) {
323323
})
324324
}
325325
}
326+
327+
func TestStreamableClientStrictness(t *testing.T) {
328+
ctx := context.Background()
329+
330+
tests := []struct {
331+
label string
332+
strict bool
333+
initializedStatus int
334+
getStatus int
335+
wantConnectError bool
336+
wantListError bool
337+
}{
338+
{"conformant server", true, http.StatusAccepted, http.StatusMethodNotAllowed, false, false},
339+
{"strict initialized", true, http.StatusOK, http.StatusMethodNotAllowed, true, false},
340+
{"unstrict initialized", false, http.StatusOK, http.StatusMethodNotAllowed, false, false},
341+
{"strict GET", true, http.StatusAccepted, http.StatusNotFound, false, true},
342+
{"unstrict GET", false, http.StatusOK, http.StatusNotFound, false, false},
343+
}
344+
for _, test := range tests {
345+
t.Run(test.label, func(t *testing.T) {
346+
fake := &fakeStreamableServer{
347+
t: t,
348+
responses: fakeResponses{
349+
{"POST", "", methodInitialize}: {
350+
header: header{
351+
"Content-Type": "application/json",
352+
sessionIDHeader: "123",
353+
},
354+
body: jsonBody(t, initResp),
355+
},
356+
{"POST", "123", notificationInitialized}: {
357+
status: test.initializedStatus,
358+
wantProtocolVersion: latestProtocolVersion,
359+
},
360+
{"GET", "123", ""}: {
361+
header: header{
362+
"Content-Type": "text/event-stream",
363+
},
364+
status: test.getStatus,
365+
wantProtocolVersion: latestProtocolVersion,
366+
},
367+
{"POST", "123", methodListTools}: {
368+
header: header{
369+
"Content-Type": "application/json",
370+
sessionIDHeader: "123",
371+
},
372+
body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)),
373+
optional: true,
374+
},
375+
{"DELETE", "123", ""}: {optional: true},
376+
},
377+
}
378+
httpServer := httptest.NewServer(fake)
379+
defer httpServer.Close()
380+
381+
transport := &StreamableClientTransport{Endpoint: httpServer.URL, strict: test.strict}
382+
client := NewClient(testImpl, nil)
383+
session, err := client.Connect(ctx, transport, nil)
384+
if (err != nil) != test.wantConnectError {
385+
t.Errorf("client.Connect() returned error %v; want error: %t", err, test.wantConnectError)
386+
}
387+
if err != nil {
388+
return
389+
}
390+
// Since we need the client to observe the result of the hanging GET,
391+
// wait for all requests to be handled.
392+
start := time.Now()
393+
delay := 1 * time.Millisecond
394+
for range 10 {
395+
if len(fake.missingRequests()) == 0 {
396+
break
397+
}
398+
time.Sleep(delay)
399+
delay *= 2
400+
}
401+
if missing := fake.missingRequests(); len(missing) > 0 {
402+
t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing)
403+
}
404+
_, err = session.ListTools(ctx, nil)
405+
if (err != nil) != test.wantListError {
406+
t.Errorf("ListTools returned error %v; want error: %t", err, test.wantListError)
407+
}
408+
if err := session.Close(); err != nil {
409+
t.Errorf("closing session: %v", err)
410+
}
411+
})
412+
}
413+
}

0 commit comments

Comments
 (0)