Skip to content

Commit 87c8126

Browse files
authored
mcp: gate prime and close events on protocol version 2025-11-25 (#696)
The prime and close SSE events (SEP-1699) were added in protocol version 2025-11-25. Only send these events when the client negotiates that version or later. - In servePOST, extract protocol version from InitializeParams for initialize requests, otherwise use the Mcp-Protocol-Version header - In serveGET/acquireStream, read protocol version from header - Only set closeLocked callback when prime/close is supported - Add unexported protocolVersion field to ClientSessionOptions for testing Fixes #686
1 parent d6d6edd commit 87c8126

File tree

3 files changed

+129
-30
lines changed

3 files changed

+129
-30
lines changed

mcp/client.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ func (e unsupportedProtocolVersionError) Error() string {
120120
}
121121

122122
// ClientSessionOptions is reserved for future use.
123-
type ClientSessionOptions struct{}
123+
type ClientSessionOptions struct {
124+
// protocolVersion overrides the protocol version sent in the initialize
125+
// request, for testing. If empty, latestProtocolVersion is used.
126+
protocolVersion string
127+
}
124128

125129
func (c *Client) capabilities() *ClientCapabilities {
126130
caps := &ClientCapabilities{}
@@ -151,14 +155,18 @@ func (c *Client) capabilities() *ClientCapabilities {
151155
// when it is no longer needed. However, if the connection is closed by the
152156
// server, calls or notifications will return an error wrapping
153157
// [ErrConnectionClosed].
154-
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) {
158+
func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) {
155159
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil)
156160
if err != nil {
157161
return nil, err
158162
}
159163

164+
protocolVersion := latestProtocolVersion
165+
if opts != nil && opts.protocolVersion != "" {
166+
protocolVersion = opts.protocolVersion
167+
}
160168
params := &InitializeParams{
161-
ProtocolVersion: latestProtocolVersion,
169+
ProtocolVersion: protocolVersion,
162170
ClientInfo: c.impl,
163171
Capabilities: c.capabilities(),
164172
}

mcp/streamable.go

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,15 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
730730
ctx, cancel := context.WithCancel(req.Context())
731731
defer cancel()
732732

733-
stream, done := c.acquireStream(ctx, w, streamID, &lastIdx)
733+
// Read the protocol version from the header. For GET requests, this should
734+
// always be present since GET only happens after initialization.
735+
protocolVersion := req.Header.Get(protocolVersionHeader)
736+
if protocolVersion == "" {
737+
protocolVersion = protocolVersion20250326
738+
}
739+
supportsPrimeClose := protocolVersion >= protocolVersion20251125
740+
741+
stream, done := c.acquireStream(ctx, w, streamID, &lastIdx, supportsPrimeClose)
734742
if stream == nil {
735743
return
736744
}
@@ -792,7 +800,10 @@ func (c *streamableServerConn) writeCloseEvent(w http.ResponseWriter, reconnectA
792800
// Importantly, this function must hold the stream mutex until done replaying
793801
// all messages, so that no delivery or storage of new messages occurs while
794802
// the stream is still replaying.
795-
func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) {
803+
//
804+
// supportsPrimeClose indicates whether the client supports the prime and close
805+
// events (protocol version 2025-11-25 or later).
806+
func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int, supportsPrimeClose bool) (*stream, chan struct{}) {
796807
// if tempStream is set, the stream is done and we're just replaying messages.
797808
//
798809
// We record a temporary stream to claim exclusive replay rights.
@@ -898,16 +909,18 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons
898909
}
899910
return c.writeEvent(w, s.id, Event{Name: "message", Data: data}, lastIdx)
900911
}
901-
s.closeLocked = func(reconnectAfter time.Duration) {
902-
select {
903-
case <-done:
904-
return
905-
default:
906-
}
907-
if reconnectAfter > 0 {
908-
c.writeCloseEvent(w, reconnectAfter)
912+
if supportsPrimeClose {
913+
s.closeLocked = func(reconnectAfter time.Duration) {
914+
select {
915+
case <-done:
916+
return
917+
default:
918+
}
919+
if reconnectAfter > 0 {
920+
c.writeCloseEvent(w, reconnectAfter)
921+
}
922+
close(done)
909923
}
910-
close(done)
911924
}
912925
return s, done
913926
}
@@ -959,6 +972,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
959972
calls := make(map[jsonrpc.ID]struct{})
960973
tokenInfo := auth.TokenInfoFromContext(req.Context())
961974
isInitialize := false
975+
var initializeProtocolVersion string
962976
for _, msg := range incoming {
963977
if jreq, ok := msg.(*jsonrpc.Request); ok {
964978
// Preemptively check that this is a valid request, so that we can fail
@@ -970,6 +984,11 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
970984
}
971985
if jreq.Method == methodInitialize {
972986
isInitialize = true
987+
// Extract the protocol version from InitializeParams.
988+
var params InitializeParams
989+
if err := json.Unmarshal(jreq.Params, &params); err == nil {
990+
initializeProtocolVersion = params.ProtocolVersion
991+
}
973992
}
974993
jreq.Extra = &RequestExtra{
975994
TokenInfo: tokenInfo,
@@ -994,6 +1013,15 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
9941013
}
9951014
}
9961015

1016+
// The prime and close events were added in protocol version 2025-11-25 (SEP-1699).
1017+
// Use the version from InitializeParams if this is an initialize request,
1018+
// otherwise use the protocol version header.
1019+
effectiveVersion := protocolVersion
1020+
if isInitialize && initializeProtocolVersion != "" {
1021+
effectiveVersion = initializeProtocolVersion
1022+
}
1023+
supportsPrimeClose := effectiveVersion >= protocolVersion20251125
1024+
9971025
// If we don't have any calls, we can just publish the incoming messages and return.
9981026
// No need to track a logical stream.
9991027
if len(calls) == 0 {
@@ -1069,7 +1097,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
10691097
} else {
10701098
// Write events in the order we receive them.
10711099
lastIndex := -1
1072-
if c.eventStore != nil {
1100+
if c.eventStore != nil && supportsPrimeClose {
10731101
// Write a priming event.
10741102
// We must also write it to the event store in order for indexes to
10751103
// align.
@@ -1092,16 +1120,18 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
10921120
}
10931121
return c.writeEvent(w, stream.id, Event{Name: "message", Data: data}, &lastIndex)
10941122
}
1095-
stream.closeLocked = func(reconnectAfter time.Duration) {
1096-
select {
1097-
case <-done:
1098-
return
1099-
default:
1100-
}
1101-
if reconnectAfter > 0 {
1102-
c.writeCloseEvent(w, reconnectAfter)
1123+
if supportsPrimeClose {
1124+
stream.closeLocked = func(reconnectAfter time.Duration) {
1125+
select {
1126+
case <-done:
1127+
return
1128+
default:
1129+
}
1130+
if reconnectAfter > 0 {
1131+
c.writeCloseEvent(w, reconnectAfter)
1132+
}
1133+
close(done)
11031134
}
1104-
close(done)
11051135
}
11061136
}
11071137

mcp/streamable_test.go

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ func TestStreamableServerDisconnect(t *testing.T) {
584584
})
585585
clientSession, err := client.Connect(ctx, &StreamableClientTransport{
586586
Endpoint: httpServer.URL,
587-
}, nil)
587+
}, &ClientSessionOptions{protocolVersion: protocolVersion20251125})
588588
if err != nil {
589589
t.Fatalf("client.Connect() failed: %v", err)
590590
}
@@ -752,7 +752,7 @@ func TestStreamableServerTransport(t *testing.T) {
752752
// requests.
753753

754754
// Predefined steps, to avoid repetition below.
755-
initReq := req(1, methodInitialize, &InitializeParams{})
755+
initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: protocolVersion20250618})
756756
initResp := resp(1, &InitializeResult{
757757
Capabilities: &ServerCapabilities{
758758
Logging: &LoggingCapabilities{},
@@ -775,6 +775,30 @@ func TestStreamableServerTransport(t *testing.T) {
775775
wantStatusCode: http.StatusAccepted,
776776
}
777777

778+
// Protocol version 2025-11-25 variants, for testing prime/close events (SEP-1699).
779+
initReq20251125 := req(1, methodInitialize, &InitializeParams{ProtocolVersion: protocolVersion20251125})
780+
initResp20251125 := resp(1, &InitializeResult{
781+
Capabilities: &ServerCapabilities{
782+
Logging: &LoggingCapabilities{},
783+
Tools: &ToolCapabilities{ListChanged: true},
784+
},
785+
ProtocolVersion: protocolVersion20251125,
786+
ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"},
787+
}, nil)
788+
initialize20251125 := streamableRequest{
789+
method: "POST",
790+
messages: []jsonrpc.Message{initReq20251125},
791+
wantStatusCode: http.StatusOK,
792+
wantMessages: []jsonrpc.Message{initResp20251125},
793+
wantSessionID: true,
794+
}
795+
initialized20251125 := streamableRequest{
796+
method: "POST",
797+
headers: http.Header{protocolVersionHeader: {protocolVersion20251125}},
798+
messages: []jsonrpc.Message{initializedMsg},
799+
wantStatusCode: http.StatusAccepted,
800+
}
801+
778802
tests := []struct {
779803
name string
780804
replay bool // if set, use a MemoryEventStore to enable replay
@@ -1026,13 +1050,49 @@ func TestStreamableServerTransport(t *testing.T) {
10261050
wantSessions: 0, // session deleted
10271051
},
10281052
{
1029-
name: "priming message",
1053+
name: "no priming message on old protocol",
10301054
replay: true,
10311055
requests: []streamableRequest{
10321056
initialize,
10331057
initialized,
1058+
{
1059+
method: "POST",
1060+
messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
1061+
wantStatusCode: http.StatusOK,
1062+
wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)},
1063+
wantBodyNotContaining: "prime",
1064+
},
1065+
},
1066+
wantSessions: 1,
1067+
},
1068+
{
1069+
name: "no close message on old protocol",
1070+
replay: true,
1071+
tool: func(t *testing.T, _ context.Context, req *CallToolRequest) {
1072+
req.Extra.CloseStream(time.Millisecond)
1073+
},
1074+
requests: []streamableRequest{
1075+
initialize,
1076+
initialized,
1077+
{
1078+
method: "POST",
1079+
messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
1080+
wantStatusCode: http.StatusOK,
1081+
wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)},
1082+
wantBodyNotContaining: "close",
1083+
},
1084+
},
1085+
wantSessions: 1,
1086+
},
1087+
{
1088+
name: "priming message on 2025-11-25",
1089+
replay: true,
1090+
requests: []streamableRequest{
1091+
initialize20251125,
1092+
initialized20251125,
10341093
{
10351094
method: "POST",
1095+
headers: http.Header{protocolVersionHeader: {protocolVersion20251125}},
10361096
messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
10371097
wantStatusCode: http.StatusOK,
10381098
wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)},
@@ -1042,16 +1102,17 @@ func TestStreamableServerTransport(t *testing.T) {
10421102
wantSessions: 1,
10431103
},
10441104
{
1045-
name: "close message",
1105+
name: "close message on 2025-11-25",
10461106
replay: true,
10471107
tool: func(t *testing.T, _ context.Context, req *CallToolRequest) {
10481108
req.Extra.CloseStream(time.Millisecond)
10491109
},
10501110
requests: []streamableRequest{
1051-
initialize,
1052-
initialized,
1111+
initialize20251125,
1112+
initialized20251125,
10531113
{
10541114
method: "POST",
1115+
headers: http.Header{protocolVersionHeader: {protocolVersion20251125}},
10551116
messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
10561117
wantStatusCode: http.StatusOK,
10571118
wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)},

0 commit comments

Comments
 (0)