diff --git a/examples/client/listfeatures/main.go b/examples/client/listfeatures/main.go index 9d473f0b..044b5e99 100644 --- a/examples/client/listfeatures/main.go +++ b/examples/client/listfeatures/main.go @@ -27,30 +27,51 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) +var ( + endpoint = flag.String("http", "", "if set, connect to this streamable endpoint rather than running a stdio server") +) + func main() { flag.Parse() args := flag.Args() - if len(args) == 0 { + if len(args) == 0 && *endpoint == "" { fmt.Fprintln(os.Stderr, "Usage: listfeatures []") + fmt.Fprintln(os.Stderr, "Usage: listfeatures --http=\"https://example.com/server/mcp\"") fmt.Fprintln(os.Stderr, "List all features for a stdio MCP server") fmt.Fprintln(os.Stderr) fmt.Fprintln(os.Stderr, "Example:\n\tlistfeatures npx @modelcontextprotocol/server-everything") os.Exit(2) } - ctx := context.Background() - cmd := exec.Command(args[0], args[1:]...) + var ( + ctx = context.Background() + transport mcp.Transport + ) + if *endpoint != "" { + transport = &mcp.StreamableClientTransport{ + Endpoint: *endpoint, + } + } else { + cmd := exec.Command(args[0], args[1:]...) + transport = &mcp.CommandTransport{Command: cmd} + } client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) - cs, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) + cs, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } defer cs.Close() - printSection("tools", cs.Tools(ctx, nil), func(t *mcp.Tool) string { return t.Name }) - printSection("resources", cs.Resources(ctx, nil), func(r *mcp.Resource) string { return r.Name }) - printSection("resource templates", cs.ResourceTemplates(ctx, nil), func(r *mcp.ResourceTemplate) string { return r.Name }) - printSection("prompts", cs.Prompts(ctx, nil), func(p *mcp.Prompt) string { return p.Name }) + if cs.InitializeResult().Capabilities.Tools != nil { + printSection("tools", cs.Tools(ctx, nil), func(t *mcp.Tool) string { return t.Name }) + } + if cs.InitializeResult().Capabilities.Resources != nil { + printSection("resources", cs.Resources(ctx, nil), func(r *mcp.Resource) string { return r.Name }) + printSection("resource templates", cs.ResourceTemplates(ctx, nil), func(r *mcp.ResourceTemplate) string { return r.Name }) + } + if cs.InitializeResult().Capabilities.Prompts != nil { + printSection("prompts", cs.Prompts(ctx, nil), func(p *mcp.Prompt) string { return p.Name }) + } } func printSection[T any](name string, features iter.Seq2[T, error], featName func(T) string) { diff --git a/mcp/streamable.go b/mcp/streamable.go index 8072a637..a96386d9 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "iter" + "log/slog" "math" "math/rand/v2" "net/http" @@ -981,6 +982,14 @@ type StreamableClientTransport struct { // MaxRetries is the maximum number of times to attempt a reconnect before giving up. // It defaults to 5. To disable retries, use a negative number. MaxRetries int + + // TODO(rfindley): propose exporting these. + // If strict is set, the transport is in 'strict mode', where any violation + // of the MCP spec causes a failure. + strict bool + // If logger is set, it is used to log aspects of the transport, such as spec + // violations that were ignored. + logger *slog.Logger } // These settings are not (yet) exposed to the user in @@ -1025,6 +1034,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), maxRetries: maxRetries, + strict: t.strict, + logger: t.logger, ctx: connCtx, cancel: cancel, failed: make(chan struct{}), @@ -1039,6 +1050,8 @@ type streamableClientConn struct { cancel context.CancelFunc incoming chan jsonrpc.Message maxRetries int + strict bool // from [StreamableClientTransport.strict] + logger *slog.Logger // from [StreamableClientTransport.logger] // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once @@ -1152,9 +1165,11 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } var requestSummary string + var isCall bool switch msg := msg.(type) { case *jsonrpc.Request: requestSummary = fmt.Sprintf("sending %q", msg.Method) + isCall = msg.IsCall() case *jsonrpc.Response: requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) default: @@ -1209,11 +1224,24 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } } if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted { + // [§2.1.4]: "If the input is a JSON-RPC response or notification: + // If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body." + // + // [§2.1.4]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + resp.Body.Close() + return nil + } else if !isCall && !c.strict { + // Some servers return 200, even with an empty json body. + // Ignore this response in non-strict mode. + if c.logger != nil { + c.logger.Warn(fmt.Sprintf("unexpected status code %d from non-call", resp.StatusCode)) + } resp.Body.Close() return nil } - switch ct := resp.Header.Get("Content-Type"); ct { + contentType := strings.TrimSpace(strings.SplitN(resp.Header.Get("Content-Type"), ";", 2)[0]) + switch contentType { case "application/json": go c.handleJSON(requestSummary, resp) @@ -1223,7 +1251,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e default: resp.Body.Close() - return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct) + return fmt.Errorf("%s: unsupported content type %q", requestSummary, contentType) } return nil } @@ -1294,18 +1322,36 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt newResp, err := c.reconnect(lastEventID) if err != nil { // All reconnection attempts failed: fail the connection. - c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err)) + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) return } resp = newResp if resp.StatusCode == http.StatusMethodNotAllowed && persistent { + // [§2.2.3]: "The server MUST either return Content-Type: + // text/event-stream in response to this HTTP GET, or else return HTTP + // 405 Method Not Allowed, indicating that the server does not offer an + // SSE stream at this endpoint." + // + // [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + // The server doesn't support the hanging GET. resp.Body.Close() return } + if resp.StatusCode == http.StatusNotFound && persistent && !c.strict { + // modelcontextprotocol/gosdk#393: some servers return NotFound instead + // of MethodNotAllowed for the persistent GET. + // + // Treat this like MethodNotAllowed in non-strict mode. + if c.logger != nil { + c.logger.Warn("got 404 instead of 405 for hanging GET") + } + resp.Body.Close() + return + } // (see equivalent handling in [streamableClientConn.Write]). if resp.StatusCode == http.StatusNotFound { - c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing)) + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing)) return } if resp.StatusCode < 200 || resp.StatusCode >= 300 { diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 001d3a64..45c4a2e6 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -256,7 +256,7 @@ func TestStreamableClientGETHandling(t *testing.T) { responses: fakeResponses{ {"POST", "", methodInitialize}: { header: header{ - "Content-Type": "application/json", + "Content-Type": "application/json; charset=utf-8", // should ignore the charset sessionIDHeader: "123", }, body: jsonBody(t, initResp), @@ -293,8 +293,8 @@ func TestStreamableClientGETHandling(t *testing.T) { t.Fatalf("client.Connect() failed: %v", err) } - // wait for all required requests to be handled, with exponential - // backoff. + // Since we need the client to observe the result of the hanging GET, + // wait for all requests to be handled. start := time.Now() delay := 1 * time.Millisecond for range 10 { @@ -323,3 +323,91 @@ func TestStreamableClientGETHandling(t *testing.T) { }) } } + +func TestStreamableClientStrictness(t *testing.T) { + ctx := context.Background() + + tests := []struct { + label string + strict bool + initializedStatus int + getStatus int + wantConnectError bool + wantListError bool + }{ + {"conformant server", true, http.StatusAccepted, http.StatusMethodNotAllowed, false, false}, + {"strict initialized", true, http.StatusOK, http.StatusMethodNotAllowed, true, false}, + {"unstrict initialized", false, http.StatusOK, http.StatusMethodNotAllowed, false, false}, + {"strict GET", true, http.StatusAccepted, http.StatusNotFound, false, true}, + {"unstrict GET", false, http.StatusOK, http.StatusNotFound, false, false}, + } + for _, test := range tests { + t.Run(test.label, func(t *testing.T) { + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: test.initializedStatus, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: test.getStatus, + wantProtocolVersion: latestProtocolVersion, + }, + {"POST", "123", methodListTools}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), + optional: true, + }, + {"DELETE", "123", ""}: {optional: true}, + }, + } + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL, strict: test.strict} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if (err != nil) != test.wantConnectError { + t.Errorf("client.Connect() returned error %v; want error: %t", err, test.wantConnectError) + } + if err != nil { + return + } + // Since we need the client to observe the result of the hanging GET, + // wait for all requests to be handled. + start := time.Now() + delay := 1 * time.Millisecond + for range 10 { + if len(fake.missingRequests()) == 0 { + break + } + time.Sleep(delay) + delay *= 2 + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing) + } + _, err = session.ListTools(ctx, nil) + if (err != nil) != test.wantListError { + t.Errorf("ListTools returned error %v; want error: %t", err, test.wantListError) + } + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + }) + } +}