Skip to content

Commit 1b87a0b

Browse files
committed
mcp: be strict about returning the Mcp-Session-Id header
Rather than returning the Mcp-Session-Id header for all responses, only return it from initialize, per the spec. Fixes #412
1 parent 07b9cee commit 1b87a0b

File tree

2 files changed

+37
-33
lines changed

2 files changed

+37
-33
lines changed

mcp/streamable.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er
424424
// It is always text/event-stream, since it must carry arbitrarily many
425425
// messages.
426426
var err error
427-
t.connection.streams[""], err = t.connection.newStream(ctx, "", false)
427+
t.connection.streams[""], err = t.connection.newStream(ctx, "", false, false)
428428
if err != nil {
429429
return nil, err
430430
}
@@ -485,6 +485,10 @@ type stream struct {
485485
// an empty string is used for messages that don't correlate with an incoming request.
486486
id StreamID
487487

488+
// If isInitialize is set, the stream is in response to an initialize request,
489+
// and therefore should include the session ID header.
490+
isInitialize bool
491+
488492
// jsonResponse records whether this stream should respond with application/json
489493
// instead of text/event-stream.
490494
//
@@ -513,12 +517,13 @@ type stream struct {
513517
requests map[jsonrpc.ID]struct{}
514518
}
515519

516-
func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) {
520+
func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, isInitialize, jsonResponse bool) (*stream, error) {
517521
if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil {
518522
return nil, err
519523
}
520524
return &stream{
521525
id: id,
526+
isInitialize: isInitialize,
522527
jsonResponse: jsonResponse,
523528
requests: make(map[jsonrpc.ID]struct{}),
524529
}, nil
@@ -647,6 +652,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
647652
}
648653
requests := make(map[jsonrpc.ID]struct{})
649654
tokenInfo := auth.TokenInfoFromContext(req.Context())
655+
isInitialize := false
650656
for _, msg := range incoming {
651657
if jreq, ok := msg.(*jsonrpc.Request); ok {
652658
// Preemptively check that this is a valid request, so that we can fail
@@ -656,6 +662,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
656662
http.Error(w, err.Error(), http.StatusBadRequest)
657663
return
658664
}
665+
if jreq.Method == methodInitialize {
666+
isInitialize = true
667+
}
659668
jreq.Extra = &RequestExtra{
660669
TokenInfo: tokenInfo,
661670
Header: req.Header,
@@ -672,7 +681,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
672681
// notifications or server->client requests made in the course of handling.
673682
// Update accounting for this incoming payload.
674683
if len(requests) > 0 {
675-
stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse)
684+
stream, err = c.newStream(req.Context(), StreamID(randText()), isInitialize, c.jsonResponse)
676685
if err != nil {
677686
http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError)
678687
return
@@ -708,7 +717,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
708717
func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) {
709718
w.Header().Set("Cache-Control", "no-cache, no-transform")
710719
w.Header().Set("Content-Type", "application/json")
711-
if c.sessionID != "" {
720+
if c.sessionID != "" && stream.isInitialize {
712721
w.Header().Set(sessionIDHeader, c.sessionID)
713722
}
714723

@@ -747,7 +756,7 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
747756
w.Header().Set("Cache-Control", "no-cache, no-transform")
748757
w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler]
749758
w.Header().Set("Connection", "keep-alive")
750-
if c.sessionID != "" {
759+
if c.sessionID != "" && stream.isInitialize {
751760
w.Header().Set(sessionIDHeader, c.sessionID)
752761
}
753762
if persistent {

mcp/streamable_test.go

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func TestStreamableTransports(t *testing.T) {
133133
defer session.Close()
134134
sid := session.ID()
135135
if sid == "" {
136-
t.Error("empty session ID")
136+
t.Fatalf("empty session ID")
137137
}
138138
if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w {
139139
t.Fatalf("got protocol version %q, want %q", g, w)
@@ -475,6 +475,8 @@ func resp(id int64, result any, err error) *jsonrpc.Response {
475475
}
476476
}
477477

478+
var ()
479+
478480
func TestStreamableServerTransport(t *testing.T) {
479481
// This test checks detailed behavior of the streamable server transport, by
480482
// faking the behavior of a streamable client using a sequence of HTTP
@@ -502,7 +504,6 @@ func TestStreamableServerTransport(t *testing.T) {
502504
method: "POST",
503505
messages: []jsonrpc.Message{initializedMsg},
504506
wantStatusCode: http.StatusAccepted,
505-
wantSessionID: false, // TODO: should this be true?
506507
}
507508

508509
tests := []struct {
@@ -520,7 +521,6 @@ func TestStreamableServerTransport(t *testing.T) {
520521
messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
521522
wantStatusCode: http.StatusOK,
522523
wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)},
523-
wantSessionID: true,
524524
},
525525
},
526526
},
@@ -535,30 +535,26 @@ func TestStreamableServerTransport(t *testing.T) {
535535
headers: http.Header{"Accept": {"text/plain", "application/*"}},
536536
messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})},
537537
wantStatusCode: http.StatusBadRequest, // missing text/event-stream
538-
wantSessionID: false,
539538
},
540539
{
541540
method: "POST",
542541
headers: http.Header{"Accept": {"text/event-stream"}},
543542
messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})},
544543
wantStatusCode: http.StatusBadRequest, // missing application/json
545-
wantSessionID: false,
546544
},
547545
{
548546
method: "POST",
549547
headers: http.Header{"Accept": {"text/plain", "*/*"}},
550548
messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})},
551549
wantStatusCode: http.StatusOK,
552550
wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)},
553-
wantSessionID: true,
554551
},
555552
{
556553
method: "POST",
557554
headers: http.Header{"Accept": {"text/*, application/*"}},
558555
messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})},
559556
wantStatusCode: http.StatusOK,
560557
wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)},
561-
wantSessionID: true,
562558
},
563559
},
564560
},
@@ -598,7 +594,6 @@ func TestStreamableServerTransport(t *testing.T) {
598594
req(0, "notifications/progress", &ProgressNotificationParams{}),
599595
resp(2, &CallToolResult{}, nil),
600596
},
601-
wantSessionID: true,
602597
},
603598
},
604599
},
@@ -620,7 +615,6 @@ func TestStreamableServerTransport(t *testing.T) {
620615
resp(1, &ListRootsResult{}, nil),
621616
},
622617
wantStatusCode: http.StatusAccepted,
623-
wantSessionID: false,
624618
},
625619
{
626620
method: "POST",
@@ -632,7 +626,6 @@ func TestStreamableServerTransport(t *testing.T) {
632626
req(1, "roots/list", &ListRootsParams{}),
633627
resp(2, &CallToolResult{}, nil),
634628
},
635-
wantSessionID: true,
636629
},
637630
},
638631
},
@@ -663,7 +656,6 @@ func TestStreamableServerTransport(t *testing.T) {
663656
resp(1, &ListRootsResult{}, nil),
664657
},
665658
wantStatusCode: http.StatusAccepted,
666-
wantSessionID: false,
667659
},
668660
{
669661
method: "GET",
@@ -674,7 +666,6 @@ func TestStreamableServerTransport(t *testing.T) {
674666
req(0, "notifications/progress", &ProgressNotificationParams{}),
675667
req(1, "roots/list", &ListRootsParams{}),
676668
},
677-
wantSessionID: true,
678669
},
679670
{
680671
method: "POST",
@@ -685,7 +676,6 @@ func TestStreamableServerTransport(t *testing.T) {
685676
wantMessages: []jsonrpc.Message{
686677
resp(2, &CallToolResult{}, nil),
687678
},
688-
wantSessionID: true,
689679
},
690680
{
691681
method: "DELETE",
@@ -724,7 +714,6 @@ func TestStreamableServerTransport(t *testing.T) {
724714
wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{
725715
Message: `method "tools/call" is invalid during session initialization`,
726716
})},
727-
wantSessionID: true, // TODO: this is probably wrong; we don't have a valid session
728717
},
729718
},
730719
},
@@ -951,7 +940,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string,
951940
return "", 0, nil, fmt.Errorf("creating request: %w", err)
952941
}
953942
if sessionID != "" {
954-
req.Header.Set("Mcp-Session-Id", sessionID)
943+
req.Header.Set(sessionIDHeader, sessionID)
955944
}
956945
req.Header.Set("Content-Type", "application/json")
957946
req.Header.Set("Accept", "application/json, text/event-stream")
@@ -963,7 +952,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string,
963952
}
964953
defer resp.Body.Close()
965954

966-
newSessionID := resp.Header.Get("Mcp-Session-Id")
955+
newSessionID := resp.Header.Get(sessionIDHeader)
967956

968957
contentType := resp.Header.Get("Content-Type")
969958
var respBody []byte
@@ -1073,7 +1062,7 @@ func TestStreamableClientTransport(t *testing.T) {
10731062
if err != nil {
10741063
t.Errorf("encoding failed: %v", err)
10751064
}
1076-
w.Header().Set("Mcp-Session-Id", "123")
1065+
w.Header().Set(sessionIDHeader, "123")
10771066
w.Write(data)
10781067
} else {
10791068
if v := r.Header.Get(protocolVersionHeader); v != latestProtocolVersion {
@@ -1150,6 +1139,15 @@ func TestEventID(t *testing.T) {
11501139
}
11511140

11521141
func TestStreamableStateless(t *testing.T) {
1142+
initReq := req(1, methodInitialize, &InitializeParams{})
1143+
initResp := resp(1, &InitializeResult{
1144+
Capabilities: &ServerCapabilities{
1145+
Logging: &LoggingCapabilities{},
1146+
Tools: &ToolCapabilities{ListChanged: true},
1147+
},
1148+
ProtocolVersion: latestProtocolVersion,
1149+
ServerInfo: &Implementation{Name: "test", Version: "v1.0.0"},
1150+
}, nil)
11531151
// This version of sayHi expects
11541152
// that request from our client).
11551153
sayHi := func(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) {
@@ -1163,17 +1161,22 @@ func TestStreamableStateless(t *testing.T) {
11631161
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
11641162

11651163
requests := []streamableRequest{
1164+
{
1165+
method: "POST",
1166+
messages: []jsonrpc.Message{initReq},
1167+
wantStatusCode: http.StatusOK,
1168+
wantMessages: []jsonrpc.Message{initResp},
1169+
wantSessionID: false, // sessionless
1170+
},
11661171
{
11671172
method: "POST",
11681173
wantStatusCode: http.StatusOK,
11691174
messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})},
11701175
wantBodyContaining: "greet",
1171-
wantSessionID: false,
11721176
},
11731177
{
11741178
method: "GET",
11751179
wantStatusCode: http.StatusMethodNotAllowed,
1176-
wantSessionID: false,
11771180
},
11781181
{
11791182
method: "POST",
@@ -1187,7 +1190,6 @@ func TestStreamableStateless(t *testing.T) {
11871190
StructuredContent: json.RawMessage("null"),
11881191
}, nil),
11891192
},
1190-
wantSessionID: false,
11911193
},
11921194
{
11931195
method: "POST",
@@ -1201,7 +1203,6 @@ func TestStreamableStateless(t *testing.T) {
12011203
StructuredContent: json.RawMessage("null"),
12021204
}, nil),
12031205
},
1204-
wantSessionID: false,
12051206
},
12061207
}
12071208

@@ -1237,13 +1238,7 @@ func TestStreamableStateless(t *testing.T) {
12371238
//
12381239
// This can be used by tools to look up application state preserved across
12391240
// subsequent requests.
1240-
for i, req := range requests {
1241-
// Now, we want a session for all (valid) requests.
1242-
if req.wantStatusCode != http.StatusMethodNotAllowed {
1243-
req.wantSessionID = true
1244-
}
1245-
requests[i] = req
1246-
}
1241+
requests[0].wantSessionID = true // now expect a session ID for initialize
12471242
statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
12481243
Stateless: true,
12491244
})

0 commit comments

Comments
 (0)