diff --git a/mcp/client.go b/mcp/client.go index 3a935040..512be2cb 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "fmt" "iter" "slices" "sync" @@ -86,6 +87,15 @@ func (c *Client) disconnect(cs *ClientSession) { }) } +// TODO: Consider exporting this type and its field. +type unsupportedProtocolVersionError struct { + version string +} + +func (e unsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.version) +} + // Connect begins an MCP session by connecting to a server over the given // transport, and initializing the session. // @@ -106,19 +116,22 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e } params := &InitializeParams{ + ProtocolVersion: latestProtocolVersion, ClientInfo: &implementation{Name: c.name, Version: c.version}, Capabilities: caps, - ProtocolVersion: "2025-03-26", } - // TODO(rfindley): handle protocol negotiation gracefully. If the server - // responds with 2024-11-05, surface that failure to the caller of connect, - // so that they can choose a different transport. res, err := handleSend[*InitializeResult](ctx, cs, methodInitialize, params) if err != nil { _ = cs.Close() return nil, err } + if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) { + return nil, unsupportedProtocolVersionError{res.ProtocolVersion} + } cs.initializeResult = res + if hc, ok := cs.mcpConn.(httpConnection); ok { + hc.setProtocolVersion(res.ProtocolVersion) + } if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil { _ = cs.Close() return nil, err diff --git a/mcp/server.go b/mcp/server.go index cd8f808b..31b4226d 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -634,10 +634,11 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam ss.mu.Unlock() }() - version := "2025-03-26" // preferred version - switch v := params.ProtocolVersion; v { - case "2024-11-05", "2025-03-26": - version = v + // If we support the client's version, reply with it. Otherwise, reply with our + // latest version. + version := params.ProtocolVersion + if !slices.Contains(supportedProtocolVersions, params.ProtocolVersion) { + version = latestProtocolVersion } return &InitializeResult{ diff --git a/mcp/shared.go b/mcp/shared.go index db871ca8..8a38777e 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -22,6 +22,16 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) +// latestProtocolVersion is the latest protocol version that this version of the SDK supports. +// It is the version that the client sends in the initialization request. +const latestProtocolVersion = "2025-06-18" + +var supportedProtocolVersions = []string{ + latestProtocolVersion, + "2025-03-26", + "2024-11-05", +} + // A MethodHandler handles MCP messages. // For methods, exactly one of the return values must be nil. // For notifications, both must be nil. diff --git a/mcp/streamable.go b/mcp/streamable.go index db3add85..52208948 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -18,6 +18,11 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) +const ( + protocolVersionHeader = "Mcp-Protocol-Version" + sessionIDHeader = "Mcp-Session-Id" +) + // A StreamableHTTPHandler is an http.Handler that serves streamable MCP // sessions, as defined by the [MCP spec]. // @@ -88,7 +93,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } var session *StreamableServerTransport - if id := req.Header.Get("Mcp-Session-Id"); id != "" { + if id := req.Header.Get(sessionIDHeader); id != "" { h.sessionsMu.Lock() session, _ = h.sessions[id] h.sessionsMu.Unlock() @@ -386,7 +391,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h t.mu.Unlock() } - w.Header().Set("Mcp-Session-Id", t.id) + w.Header().Set(sessionIDHeader, t.id) w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") @@ -636,12 +641,19 @@ type streamableClientConn struct { closeOnce sync.Once closeErr error - mu sync.Mutex - _sessionID string + mu sync.Mutex + protocolVersion string + _sessionID string // bodies map[*http.Response]io.Closer err error } +func (c *streamableClientConn) setProtocolVersion(s string) { + c.mu.Lock() + defer c.mu.Unlock() + c.protocolVersion = s +} + func (c *streamableClientConn) SessionID() string { c.mu.Lock() defer c.mu.Unlock() @@ -707,8 +719,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string if err != nil { return "", err } + if s.protocolVersion != "" { + req.Header.Set(protocolVersionHeader, s.protocolVersion) + } if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) + req.Header.Set(sessionIDHeader, sessionID) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") @@ -724,7 +739,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string return "", fmt.Errorf("broken session: %v", resp.Status) } - sessionID = resp.Header.Get("Mcp-Session-Id") + sessionID = resp.Header.Get(sessionIDHeader) if resp.Header.Get("Content-Type") == "text/event-stream" { go s.handleSSE(resp) } else { @@ -763,7 +778,11 @@ func (s *streamableClientConn) Close() error { if err != nil { s.closeErr = err } else { - req.Header.Set("Mcp-Session-Id", s._sessionID) + // TODO(jba): confirm that we don't need a lock here, or add locking. + if s.protocolVersion != "" { + req.Header.Set(protocolVersionHeader, s.protocolVersion) + } + req.Header.Set(sessionIDHeader, s._sessionID) if _, err := s.client.Do(req); err != nil { s.closeErr = err } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 8925b3da..3329caea 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -36,7 +36,9 @@ func TestStreamableTransports(t *testing.T) { // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + var header http.Header httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header = r.Header cookie, err := r.Cookie("test-cookie") if err != nil { t.Errorf("missing cookie: %v", err) @@ -72,6 +74,9 @@ func TestStreamableTransports(t *testing.T) { if sid == "" { t.Error("empty session ID") } + if g, w := session.mcpConn.(*streamableClientConn).protocolVersion, latestProtocolVersion; g != w { + t.Fatalf("got protocol version %q, want %q", g, w) + } // 4. The client calls the "greet" tool. params := &CallToolParams{ Name: "greet", @@ -84,6 +89,9 @@ func TestStreamableTransports(t *testing.T) { if g := session.ID(); g != sid { t.Errorf("session ID: got %q, want %q", g, sid) } + if g, w := header.Get(protocolVersionHeader), latestProtocolVersion; g != w { + t.Errorf("got protocol version header %q, want %q", g, w) + } // 5. Verify that the correct response is received. want := &CallToolResult{ @@ -154,7 +162,7 @@ func TestStreamableServerTransport(t *testing.T) { Resources: &resourceCapabilities{ListChanged: true}, Tools: &toolCapabilities{ListChanged: true}, }, - ProtocolVersion: "2025-03-26", + ProtocolVersion: latestProtocolVersion, ServerInfo: &implementation{Name: "testServer", Version: "v1.0.0"}, }, nil) initializedMsg := req(0, "initialized", &InitializedParams{}) diff --git a/mcp/transport.go b/mcp/transport.go index 85bfaf65..f0b81650 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -53,6 +53,12 @@ type Connection interface { SessionID() string } +// An httpConnection is a [Connection] that runs over HTTP. +type httpConnection interface { + Connection + setProtocolVersion(string) +} + // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. type StdioTransport struct {