diff --git a/mcp/streamable.go b/mcp/streamable.go index 25efe31a..1eef9a74 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -424,7 +424,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er // It is always text/event-stream, since it must carry arbitrarily many // messages. var err error - t.connection.streams[""], err = t.connection.newStream(ctx, "", false) + t.connection.streams[""], err = t.connection.newStream(ctx, "", false, false) if err != nil { return nil, err } @@ -485,6 +485,10 @@ type stream struct { // an empty string is used for messages that don't correlate with an incoming request. id StreamID + // If isInitialize is set, the stream is in response to an initialize request, + // and therefore should include the session ID header. + isInitialize bool + // jsonResponse records whether this stream should respond with application/json // instead of text/event-stream. // @@ -513,12 +517,13 @@ type stream struct { requests map[jsonrpc.ID]struct{} } -func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) { +func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, isInitialize, jsonResponse bool) (*stream, error) { if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { return nil, err } return &stream{ id: id, + isInitialize: isInitialize, jsonResponse: jsonResponse, requests: make(map[jsonrpc.ID]struct{}), }, nil @@ -647,6 +652,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } requests := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) + isInitialize := false for _, msg := range incoming { if jreq, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -656,6 +662,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } + if jreq.Method == methodInitialize { + isInitialize = true + } jreq.Extra = &RequestExtra{ TokenInfo: tokenInfo, Header: req.Header, @@ -672,7 +681,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse) + stream, err = c.newStream(req.Context(), StreamID(randText()), isInitialize, c.jsonResponse) if err != nil { http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) return @@ -708,7 +717,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "application/json") - if c.sessionID != "" { + if c.sessionID != "" && stream.isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) } @@ -747,7 +756,7 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") - if c.sessionID != "" { + if c.sessionID != "" && stream.isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) } if persistent { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2963a04d..3f897ba0 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -133,7 +133,7 @@ func TestStreamableTransports(t *testing.T) { defer session.Close() sid := session.ID() if sid == "" { - t.Error("empty session ID") + t.Fatalf("empty session ID") } if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w { t.Fatalf("got protocol version %q, want %q", g, w) @@ -475,6 +475,8 @@ func resp(id int64, result any, err error) *jsonrpc.Response { } } +var () + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP @@ -502,7 +504,6 @@ func TestStreamableServerTransport(t *testing.T) { method: "POST", messages: []jsonrpc.Message{initializedMsg}, wantStatusCode: http.StatusAccepted, - wantSessionID: false, // TODO: should this be true? } tests := []struct { @@ -520,7 +521,6 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, - wantSessionID: true, }, }, }, @@ -535,14 +535,12 @@ func TestStreamableServerTransport(t *testing.T) { headers: http.Header{"Accept": {"text/plain", "application/*"}}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing text/event-stream - wantSessionID: false, }, { method: "POST", headers: http.Header{"Accept": {"text/event-stream"}}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing application/json - wantSessionID: false, }, { method: "POST", @@ -550,7 +548,6 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, - wantSessionID: true, }, { method: "POST", @@ -558,7 +555,6 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, - wantSessionID: true, }, }, }, @@ -598,7 +594,6 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, - wantSessionID: true, }, }, }, @@ -620,7 +615,6 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, - wantSessionID: false, }, { method: "POST", @@ -632,7 +626,6 @@ func TestStreamableServerTransport(t *testing.T) { req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, - wantSessionID: true, }, }, }, @@ -663,7 +656,6 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, - wantSessionID: false, }, { method: "GET", @@ -674,7 +666,6 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, - wantSessionID: true, }, { method: "POST", @@ -685,7 +676,6 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, - wantSessionID: true, }, { method: "DELETE", @@ -724,7 +714,6 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, - wantSessionID: true, // TODO: this is probably wrong; we don't have a valid session }, }, }, @@ -951,7 +940,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, return "", 0, nil, fmt.Errorf("creating request: %w", err) } if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) + req.Header.Set(sessionIDHeader, sessionID) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") @@ -963,7 +952,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, } defer resp.Body.Close() - newSessionID := resp.Header.Get("Mcp-Session-Id") + newSessionID := resp.Header.Get(sessionIDHeader) contentType := resp.Header.Get("Content-Type") var respBody []byte @@ -1079,6 +1068,15 @@ func TestEventID(t *testing.T) { } func TestStreamableStateless(t *testing.T) { + initReq := req(1, methodInitialize, &InitializeParams{}) + initResp := resp(1, &InitializeResult{ + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "test", Version: "v1.0.0"}, + }, nil) // This version of sayHi expects // that request from our client). sayHi := func(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) { @@ -1092,17 +1090,22 @@ func TestStreamableStateless(t *testing.T) { AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) requests := []streamableRequest{ + { + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: false, // sessionless + }, { method: "POST", wantStatusCode: http.StatusOK, messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, wantBodyContaining: "greet", - wantSessionID: false, }, { method: "GET", wantStatusCode: http.StatusMethodNotAllowed, - wantSessionID: false, }, { method: "POST", @@ -1116,7 +1119,6 @@ func TestStreamableStateless(t *testing.T) { StructuredContent: json.RawMessage("null"), }, nil), }, - wantSessionID: false, }, { method: "POST", @@ -1130,7 +1132,6 @@ func TestStreamableStateless(t *testing.T) { StructuredContent: json.RawMessage("null"), }, nil), }, - wantSessionID: false, }, } @@ -1166,13 +1167,7 @@ func TestStreamableStateless(t *testing.T) { // // This can be used by tools to look up application state preserved across // subsequent requests. - for i, req := range requests { - // Now, we want a session for all (valid) requests. - if req.wantStatusCode != http.StatusMethodNotAllowed { - req.wantSessionID = true - } - requests[i] = req - } + requests[0].wantSessionID = true // now expect a session ID for initialize statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ Stateless: true, })