diff --git a/mcp/server.go b/mcp/server.go index 29be8ff1..4a7bc89a 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1145,7 +1145,13 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, return handleReceive(ctx, ss, req) } -func (ss *ServerSession) InitializeParams() *InitializeParams { return ss.state.InitializeParams } +// InitializeParams returns the InitializeParams provided during the client's +// initial connection. +func (ss *ServerSession) InitializeParams() *InitializeParams { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializeParams +} func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { if params == nil { diff --git a/mcp/streamable.go b/mcp/streamable.go index d8ce45e7..4e3c1156 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -324,10 +324,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque logger: h.opts.Logger, } + // Sessions without a session ID are also stateless: there's no way to + // address them. + stateless := h.opts.Stateless || sessionID == "" // To support stateless mode, we initialize the session with a default // state, so that it doesn't reject subsequent requests. var connectOpts *ServerSessionOptions - if h.opts.Stateless { + if stateless { // Peek at the body to see if it is initialize or initialized. // We want those to be handled as usual. var hasInitialize, hasInitialized bool @@ -405,7 +408,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque transport: transport, } - if h.opts.Stateless { + if stateless { // Stateless mode: close the session when the request exits. defer session.Close() // close the fake session after handling the request } else { @@ -424,6 +427,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.mu.Lock() h.sessions[transport.SessionID] = sessInfo h.mu.Unlock() + defer func() { + // If initialization failed, clean up the session (#578). + if session.InitializeParams() == nil { + // Initialization failed. + session.Close() + } + }() } } diff --git a/mcp/streamable_bench_test.go b/mcp/streamable_bench_test.go index bf82ee51..6157e980 100644 --- a/mcp/streamable_bench_test.go +++ b/mcp/streamable_bench_test.go @@ -6,9 +6,15 @@ package mcp_test import ( "context" + "flag" + "log" "net/http" "net/http/httptest" + "os" "reflect" + "runtime" + "runtime/pprof" + "strings" "testing" "github.com/google/jsonschema-go/jsonschema" @@ -65,3 +71,62 @@ func BenchmarkStreamableServing(b *testing.B) { } } } + +var streamableHeap = flag.String("streamable_heap", "", "if set, write streamable heap profiles with this prefix") + +func BenchmarkStreamableServing_BadSessions(b *testing.B) { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{JSONResponse: true}) + + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if *streamableHeap != "" { + writeHeap := func(file string) { + // GC a couple times to ensure accurate heap. + runtime.GC() + runtime.GC() + f, err := os.Create(file) + if err != nil { + log.Fatal("could not create memory profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + b.Errorf("writing heap file %q: %v", file, err) + } + }() + if err := pprof.Lookup("heap").WriteTo(f, 0); err != nil { + b.Errorf("could not write heap profile: %v", err) + } + } + writeHeap(*streamableHeap + ".before") + defer writeHeap(*streamableHeap + ".after") + } + + b.ResetTimer() + for range b.N { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, strings.NewReader("{}")) + if err != nil { + b.Fatal(err) + } + req.Header.Add("Accept", "application/json") + req.Header.Add("Accept", "text/event-stream") + resp, err := http.DefaultClient.Do(req) + if err != nil { + b.Fatal(err) + } + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { + b.Fatalf("POST got status %d, want %d", got, want) + } + if got := resp.Header.Get("Mcp-Session-Id"); got != "" { + b.Fatalf("POST got unexpected session ID") + } + resp.Body.Close() + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0f38a0f4..fa46a5b5 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -697,10 +697,11 @@ func TestStreamableServerTransport(t *testing.T) { } tests := []struct { - name string - replay bool // if set, use a MemoryEventStore to enable stream replay - tool func(*testing.T, context.Context, *ServerSession) - requests []streamableRequest // http requests + name string + replay bool // if set, use a MemoryEventStore to enable stream replay + tool func(*testing.T, context.Context, *ServerSession) + requests []streamableRequest // http requests + wantSessions int // number of sessions expected after the test }{ { name: "basic", @@ -714,6 +715,19 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, }, }, + wantSessions: 1, + }, + { + name: "uninitialized", + requests: []streamableRequest{ + { + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantBodyContaining: "invalid during session initialization", + }, + }, + wantSessions: 0, }, { name: "accept headers", @@ -748,6 +762,7 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, }, }, + wantSessions: 1, }, { name: "protocol version headers", @@ -763,6 +778,7 @@ func TestStreamableServerTransport(t *testing.T) { wantSessionID: false, // could be true, but shouldn't matter }, }, + wantSessions: 1, }, { name: "batch rejected on 2025-06-18", @@ -782,6 +798,7 @@ func TestStreamableServerTransport(t *testing.T) { wantBodyContaining: "batch", }, }, + wantSessions: 1, }, { name: "batch accepted on 2025-03-26", @@ -804,6 +821,7 @@ func TestStreamableServerTransport(t *testing.T) { }, }, }, + wantSessions: 1, }, { name: "tool notification", @@ -828,6 +846,7 @@ func TestStreamableServerTransport(t *testing.T) { }, }, }, + wantSessions: 1, }, { name: "tool upcall", @@ -860,6 +879,7 @@ func TestStreamableServerTransport(t *testing.T) { }, }, }, + wantSessions: 1, }, { name: "background", @@ -922,6 +942,7 @@ func TestStreamableServerTransport(t *testing.T) { headers: map[string][]string{"Accept": nil}, }, }, + wantSessions: 0, // session deleted }, { name: "errors", @@ -953,6 +974,7 @@ func TestStreamableServerTransport(t *testing.T) { })}, }, }, + wantSessions: 0, }, } @@ -979,6 +1001,9 @@ func TestStreamableServerTransport(t *testing.T) { defer handler.closeAll() testStreamableHandler(t, handler, test.requests) + if got := len(slices.Collect(server.Sessions())); got != test.wantSessions { + t.Errorf("after test, got %d sessions, want %d", got, test.wantSessions) + } }) } }