From eb05386484a8a56431fbd2130f55ca025be9f0ad Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 14 Aug 2025 23:49:25 +0000 Subject: [PATCH 1/3] mcp: factor out reusable functions from TestStreamableServerTransport The next CL will test stateless and distributable server transport configurations, using the HTTP testing strategy of TestStreamableServerTransport. --- mcp/server.go | 2 + mcp/streamable_test.go | 367 ++++++++++++++++++++--------------------- 2 files changed, 182 insertions(+), 187 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index c3fbd9e3..88021336 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -355,6 +355,8 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam if !ok { return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name) } + // TODO: if handler returns nil content, it will serialize as null. + // Add a test and fix. return st.handler(ctx, req) } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 52c47998..7c77938e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -382,30 +382,31 @@ func readNotifications(t *testing.T, ctx context.Context, notifications chan str } } +// JSON-RPC message constructors. +func req(id int64, method string, params any) *jsonrpc.Request { + r := &jsonrpc.Request{ + Method: method, + Params: mustMarshal(params), + } + if id > 0 { + r.ID = jsonrpc2.Int64ID(id) + } + return r +} + +func resp(id int64, result any, err error) *jsonrpc.Response { + return &jsonrpc.Response{ + ID: jsonrpc2.Int64ID(id), + Result: mustMarshal(result), + Error: err, + } +} + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP // requests. - // JSON-RPC message constructors. - req := func(id int64, method string, params any) *jsonrpc.Request { - r := &jsonrpc.Request{ - Method: method, - Params: mustMarshal(t, params), - } - if id > 0 { - r.ID = jsonrpc2.Int64ID(id) - } - return r - } - resp := func(id int64, result any, err error) *jsonrpc.Response { - return &jsonrpc.Response{ - ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(t, result), - Error: err, - } - } - // Predefined steps, to avoid repetition below. initReq := req(1, methodInitialize, &InitializeParams{}) initResp := resp(1, &InitializeResult{ @@ -422,21 +423,23 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{initReq}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, } initialized := streamableRequest{ method: "POST", messages: []jsonrpc.Message{initializedMsg}, wantStatusCode: http.StatusAccepted, + wantSessionID: false, // TODO: should this be true? } tests := []struct { - name string - tool func(*testing.T, context.Context, *ServerSession) - steps []streamableRequest + name string + tool func(*testing.T, context.Context, *ServerSession) + requests []streamableRequest // http requests }{ { name: "basic", - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -444,12 +447,13 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, + wantSessionID: true, }, }, }, { name: "accept headers", - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, // Test various accept headers. @@ -458,12 +462,14 @@ func TestStreamableServerTransport(t *testing.T) { accept: []string{"text/plain", "application/*"}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing text/event-stream + wantSessionID: false, }, { method: "POST", accept: []string{"text/event-stream"}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing application/json + wantSessionID: false, }, { method: "POST", @@ -471,6 +477,7 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + wantSessionID: true, }, { method: "POST", @@ -478,6 +485,7 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + wantSessionID: true, }, }, }, @@ -489,7 +497,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -502,6 +510,7 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, @@ -513,7 +522,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Call failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -523,6 +532,7 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, + wantSessionID: false, }, { method: "POST", @@ -534,6 +544,7 @@ func TestStreamableServerTransport(t *testing.T) { req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, @@ -554,7 +565,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -564,6 +575,7 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, + wantSessionID: false, }, { method: "GET", @@ -574,6 +586,7 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, + wantSessionID: true, }, { method: "POST", @@ -584,12 +597,13 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, { name: "errors", - steps: []streamableRequest{ + requests: []streamableRequest{ { method: "PUT", wantStatusCode: http.StatusMethodNotAllowed, @@ -615,6 +629,7 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, + wantSessionID: true, // TODO: this is probably wrong; we don't have a valid session }, }, }, @@ -636,118 +651,127 @@ func TestStreamableServerTransport(t *testing.T) { handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) defer handler.closeAll() - httpServer := httptest.NewServer(handler) - defer httpServer.Close() + testStreamableHandler(t, handler, test.requests) + }) + } +} - // blocks records request blocks by jsonrpc. ID. - // - // When an OnRequest step is encountered, it waits on the corresponding - // block. When a request with that ID is received, the block is closed. - var mu sync.Mutex - blocks := make(map[int64]chan struct{}) - for _, step := range test.steps { - if step.onRequest > 0 { - blocks[step.onRequest] = make(chan struct{}) - } - } +func testStreamableHandler(t *testing.T, handler http.Handler, requests []streamableRequest) { + httpServer := httptest.NewServer(handler) + defer httpServer.Close() - // signal when all synchronous requests have executed, so we can fail - // async requests that are blocked. - syncRequestsDone := make(chan struct{}) + // blocks records request blocks by jsonrpc. ID. + // + // When an OnRequest step is encountered, it waits on the corresponding + // block. When a request with that ID is received, the block is closed. + var mu sync.Mutex + blocks := make(map[int64]chan struct{}) + for _, req := range requests { + if req.onRequest > 0 { + blocks[req.onRequest] = make(chan struct{}) + } + } - // To avoid complicated accounting for session ID, just set the first - // non-empty session ID from a response. - var sessionID atomic.Value - sessionID.Store("") + // signal when all synchronous requests have executed, so we can fail + // async requests that are blocked. + syncRequestsDone := make(chan struct{}) + + // To avoid complicated accounting for session ID, just set the first + // non-empty session ID from a response. + var sessionID atomic.Value + sessionID.Store("") + + // doStep executes a single step. + doStep := func(t *testing.T, i int, request streamableRequest) { + if request.onRequest > 0 { + // Block the step until we've received the server->client request. + mu.Lock() + block := blocks[request.onRequest] + mu.Unlock() + select { + case <-block: + case <-syncRequestsDone: + t.Errorf("after all sync requests are complete, request still blocked on %d", request.onRequest) + return + } + } - // doStep executes a single step. - doStep := func(t *testing.T, step streamableRequest) { - if step.onRequest > 0 { - // Block the step until we've received the server->client request. + // Collect messages received during this request, unblock other steps + // when requests are received. + var got []jsonrpc.Message + out := make(chan jsonrpc.Message) + // Cancel the step if we encounter a request that isn't going to be + // handled. + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + for m := range out { + if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { + // Encountered a server->client request. We should have a + // response queued. Otherwise, we may deadlock. mu.Lock() - block := blocks[step.onRequest] - mu.Unlock() - select { - case <-block: - case <-syncRequestsDone: - t.Errorf("after all sync requests are complete, request still blocked on %d", step.onRequest) - return + if block, ok := blocks[req.ID.Raw().(int64)]; ok { + close(block) + } else { + t.Errorf("no queued response for %v", req.ID) + cancel() } + mu.Unlock() } + got = append(got, m) + if request.closeAfter > 0 && len(got) == request.closeAfter { + cancel() + } + } + }() - // Collect messages received during this request, unblock other steps - // when requests are received. - var got []jsonrpc.Message - out := make(chan jsonrpc.Message) - // Cancel the step if we encounter a request that isn't going to be - // handled. - ctx, cancel := context.WithCancel(context.Background()) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - for m := range out { - if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { - // Encountered a server->client request. We should have a - // response queued. Otherwise, we may deadlock. - mu.Lock() - if block, ok := blocks[req.ID.Raw().(int64)]; ok { - close(block) - } else { - t.Errorf("no queued response for %v", req.ID) - cancel() - } - mu.Unlock() - } - got = append(got, m) - if step.closeAfter > 0 && len(got) == step.closeAfter { - cancel() - } - } - }() - - gotSessionID, gotStatusCode, err := step.do(ctx, httpServer.URL, sessionID.Load().(string), out) + gotSessionID, gotStatusCode, err := request.do(ctx, httpServer.URL, sessionID.Load().(string), out) - // Don't fail on cancelled requests: error (if any) is handled - // elsewhere. - if err != nil && ctx.Err() == nil { - t.Fatal(err) - } + // Don't fail on cancelled requests: error (if any) is handled + // elsewhere. + if err != nil && ctx.Err() == nil { + t.Fatal(err) + } - if gotStatusCode != step.wantStatusCode { - t.Errorf("got status %d, want %d", gotStatusCode, step.wantStatusCode) - } - wg.Wait() + if gotStatusCode != request.wantStatusCode { + t.Errorf("request #%d: got status %d, want %d", i, gotStatusCode, request.wantStatusCode) + } + if got := gotSessionID != ""; got != request.wantSessionID { + t.Errorf("request #%d: got session id: %t, want %t", i, got, request.wantSessionID) + } + wg.Wait() - transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) - if diff := cmp.Diff(step.wantMessages, got, transform); diff != "" { - t.Errorf("received unexpected messages (-want +got):\n%s", diff) - } - sessionID.CompareAndSwap("", gotSessionID) + if !request.ignoreResponse { + transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) + if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { + t.Errorf("received unexpected messages (-want +got):\n%s", diff) } + } + sessionID.CompareAndSwap("", gotSessionID) + } - var wg sync.WaitGroup - for _, step := range test.steps { - if step.async || step.onRequest > 0 { - wg.Add(1) - go func() { - defer wg.Done() - doStep(t, step) - }() - } else { - doStep(t, step) - } - } + var wg sync.WaitGroup + for i, request := range requests { + if request.async || request.onRequest > 0 { + wg.Add(1) + go func() { + defer wg.Done() + doStep(t, i, request) + }() + } else { + doStep(t, i, request) + } + } - // Fail any blocked responses if they weren't needed by a synchronous - // request. - close(syncRequestsDone) + // Fail any blocked responses if they weren't needed by a synchronous + // request. + close(syncRequestsDone) - wg.Wait() - }) - } + wg.Wait() } // A streamableRequest describes a single streamable HTTP request, consisting @@ -768,13 +792,15 @@ type streamableRequest struct { async bool // Request attributes - method string // HTTP request method + method string // HTTP request method (required) accept []string // if non-empty, the Accept header to use; otherwise the default header is used messages []jsonrpc.Message // messages to send closeAfter int // if nonzero, close after receiving this many messages wantStatusCode int // expected status code + ignoreResponse bool // if set, don't check the response messages wantMessages []jsonrpc.Message // expected messages to receive + wantSessionID bool // whether or not a session ID is expected in the response } // streamingRequest makes a request to the given streamable server with the @@ -840,7 +866,8 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, newSessionID := resp.Header.Get("Mcp-Session-Id") - if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") { + contentType := resp.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "text/event-stream") { for evt, err := range scanEvents(resp.Body) { if err != nil { return newSessionID, resp.StatusCode, fmt.Errorf("reading events: %v", err) @@ -853,7 +880,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, } out <- msg } - } else if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + } else if strings.HasPrefix(contentType, "application/json") { data, err := io.ReadAll(resp.Body) if err != nil { return newSessionID, resp.StatusCode, fmt.Errorf("reading json body: %w", err) @@ -868,14 +895,13 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, return newSessionID, resp.StatusCode, nil } -func mustMarshal(t *testing.T, v any) json.RawMessage { +func mustMarshal(v any) json.RawMessage { if v == nil { return nil } - t.Helper() data, err := json.Marshal(v) if err != nil { - t.Fatal(err) + panic(err) } return data } @@ -886,7 +912,7 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { resp := func(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(t, result), + Result: mustMarshal(result), Error: err, } } @@ -970,13 +996,10 @@ func TestEventID(t *testing.T) { } func TestStreamableStateless(t *testing.T) { - // Test stateless mode behavior - ctx := context.Background() - // This version of sayHi doesn't make a ping request (we can't respond to // that request from our client). - sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil + sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) @@ -985,57 +1008,27 @@ func TestStreamableStateless(t *testing.T) { handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ GetSessionID: func() string { return "" }, }) - httpServer := httptest.NewServer(handler) - defer httpServer.Close() - checkRequest := func(body string) { - // Verify we can call tools/list directly without initialization in stateless mode - req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, strings.NewReader(body)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - // Verify that no session ID header is returned in stateless mode - if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { - t.Errorf("%s = %s, want no session ID header", sessionIDHeader, sessionID) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("Status code = %d; want successful response", resp.StatusCode) - } - - var events []Event - for event, err := range scanEvents(resp.Body) { - if err != nil { - t.Fatal(err) - } - events = append(events, event) - } - if len(events) != 1 { - t.Fatalf("got %d SSE events, want 1; events: %v", len(events), events) - } - msg, err := jsonrpc.DecodeMessage(events[0].Data) - if err != nil { - t.Fatal(err) - } - jsonResp, ok := msg.(*jsonrpc.Response) - if !ok { - t.Errorf("event is %T, want response", jsonResp) - } - if jsonResp.Error != nil { - t.Errorf("request failed: %v", jsonResp.Error) - } + requests := []streamableRequest{ + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, + ignoreResponse: true, + wantSessionID: false, + }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{ + req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "World"}}), + }, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi World"}}}, nil), + }, + wantSessionID: false, + }, } - checkRequest(`{"jsonrpc":"2.0","method":"tools/list","id":1,"params":{}}`) - - // Verify we can make another request without session ID - checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`) + testStreamableHandler(t, handler, requests) } From a4deb28c2a728620ef46213b535c36a43a051096 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 15 Aug 2025 03:13:53 +0000 Subject: [PATCH 2/3] mcp: improvements for 'stateless' streamable servers; 'distributed' mode Several improvements for the stateless streamable mode, plus support for a 'distributed' (or rather, distributable) version of the stateless server. - Add a 'Stateless' option to StreamableHTTPOptions and StreamableServerTransport, which controls stateless behavior. GetSessionID may still return a non-empty session ID. - Audit validation of stateless mode to allow requests with a session id. Propagate this session ID to the temporary connection. - Peek at requests to allow 'initialize' requests to go through to the session, so that version negotiation can occur (FIXME: add tests). Fixes #284 For #148 --- internal/jsonrpc2/conn.go | 18 +++++- internal/jsonrpc2/wire.go | 11 ++++ mcp/streamable.go | 125 ++++++++++++++++++++++++++++++-------- mcp/streamable_test.go | 82 ++++++++++++++++++++++--- mcp/transport.go | 4 +- 5 files changed, 203 insertions(+), 37 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 6bacfa7e..537be47a 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -739,7 +739,23 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e // write is used by all things that write outgoing messages, including replies. // it makes sure that writes are atomic func (c *Connection) write(ctx context.Context, msg Message) error { - err := c.writer.Write(ctx, msg) + var err error + // Fail writes immediately if the connection is shutting down. + // + // TODO(rfindley): should we allow cancellation notifations through? It could + // be the case that writes can still succeed. + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrServerClosing) + }) + if err == nil { + err = c.writer.Write(ctx, msg) + } + + // For rejected requests, we don't set the writeErr (which would break the + // connection). They can just be returned to the caller. + if errors.Is(err, ErrRejected) { + return err + } if err != nil && ctx.Err() == nil { // The call to Write failed, and since ctx.Err() is nil we can't attribute diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index b143dcd3..8be2872e 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -37,6 +37,17 @@ var ( ErrServerClosing = NewError(-32004, "server is closing") // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. ErrClientClosing = NewError(-32003, "client is closing") + + // The following errors have special semantics for MCP transports + + // ErrRejected may be wrapped to return errors from calls to Writer.Write + // that signal that the request was rejected by the transport layer as + // invalid. + // + // Such failures do not indicate that the connection is broken, but rather + // should be returned to the caller to indicate that the specific request is + // invalid in the current context. + ErrRejected = NewError(-32004, "rejected by transport") ) const wireVersion = "2.0" diff --git a/mcp/streamable.go b/mcp/streamable.go index 526ee515..789c7a0b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -38,18 +38,29 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - mu sync.Mutex + mu sync.Mutex + // TODO: we should store the ServerSession along with the transport, because + // we need to cancel keepalive requests when closing the transport. transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } // StreamableHTTPOptions configures the StreamableHTTPHandler. type StreamableHTTPOptions struct { // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. // - // If GetSessionID returns an empty string, the session is 'stateless', - // meaning it is not persisted and no session validation is performed. + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. GetSessionID func() string + // Stateless controls whether the session is 'stateless'. + // + // A stateless server does not validate the Mcp-Session-Id header, and uses a + // temporary session with default initialization parameters. Any + // server->client request is rejected immediately as there's no way for the + // client to respond. + Stateless bool + // TODO: support session retention (?) // jsonResponse is forwarded to StreamableServerTransport.jsonResponse. @@ -118,36 +129,39 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } + sessionID := req.Header.Get(sessionIDHeader) var transport *StreamableServerTransport - if id := req.Header.Get(sessionIDHeader); id != "" { + if sessionID != "" { h.mu.Lock() - transport = h.transports[id] + transport, _ = h.transports[sessionID] h.mu.Unlock() - if transport == nil { + if transport == nil && !h.opts.Stateless { + // In stateless mode we allow a missing transport. + // + // A synthetic transport will be created below for the transient session. http.Error(w, "session not found", http.StatusNotFound) return } } - // TODO(rfindley): simplify the locking so that each request has only one - // critical section. if req.Method == http.MethodDelete { - if transport == nil { - // => Mcp-Session-Id was not set; else we'd have returned NotFound above. + if sessionID == "" { http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } - h.mu.Lock() - delete(h.transports, transport.SessionID) - h.mu.Unlock() - transport.connection.Close() + if transport != nil { // transport may be nil in stateless mode + h.mu.Lock() + delete(h.transports, transport.SessionID) + h.mu.Unlock() + transport.connection.Close() + } w.WriteHeader(http.StatusNoContent) return } switch req.Method { case http.MethodPost, http.MethodGet: - if req.Method == http.MethodGet && transport == nil { + if req.Method == http.MethodGet && sessionID == "" { http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) return } @@ -164,37 +178,83 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "no server available", http.StatusBadRequest) return } - sessionID := h.opts.GetSessionID() - s := &StreamableServerTransport{SessionID: sessionID, jsonResponse: h.opts.jsonResponse} + if sessionID == "" { + // In stateless mode, sessionID may be nonempty even if there's no + // existing transport. + sessionID = h.opts.GetSessionID() + } + transport = &StreamableServerTransport{ + SessionID: sessionID, + Stateless: h.opts.Stateless, + jsonResponse: h.opts.jsonResponse, + } // To support stateless mode, we initialize the session with a default // state, so that it doesn't reject subsequent requests. var connectOpts *ServerSessionOptions - if sessionID == "" { + if h.opts.Stateless { + // Peek at the body to see if it is initialize or initialized. + // We want those to be handled as usual. + var hasInitialize, hasInitialized bool + { + // TODO: verify that this allows protocol version negotiation for + // stateless servers. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + req.Body.Close() + + // Reset the body so that it can be read later. + req.Body = io.NopCloser(bytes.NewBuffer(body)) + + msgs, _, err := readBatch(body) + if err == nil { + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + switch req.Method { + case methodInitialize: + hasInitialize = true + case notificationInitialized: + hasInitialized = true + } + } + } + } + } + + // If we don't have InitializeParams or InitializedParams in the request, + // set the initial state to a default value. + state := new(ServerSessionState) + if !hasInitialize { + state.InitializeParams = new(InitializeParams) + } + if !hasInitialized { + state.InitializedParams = new(InitializedParams) + } connectOpts = &ServerSessionOptions{ - State: &ServerSessionState{ - InitializeParams: new(InitializeParams), - InitializedParams: new(InitializedParams), - }, + State: state, } } + // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - ss, err := server.Connect(req.Context(), s, connectOpts) + ss, err := server.Connect(req.Context(), transport, connectOpts) if err != nil { http.Error(w, "failed connection", http.StatusInternalServerError) return } - if sessionID == "" { + if h.opts.Stateless { // Stateless mode: close the session when the request exits. defer ss.Close() // close the fake session after handling the request } else { + // Otherwise, save the transport so that it can be reused h.mu.Lock() - h.transports[s.SessionID] = s + h.transports[transport.SessionID] = transport h.mu.Unlock() } - transport = s } transport.ServeHTTP(w, req) @@ -225,6 +285,13 @@ type StreamableServerTransport struct { // generator to produce one, as with [crypto/rand.Text].) SessionID string + // Stateless controls whether the eventstore is 'Stateless'. Servers sessions + // connected to a stateless transport are disallowed from making outgoing + // requests. + // + // See also [StreamableHTTPOptions.Stateless]. + Stateless bool + // Storage for events, to enable stream resumption. // If nil, a [MemoryEventStore] with the default maximum size will be used. EventStore EventStore @@ -265,6 +332,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) } t.connection = &streamableServerConn{ sessionID: t.SessionID, + stateless: t.Stateless, eventStore: t.EventStore, jsonResponse: t.jsonResponse, incoming: make(chan jsonrpc.Message, 10), @@ -285,6 +353,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) type streamableServerConn struct { sessionID string + stateless bool jsonResponse bool eventStore EventStore @@ -755,6 +824,10 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error // Write implements the [Connection] interface. func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") { + // Requests aren't possible with stateless servers, or when there's no session ID. + return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) + } // Find the incoming request that this write relates to, if any. var forRequest jsonrpc.ID isResponse := false diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 7c77938e..8334bc0d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -748,7 +748,7 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream if !request.ignoreResponse { transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { - t.Errorf("received unexpected messages (-want +got):\n%s", diff) + t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff) } } sessionID.CompareAndSwap("", gotSessionID) @@ -996,19 +996,18 @@ func TestEventID(t *testing.T) { } func TestStreamableStateless(t *testing.T) { - // This version of sayHi doesn't make a ping request (we can't respond to + // This version of sayHi expects // that request from our client). sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) { + if err := req.Session.Ping(ctx, nil); err == nil { + // ping should fail, but not break the connection + t.Errorf("ping succeeded unexpectedly") + } return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) - // Test stateless mode. - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ - GetSessionID: func() string { return "" }, - }) - requests := []streamableRequest{ { method: "POST", @@ -1028,7 +1027,74 @@ func TestStreamableStateless(t *testing.T) { }, wantSessionID: false, }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{ + req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "foo"}}), + }, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi foo"}}}, nil), + }, + wantSessionID: false, + }, + } + + testClientCompatibility := func(t *testing.T, handler http.Handler) { + ctx := context.Background() + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatal(err) + } + res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}}) + if err != nil { + t.Fatal(err) + } + if got, want := textContent(t, res), "hi bar"; got != want { + t.Errorf("Result = %q, want %q", got, want) + } } - testStreamableHandler(t, handler, requests) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + GetSessionID: func() string { return "" }, + Stateless: true, + }) + + // Test the default stateless mode. + t.Run("stateless", func(t *testing.T) { + testStreamableHandler(t, handler, requests) + testClientCompatibility(t, handler) + }) + + // Test a "distributed" variant of stateless mode, where it has non-empty + // session IDs, but is otherwise stateless. + // + // This can be used by tools to look up application state preserved across + // subsequent requests. + for i, req := range requests { + // Now, we want a session for all requests. + req.wantSessionID = true + requests[i] = req + } + distributableHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) + t.Run("distributed", func(t *testing.T) { + testStreamableHandler(t, distributableHandler, requests) + testClientCompatibility(t, handler) + }) +} + +func textContent(t *testing.T, res *CallToolResult) string { + t.Helper() + if len(res.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(res.Content)) + } + text, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatalf("Content[0] is %T, want *TextContent", res.Content[0]) + } + return text.Text } diff --git a/mcp/transport.go b/mcp/transport.go index 6d25de33..8018910b 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -40,8 +40,8 @@ type Transport interface { type Connection interface { // Read reads the next message to process off the connection. // - // Read need not be safe for concurrent use: Read is called in a - // concurrency-safe manner by the JSON-RPC library. + // Connections must allow Read to be called concurrently with Close. In + // particular, calling Close should unblock a Read waiting for input. Read(context.Context) (jsonrpc.Message, error) // Write writes a new message to the connection. From 59516ac28698de3ebe460e256a62d4f0fdf8c01b Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 18 Aug 2025 20:14:00 +0000 Subject: [PATCH 3/3] mcp: fix reconnect semantics for hanging GET A few problems with reconnection cropped up in the review of PR #307. We should allow for the hanging GET to fail with StatusMethodNotAllowed. This simply means that the server does not support sending notifications or requests over the GET, which is allowed in the spec. Also, we should fix the initial delay of the hanging GET request: it should start with 0 delay. Fix the math for this and subsequent attempts. Incidentally, this makes the tests take 3s on my machine, down from 9s. Also address some comments from #307. --- internal/jsonrpc2/conn.go | 4 +-- mcp/streamable.go | 68 +++++++++++++++++++++++---------------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 537be47a..49902b00 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -742,8 +742,8 @@ func (c *Connection) write(ctx context.Context, msg Message) error { var err error // Fail writes immediately if the connection is shutting down. // - // TODO(rfindley): should we allow cancellation notifations through? It could - // be the case that writes can still succeed. + // TODO(rfindley): should we allow cancellation notifications through? It + // could be the case that writes can still succeed. c.updateInFlight(func(s *inFlightState) { err = s.shuttingDown(ErrServerClosing) }) diff --git a/mcp/streamable.go b/mcp/streamable.go index 789c7a0b..572fe5de 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -49,6 +49,9 @@ type StreamableHTTPOptions struct { // GetSessionID provides the next session ID to use for an incoming request. // If nil, a default randomly generated ID will be used. // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. + // // As a special case, if GetSessionID returns the empty string, the // Mcp-Session-Id header will not be set. GetSessionID func() string @@ -58,7 +61,9 @@ type StreamableHTTPOptions struct { // A stateless server does not validate the Mcp-Session-Id header, and uses a // temporary session with default initialization parameters. Any // server->client request is rejected immediately as there's no way for the - // client to respond. + // client to respond. Server->Client notifications may reach the client if + // they are made in the context of an incoming request, as described in the + // documentation for [StreamableServerTransport]. Stateless bool // TODO: support session retention (?) @@ -133,12 +138,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque var transport *StreamableServerTransport if sessionID != "" { h.mu.Lock() - transport, _ = h.transports[sessionID] + transport = h.transports[sessionID] h.mu.Unlock() if transport == nil && !h.opts.Stateless { - // In stateless mode we allow a missing transport. + // Unless we're in 'stateless' mode, which doesn't perform any Session-ID + // validation, we require that the session ID matches a known session. // - // A synthetic transport will be created below for the transient session. + // In stateless mode, a temporary transport is be created below. http.Error(w, "session not found", http.StatusNotFound) return } @@ -201,7 +207,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // stateless servers. body, err := io.ReadAll(req.Body) if err != nil { - http.Error(w, "failed to read body", http.StatusBadRequest) + http.Error(w, "failed to read body", http.StatusInternalServerError) return } req.Body.Close() @@ -272,9 +278,22 @@ type StreamableServerTransportOptions struct { // A StreamableServerTransport implements the server side of the MCP streamable // transport. // -// Each StreamableServerTransport may be connected (via [Server.Connect]) at +// Each StreamableServerTransport must be connected (via [Server.Connect]) at // most once, since [StreamableServerTransport.ServeHTTP] serves messages to // the connected session. +// +// Reads from the streamable server connection receive messages from http POST +// requests from the client. Writes to the streamable server connection are +// sent either to the hanging POST response, or to the hanging GET, according +// to the following rules: +// - JSON-RPC responses to incoming requests are always routed to the +// appropriate HTTP response. +// - Requests or notifications made with a context.Context value derived from +// an incoming request handler, are routed to the HTTP response +// corresponding to that request, unless it has already terminated, in +// which case they are routed to the hanging GET. +// - Requests or notifications made with a detached context.Context value are +// routed to the hanging GET. type StreamableServerTransport struct { // SessionID is the ID of this session. // @@ -285,7 +304,7 @@ type StreamableServerTransport struct { // generator to produce one, as with [crypto/rand.Text].) SessionID string - // Stateless controls whether the eventstore is 'Stateless'. Servers sessions + // Stateless controls whether the eventstore is 'Stateless'. Server sessions // connected to a stateless transport are disallowed from making outgoing // requests. // @@ -1225,9 +1244,18 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent c.fail(err) return } - - // Reconnection was successful. Continue the loop with the new response. resp = newResp + if resp.StatusCode == http.StatusMethodNotAllowed && persistent { + // The server doesn't support the hanging GET. + resp.Body.Close() + return + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + resp.Body.Close() + c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode))) + return + } + // Reconnection was successful. Continue the loop with the new response. } } @@ -1295,13 +1323,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er finalErr = err // Store the error and try again. continue } - - if !isResumable(resp) { - // The server indicated we should not continue. - resp.Body.Close() - return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status) - } - return resp, nil } } @@ -1312,16 +1333,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries) } -// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. -func isResumable(resp *http.Response) bool { - // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint. - if resp.StatusCode == http.StatusMethodNotAllowed { - return false - } - - return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") -} - // Close implements the [Connection] interface. func (c *streamableClientConn) Close() error { c.closeOnce.Do(func() { @@ -1361,8 +1372,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, // calculateReconnectDelay calculates a delay using exponential backoff with full jitter. func calculateReconnectDelay(attempt int) time.Duration { + if attempt == 0 { + return 0 + } // Calculate the exponential backoff using the grow factor. - backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt))) + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1))) // Cap the backoffDuration at maxDelay. backoffDuration = min(backoffDuration, reconnectMaxDelay)