Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package mcp

import (
"context"
"fmt"
"iter"
"slices"
"sync"
Expand Down Expand Up @@ -86,6 +87,15 @@ func (c *Client) disconnect(cs *ClientSession) {
})
}

// TODO: Consider exporting this type and its field.
type unsupportedProtocolVersionError struct {
version string
}

func (e unsupportedProtocolVersionError) Error() string {
return fmt.Sprintf("unsupported protocol version: %q", e.version)
}

// Connect begins an MCP session by connecting to a server over the given
// transport, and initializing the session.
//
Expand All @@ -106,19 +116,22 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
}

params := &InitializeParams{
ProtocolVersion: latestProtocolVersion,
ClientInfo: &implementation{Name: c.name, Version: c.version},
Capabilities: caps,
ProtocolVersion: "2025-03-26",
}
// TODO(rfindley): handle protocol negotiation gracefully. If the server
// responds with 2024-11-05, surface that failure to the caller of connect,
// so that they can choose a different transport.
res, err := handleSend[*InitializeResult](ctx, cs, methodInitialize, params)
if err != nil {
_ = cs.Close()
return nil, err
}
if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) {
return nil, unsupportedProtocolVersionError{res.ProtocolVersion}
}
cs.initializeResult = res
if hc, ok := cs.mcpConn.(HTTPConnection); ok {
hc.SetProtocolVersion(res.ProtocolVersion)
}
if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil {
_ = cs.Close()
return nil, err
Expand Down
9 changes: 5 additions & 4 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,11 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam
ss.mu.Unlock()
}()

version := "2025-03-26" // preferred version
switch v := params.ProtocolVersion; v {
case "2024-11-05", "2025-03-26":
version = v
// If we support the client's version, reply with it. Otherwise, reply with our
// latest version.
version := params.ProtocolVersion
if !slices.Contains(supportedProtocolVersions, params.ProtocolVersion) {
version = latestProtocolVersion
}

return &InitializeResult{
Expand Down
6 changes: 6 additions & 0 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ import (
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
)

// latestProtocolVersion is the latest protocol version that this version of the SDK supports.
// It is the version that the client sends in the initialization request.
const latestProtocolVersion = "2025-06-18"

var supportedProtocolVersions = []string{latestProtocolVersion, "2025-03-26"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we also support the 2024-11 version?
If we don't say that, what will that mean for SSE usage?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.


// A MethodHandler handles MCP messages.
// For methods, exactly one of the return values must be nil.
// For notifications, both must be nil.
Expand Down
33 changes: 26 additions & 7 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ import (
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
)

const (
protocolVersionHeader = "Mcp-Protocol-Version"
sessionIDHeader = "Mcp-Session-Id"
)

// A StreamableHTTPHandler is an http.Handler that serves streamable MCP
// sessions, as defined by the [MCP spec].
//
Expand Down Expand Up @@ -88,7 +93,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
}

var session *StreamableServerTransport
if id := req.Header.Get("Mcp-Session-Id"); id != "" {
if id := req.Header.Get(sessionIDHeader); id != "" {
h.sessionsMu.Lock()
session, _ = h.sessions[id]
h.sessionsMu.Unlock()
Expand Down Expand Up @@ -386,7 +391,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h
t.mu.Unlock()
}

w.Header().Set("Mcp-Session-Id", t.id)
w.Header().Set(sessionIDHeader, t.id)
w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler]
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
Expand Down Expand Up @@ -636,12 +641,19 @@ type streamableClientConn struct {
closeOnce sync.Once
closeErr error

mu sync.Mutex
_sessionID string
mu sync.Mutex
protocolVersion string
_sessionID string
// bodies map[*http.Response]io.Closer
err error
}

func (c *streamableClientConn) SetProtocolVersion(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.protocolVersion = s
}

func (c *streamableClientConn) SessionID() string {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -707,8 +719,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
if err != nil {
return "", err
}
if s.protocolVersion != "" {
req.Header.Set(protocolVersionHeader, s.protocolVersion)
}
if sessionID != "" {
req.Header.Set("Mcp-Session-Id", sessionID)
req.Header.Set(sessionIDHeader, sessionID)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
Expand All @@ -724,7 +739,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
return "", fmt.Errorf("broken session: %v", resp.Status)
}

sessionID = resp.Header.Get("Mcp-Session-Id")
sessionID = resp.Header.Get(sessionIDHeader)
if resp.Header.Get("Content-Type") == "text/event-stream" {
go s.handleSSE(resp)
} else {
Expand Down Expand Up @@ -763,7 +778,11 @@ func (s *streamableClientConn) Close() error {
if err != nil {
s.closeErr = err
} else {
req.Header.Set("Mcp-Session-Id", s._sessionID)
// TODO(jba): confirm that we don't need a lock here, or add locking.
if s.protocolVersion != "" {
req.Header.Set(protocolVersionHeader, s.protocolVersion)
}
req.Header.Set(sessionIDHeader, s._sessionID)
if _, err := s.client.Do(req); err != nil {
s.closeErr = err
}
Expand Down
10 changes: 9 additions & 1 deletion mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ func TestStreamableTransports(t *testing.T) {
// 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
// cookie-checking middleware.
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
var header http.Header
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header = r.Header
cookie, err := r.Cookie("test-cookie")
if err != nil {
t.Errorf("missing cookie: %v", err)
Expand Down Expand Up @@ -72,6 +74,9 @@ func TestStreamableTransports(t *testing.T) {
if sid == "" {
t.Error("empty session ID")
}
if g, w := session.mcpConn.(*streamableClientConn).protocolVersion, latestProtocolVersion; g != w {
t.Fatalf("got protocol version %q, want %q", g, w)
}
// 4. The client calls the "greet" tool.
params := &CallToolParams{
Name: "greet",
Expand All @@ -84,6 +89,9 @@ func TestStreamableTransports(t *testing.T) {
if g := session.ID(); g != sid {
t.Errorf("session ID: got %q, want %q", g, sid)
}
if g, w := header.Get(protocolVersionHeader), latestProtocolVersion; g != w {
t.Errorf("got protocol version header %q, want %q", g, w)
}

// 5. Verify that the correct response is received.
want := &CallToolResult{
Expand Down Expand Up @@ -154,7 +162,7 @@ func TestStreamableServerTransport(t *testing.T) {
Resources: &resourceCapabilities{ListChanged: true},
Tools: &toolCapabilities{ListChanged: true},
},
ProtocolVersion: "2025-03-26",
ProtocolVersion: latestProtocolVersion,
ServerInfo: &implementation{Name: "testServer", Version: "v1.0.0"},
}, nil)
initializedMsg := req(0, "initialized", &InitializedParams{})
Expand Down
6 changes: 6 additions & 0 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ type Connection interface {
SessionID() string
}

// An HTTPConnection is a [Connection] that runs over HTTP.
type HTTPConnection interface {
Connection
SetProtocolVersion(string)
}

// A StdioTransport is a [Transport] that communicates over stdin/stdout using
// newline-delimited JSON.
type StdioTransport struct {
Expand Down
Loading