Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions examples/client/listfeatures/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,51 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)

var (
endpoint = flag.String("http", "", "if set, connect to this streamable endpoint rather than running a stdio server")
)

func main() {
flag.Parse()
args := flag.Args()
if len(args) == 0 {
if len(args) == 0 && *endpoint == "" {
fmt.Fprintln(os.Stderr, "Usage: listfeatures <command> [<args>]")
fmt.Fprintln(os.Stderr, "Usage: listfeatures --http=\"https://example.com/server/mcp\"")
fmt.Fprintln(os.Stderr, "List all features for a stdio MCP server")
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Example:\n\tlistfeatures npx @modelcontextprotocol/server-everything")
os.Exit(2)
}

ctx := context.Background()
cmd := exec.Command(args[0], args[1:]...)
var (
ctx = context.Background()
transport mcp.Transport
)
if *endpoint != "" {
transport = &mcp.StreamableClientTransport{
Endpoint: *endpoint,
}
} else {
cmd := exec.Command(args[0], args[1:]...)
transport = &mcp.CommandTransport{Command: cmd}
}
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil)
cs, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil)
cs, err := client.Connect(ctx, transport, nil)
if err != nil {
log.Fatal(err)
}
defer cs.Close()

printSection("tools", cs.Tools(ctx, nil), func(t *mcp.Tool) string { return t.Name })
printSection("resources", cs.Resources(ctx, nil), func(r *mcp.Resource) string { return r.Name })
printSection("resource templates", cs.ResourceTemplates(ctx, nil), func(r *mcp.ResourceTemplate) string { return r.Name })
printSection("prompts", cs.Prompts(ctx, nil), func(p *mcp.Prompt) string { return p.Name })
if cs.InitializeResult().Capabilities.Tools != nil {
printSection("tools", cs.Tools(ctx, nil), func(t *mcp.Tool) string { return t.Name })
}
if cs.InitializeResult().Capabilities.Resources != nil {
printSection("resources", cs.Resources(ctx, nil), func(r *mcp.Resource) string { return r.Name })
printSection("resource templates", cs.ResourceTemplates(ctx, nil), func(r *mcp.ResourceTemplate) string { return r.Name })
}
if cs.InitializeResult().Capabilities.Prompts != nil {
printSection("prompts", cs.Prompts(ctx, nil), func(p *mcp.Prompt) string { return p.Name })
}
}

func printSection[T any](name string, features iter.Seq2[T, error], featName func(T) string) {
Expand Down
54 changes: 50 additions & 4 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"io"
"iter"
"log/slog"
"math"
"math/rand/v2"
"net/http"
Expand Down Expand Up @@ -981,6 +982,14 @@ type StreamableClientTransport struct {
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
// It defaults to 5. To disable retries, use a negative number.
MaxRetries int

// TODO(rfindley): propose exporting these.
// If strict is set, the transport is in 'strict mode', where any violation
// of the MCP spec causes a failure.
strict bool
// If logger is set, it is used to log aspects of the transport, such as spec
// violations that were ignored.
logger *slog.Logger
}

// These settings are not (yet) exposed to the user in
Expand Down Expand Up @@ -1025,6 +1034,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: t.logger,
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
Expand All @@ -1039,6 +1050,8 @@ type streamableClientConn struct {
cancel context.CancelFunc
incoming chan jsonrpc.Message
maxRetries int
strict bool // from [StreamableClientTransport.strict]
logger *slog.Logger // from [StreamableClientTransport.logger]

// Guard calls to Close, as it may be called multiple times.
closeOnce sync.Once
Expand Down Expand Up @@ -1152,9 +1165,11 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
}

var requestSummary string
var isCall bool
switch msg := msg.(type) {
case *jsonrpc.Request:
requestSummary = fmt.Sprintf("sending %q", msg.Method)
isCall = msg.IsCall()
case *jsonrpc.Response:
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
default:
Expand Down Expand Up @@ -1209,11 +1224,24 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
}
}
if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted {
// [§2.1.4]: "If the input is a JSON-RPC response or notification:
// If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body."
//
// [§2.1.4]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server
resp.Body.Close()
return nil
} else if !isCall && !c.strict {
// Some servers return 200, even with an empty json body.
// Ignore this response in non-strict mode.
if c.logger != nil {
c.logger.Warn(fmt.Sprintf("unexpected status code %d from non-call", resp.StatusCode))
}
resp.Body.Close()
return nil
}

switch ct := resp.Header.Get("Content-Type"); ct {
contentType := strings.TrimSpace(strings.SplitN(resp.Header.Get("Content-Type"), ";", 2)[0])
switch contentType {
case "application/json":
go c.handleJSON(requestSummary, resp)

Expand All @@ -1223,7 +1251,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e

default:
resp.Body.Close()
return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct)
return fmt.Errorf("%s: unsupported content type %q", requestSummary, contentType)
}
return nil
}
Expand Down Expand Up @@ -1294,18 +1322,36 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
newResp, err := c.reconnect(lastEventID)
if err != nil {
// All reconnection attempts failed: fail the connection.
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err))
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
return
}
resp = newResp
if resp.StatusCode == http.StatusMethodNotAllowed && persistent {
// [§2.2.3]: "The server MUST either return Content-Type:
// text/event-stream in response to this HTTP GET, or else return HTTP
// 405 Method Not Allowed, indicating that the server does not offer an
// SSE stream at this endpoint."
//
// [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server

// The server doesn't support the hanging GET.
resp.Body.Close()
return
}
if resp.StatusCode == http.StatusNotFound && persistent && !c.strict {
// modelcontextprotocol/gosdk#393: some servers return NotFound instead
// of MethodNotAllowed for the persistent GET.
//
// Treat this like MethodNotAllowed in non-strict mode.
if c.logger != nil {
c.logger.Warn("got 404 instead of 405 for hanging GET")
}
resp.Body.Close()
return
}
// (see equivalent handling in [streamableClientConn.Write]).
if resp.StatusCode == http.StatusNotFound {
c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing))
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing))
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
Expand Down
94 changes: 91 additions & 3 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func TestStreamableClientGETHandling(t *testing.T) {
responses: fakeResponses{
{"POST", "", methodInitialize}: {
header: header{
"Content-Type": "application/json",
"Content-Type": "application/json; charset=utf-8", // should ignore the charset
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
Expand Down Expand Up @@ -293,8 +293,8 @@ func TestStreamableClientGETHandling(t *testing.T) {
t.Fatalf("client.Connect() failed: %v", err)
}

// wait for all required requests to be handled, with exponential
// backoff.
// Since we need the client to observe the result of the hanging GET,
// wait for all requests to be handled.
start := time.Now()
delay := 1 * time.Millisecond
for range 10 {
Expand Down Expand Up @@ -323,3 +323,91 @@ func TestStreamableClientGETHandling(t *testing.T) {
})
}
}

func TestStreamableClientStrictness(t *testing.T) {
ctx := context.Background()

tests := []struct {
label string
strict bool
initializedStatus int
getStatus int
wantConnectError bool
wantListError bool
}{
{"conformant server", true, http.StatusAccepted, http.StatusMethodNotAllowed, false, false},
{"strict initialized", true, http.StatusOK, http.StatusMethodNotAllowed, true, false},
{"unstrict initialized", false, http.StatusOK, http.StatusMethodNotAllowed, false, false},
{"strict GET", true, http.StatusAccepted, http.StatusNotFound, false, true},
{"unstrict GET", false, http.StatusOK, http.StatusNotFound, false, false},
}
for _, test := range tests {
t.Run(test.label, func(t *testing.T) {
fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
},
{"POST", "123", notificationInitialized}: {
status: test.initializedStatus,
wantProtocolVersion: latestProtocolVersion,
},
{"GET", "123", ""}: {
header: header{
"Content-Type": "text/event-stream",
},
status: test.getStatus,
wantProtocolVersion: latestProtocolVersion,
},
{"POST", "123", methodListTools}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)),
optional: true,
},
{"DELETE", "123", ""}: {optional: true},
},
}
httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL, strict: test.strict}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if (err != nil) != test.wantConnectError {
t.Errorf("client.Connect() returned error %v; want error: %t", err, test.wantConnectError)
}
if err != nil {
return
}
// Since we need the client to observe the result of the hanging GET,
// wait for all requests to be handled.
start := time.Now()
delay := 1 * time.Millisecond
for range 10 {
if len(fake.missingRequests()) == 0 {
break
}
time.Sleep(delay)
delay *= 2
}
if missing := fake.missingRequests(); len(missing) > 0 {
t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing)
}
_, err = session.ListTools(ctx, nil)
if (err != nil) != test.wantListError {
t.Errorf("ListTools returned error %v; want error: %t", err, test.wantListError)
}
if err := session.Close(); err != nil {
t.Errorf("closing session: %v", err)
}
})
}
}