From db153ed9d4eb0ad1c7538204448a2fc5ee7578d4 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 14 Oct 2025 21:09:09 +0000 Subject: [PATCH] mcp: allow disabling stream replay; disable by default This CL implements two changes that must be logically connected. The first is to provide a mechanism for disabling stream replay. Previously, our documentation stated that if StreamableServerTransport.EventStore was nil, a default in-memory event store would be used. In that case, there would be no way to completely disable stream replay, and as described in #576, there are many cases where replay is undesirable. But of course changing the behavior of the transport zero value implicitly affects its default behavior, and so this CL also implements the proposal #580: stream replay is disabled by default. Implementing this change required a significant refactoring, as previously we were relying on the event store for serialized message delivery from the JSON-RPC layer to the MCP layer: the connection would write to the event store, and independently the stream (be it an incoming POST or replay GET) would iterate and serve messages in the stream. In order to achieve the goals of this CL, it was necessary to decouple storage from delivery. The 'stream' abstraction now tracks a delivery callback that writes straight to the HTTP response. It would have been convenient to store the ongoing http.ResponseWriter directly in the stream (this is how typescript does it), but due to the design of our EventStore API, only the HTTP handler knows the next event index, so a 'deliver' abstraction was an unfortunate requirement (suggestions for how to further simplify this are welcome). More simplification is possible: in particular, as a result of this refactoring it should be entirely possible to clean up streams once we've received all responses. Any replay would only need access to the EventStore, if at all. This is left to a follow-up CL, to limit this already significant change. Furthermore, a nice consequence of this refactoring is that, when not using event storage, servers can get synchronous feedback that message delivery failed, which should avoid unnecessary work. We can additionally cancel ongoing requests on early client termination, but that is also left to a follow-up CL. Throughout, the terminology 'standalone SSE stream' replaced 'hanging GET' when referring to the non-replay GET request issued by the client. This is consistent with other SDKs. Fixes #576 Updates #580 --- mcp/event.go | 10 +- mcp/streamable.go | 657 +++++++++++++++++++--------------- mcp/streamable_client_test.go | 2 +- mcp/streamable_test.go | 56 ++- mcp/transport.go | 2 +- 5 files changed, 420 insertions(+), 307 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index 1dd36f4e..281f5925 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -153,11 +153,9 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { // // All of an EventStore's methods must be safe for use by multiple goroutines. type EventStore interface { - // Open prepares the event store for a given stream. It ensures that the - // underlying data structure for the stream is initialized, making it - // ready to store event streams. - // - // streamIDs must be globally unique. + // Open is called when a new stream is created. It may be used to ensure that + // the underlying data structure for the stream is initialized, making it + // ready to store and replay event streams. Open(_ context.Context, sessionID, streamID string) error // Append appends data for an outgoing event to given stream, which is part of the @@ -166,6 +164,7 @@ type EventStore interface { // After returns an iterator over the data for the given session and stream, beginning // just after the given index. + // // Once the iterator yields a non-nil error, it will stop. // After's iterator must return an error immediately if any data after index was // dropped; it must not return partial results. @@ -174,6 +173,7 @@ type EventStore interface { // SessionClosed informs the store that the given session is finished, along // with all of its streams. + // // A store cannot rely on this method being called for cleanup. It should institute // additional mechanisms, such as timeouts, to reclaim storage. SessionClosed(_ context.Context, sessionID string) error diff --git a/mcp/streamable.go b/mcp/streamable.go index 20eb13d5..b58e20b4 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "iter" "log/slog" "math" "math/rand/v2" @@ -72,6 +71,9 @@ type StreamableHTTPOptions struct { // Logger specifies the logger to use. // If nil, do not log. Logger *slog.Logger + + // TODO(rfindley): file a proposal to export this option, or something equivalent. + configureTransport func(req *http.Request, transport *StreamableServerTransport) } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -238,6 +240,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque jsonResponse: h.opts.JSONResponse, logger: h.opts.Logger, } + if h.opts.configureTransport != nil { + h.opts.configureTransport(req, transport) + } // To support stateless mode, we initialize the session with a default // state, so that it doesn't reject subsequent requests. @@ -335,16 +340,16 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // // Reads from the streamable server connection receive messages from http POST // requests from the client. Writes to the streamable server connection are -// sent either to the hanging POST response, or to the hanging GET, according -// to the following rules: +// sent either to the related stream, or to the standalone SSE stream, +// according to the following rules: // - JSON-RPC responses to incoming requests are always routed to the // appropriate HTTP response. // - Requests or notifications made with a context.Context value derived from // an incoming request handler, are routed to the HTTP response // corresponding to that request, unless it has already terminated, in -// which case they are routed to the hanging GET. +// which case they are routed to the standalone SSE stream. // - Requests or notifications made with a detached context.Context value are -// routed to the hanging GET. +// routed to the standalone SSE stream. type StreamableServerTransport struct { // SessionID is the ID of this session. // @@ -372,15 +377,16 @@ type StreamableServerTransport struct { // Specifically, responses will be application/json whenever incoming POST // request contain only a single message. In this case, notifications or // requests made within the context of a server request will be sent to the - // hanging GET request, if any. + // standalone SSE stream, if any. // // TODO(rfindley): jsonResponse should be exported, since - // StreamableHTTPOptions.JSONResponse is exported. + // StreamableHTTPOptions.JSONResponse is exported, and we want to allow users + // to write their own streamable HTTP handler. jsonResponse bool // optional logger provided through the [StreamableHTTPOptions.Logger]. // - // TODO(rfindley): logger should be exported, since we want to allow people + // TODO(rfindley): logger should be exported, since we want to allow users // to write their own streamable HTTP handler. logger *slog.Logger @@ -404,15 +410,12 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er streams: make(map[string]*stream), requestStreams: make(map[jsonrpc.ID]string), } - if t.connection.eventStore == nil { - t.connection.eventStore = NewMemoryEventStore(nil) - } - // Stream 0 corresponds to the hanging 'GET'. + // Stream 0 corresponds to the standalone SSE stream. // // It is always text/event-stream, since it must carry arbitrarily many // messages. var err error - t.connection.streams[""], err = t.connection.newStream(ctx, "", false, false) + t.connection.streams[""], err = t.connection.newStream(ctx, nil, "") if err != nil { return nil, err } @@ -449,13 +452,18 @@ type streamableServerConn struct { // bound. If we deleted a stream when the response is sent, we would lose the ability // to replay if there was a cut just before the response was transmitted. // Perhaps we could have a TTL for streams that starts just after the response. + // + // TODO(rfindley): Once all responses have been received for a stream, we can + // remove it as it is no longer necessary, even if the client wants to replay + // messages. streams map[string]*stream // requestStreams maps incoming requests to their logical stream ID. // // Lifecycle: requestStreams persist for the duration of the session. // - // TODO: clean up once requests are handled. See the TODO for streams above. + // TODO(rfindley): clean up once requests are handled. See the TODO for + // streams above. requestStreams map[jsonrpc.ID]string } @@ -472,58 +480,45 @@ func (c *streamableServerConn) SessionID() string { // at any time. type stream struct { // id is the logical ID for the stream, unique within a session. - // an empty string is used for messages that don't correlate with an incoming request. + // + // The standalone SSE stream has id "". id string - // If isInitialize is set, the stream is in response to an initialize request, - // and therefore should include the session ID header. - isInitialize bool - - // jsonResponse records whether this stream should respond with application/json - // instead of text/event-stream. - // - // See [StreamableServerTransportOptions.JSONResponse]. - jsonResponse bool + // mu guards the fields below, as well as storage of new messages in the + // connection's event store (if any). + mu sync.Mutex - // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals - // that there are messages available to write into the HTTP response. - // In addition, the presence of a channel guarantees that at most one HTTP response - // can receive messages for a logical stream. After claiming the stream, incoming - // requests should read from the event store, to ensure that no new messages are missed. - // - // To simplify locking, signal is an atomic. We need an atomic.Pointer, because - // you can't set an atomic.Value to nil. + // If non-nil, deliver writes data directly to the HTTP response. // - // Lifecycle: each channel value persists for the duration of an HTTP POST or - // GET request for the given streamID. - signal atomic.Pointer[chan struct{}] + // Only one HTTP response may receive messages at a given time. An active + // HTTP connection acquires ownership of the stream by setting this field. + deliver func(data []byte, final bool) error - // The following mutable fields are protected by the mutex of the containing - // StreamableServerTransport. - - // streamRequests is the set of unanswered incoming RPCs for the stream. + // streamRequests is the set of unanswered incoming requests for the stream. // - // Requests persist until their response data has been added to the event store. + // Requests are removed when their response has been received. requests map[jsonrpc.ID]struct{} } -func (c *streamableServerConn) newStream(ctx context.Context, id string, isInitialize, jsonResponse bool) (*stream, error) { - if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { - return nil, err +// doneLocked reports whether the stream is logically complete. +// +// s.mu must be held while calling this function. +func (s *stream) doneLocked() bool { + return len(s.requests) == 0 && s.id != "" +} + +func (c *streamableServerConn) newStream(ctx context.Context, requests map[jsonrpc.ID]struct{}, id string) (*stream, error) { + if c.eventStore != nil { + if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { + return nil, err + } } return &stream{ - id: id, - isInitialize: isInitialize, - jsonResponse: jsonResponse, - requests: make(map[jsonrpc.ID]struct{}), + id: id, + requests: requests, }, nil } -func signalChanPtr() *chan struct{} { - c := make(chan struct{}, 1) - return &c -} - // We track the incoming request ID inside the handler context using // idContextValue, so that notifications and server->client calls that occur in // the course of handling incoming requests are correlated with the incoming @@ -571,37 +566,157 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R // // It returns an HTTP status code and error message. func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { - // connID 0 corresponds to the default GET request. - id := "" + // streamID "" corresponds to the default GET request. + streamID := "" // By default, we haven't seen a last index. Since indices start at 0, we represent - // that by -1. This is incremented just before each event is written, in streamResponse - // around L407. + // that by -1. This is incremented just before each event is written. lastIdx := -1 if len(req.Header.Values("Last-Event-ID")) > 0 { eid := req.Header.Get("Last-Event-ID") var ok bool - id, lastIdx, ok = parseEventID(eid) + streamID, lastIdx, ok = parseEventID(eid) if !ok { http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) return } + if c.eventStore == nil { + http.Error(w, "stream replay unsupported", http.StatusBadRequest) + return + } + } + + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + stream, done := c.acquireStream(ctx, w, streamID, &lastIdx) + if stream == nil { + return + } + // Release the stream when we're done. + defer func() { + stream.mu.Lock() + stream.deliver = nil + stream.mu.Unlock() + }() + + select { + case <-ctx.Done(): + // request cancelled + case <-done: + // request complete + case <-c.done: + // session closed + } +} + +// writeEvent writes an SSE event to w corresponding to the given stream, data, and index. +// lastIdx is incremented before writing, so that it continues to point to the index of the +// last event written to the stream. +func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, data []byte, lastIdx *int) error { + *lastIdx++ + e := Event{ + Name: "message", + Data: data, + } + if c.eventStore != nil { + e.ID = formatEventID(stream.id, *lastIdx) } + if _, err := writeEvent(w, e); err != nil { + return err + } + return nil +} +// acquireStream acquires the stream and replays all events since lastIdx, if +// any, updating lastIdx accordingly. If non-nil, the resulting stream will be +// registered for receiving new messages, and the resulting done channel will +// be closed when all related messages have been delivered. +// +// If any errors occur, they will be written to w and the resulting stream will +// be nil. The resulting stream may also be nil if the stream is complete. +// +// Importantly, this function must hold the stream mutex until done replaying +// all messages, so that no delivery or storage of new messages occurs while +// the stream is still replaying. +func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) { c.mu.Lock() - stream, ok := c.streams[id] + stream, ok := c.streams[streamID] c.mu.Unlock() if !ok { http.Error(w, "unknown stream", http.StatusBadRequest) - return + return nil, nil } - if !stream.signal.CompareAndSwap(nil, signalChanPtr()) { - // The CAS returned false, meaning that the comparison failed: stream.signal is not nil. + + stream.mu.Lock() + defer stream.mu.Unlock() + if stream.deliver != nil { http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) - return + return nil, nil + } + + // Collect events to replay. Collect them all before writing, so that we + // have an opportunity to set the HTTP status code on an error. + // + // As indicated above, we must do that while holding stream.mu, so that no + // new messages are added to the eventstore until we've replayed all previous + // messages, and registered our delivery function. + var toReplay [][]byte + if c.eventStore != nil { + for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, *lastIdx) { + if err != nil { + // We can't replay events, perhaps because the underlying event store + // has garbage collected its storage. + // + // We must be careful here: any 404 will signal to the client that the + // *session* is not found, rather than the stream. + // + // 400 is not really accurate, but should at least have no side effects. + // Other SDKs (typescript) do not have a mechanism for events to be purged. + http.Error(w, "failed to replay events", http.StatusBadRequest) + return nil, nil + } + toReplay = append(toReplay, data) + } + } + + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] + w.Header().Set("Connection", "keep-alive") + + if stream.id == "" { + // Issue #410: the standalone SSE stream is likely not to receive messages + // for a long time. Ensure that headers are flushed. + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + + for _, data := range toReplay { + if err := c.writeEvent(w, stream, data, lastIdx); err != nil { + return nil, nil + } + } + + if stream.doneLocked() { + // Nothing more to do. + return nil, nil + } + + // Finally register a delivery function and unlock the stream, allowing the + // connection to write new events. + done := make(chan struct{}) + stream.deliver = func(data []byte, final bool) error { + if err := ctx.Err(); err != nil { + return err + } + err := c.writeEvent(w, stream, data, lastIdx) + if final { + close(done) + } + return err } - defer stream.signal.Store(nil) - persistent := id == "" // Only the special stream "" is a hanging get. - c.respondSSE(stream, w, req, lastIdx, persistent) + return stream, done } // servePOST handles an incoming message, and replies with either an outgoing @@ -641,7 +756,14 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } - requests := make(map[jsonrpc.ID]struct{}) + // TODO(rfindley): no tests fail if we reject batch JSON requests entirely. + // We need to test this with older protocol versions. + // if isBatch && c.jsonResponse { + // http.Error(w, "server does not support batch requests", http.StatusBadRequest) + // return + // } + + calls := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false for _, msg := range incoming { @@ -661,204 +783,134 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques Header: req.Header, } if jreq.IsCall() { - requests[jreq.ID] = struct{}{} + calls[jreq.ID] = struct{}{} } } } - var stream *stream // if non-nil, used to handle requests - - // If we have requests, we need to handle responses along with any - // notifications or server->client requests made in the course of handling. - // Update accounting for this incoming payload. - if len(requests) > 0 { - stream, err = c.newStream(req.Context(), randText(), isInitialize, c.jsonResponse) - if err != nil { - http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) - return - } - c.mu.Lock() - c.streams[stream.id] = stream - stream.requests = requests - for reqID := range requests { - c.requestStreams[reqID] = stream.id + // If we don't have any calls, we can just publish the incoming messages and return. + // No need to track a logical stream. + if len(calls) == 0 { + for _, msg := range incoming { + select { + case c.incoming <- msg: + case <-c.done: + // The session is closing. Since we haven't yet written any data to the + // response, we can signal to the client that the session is gone. + http.Error(w, "session is closing", http.StatusNotFound) + return + } } - c.mu.Unlock() - stream.signal.Store(signalChanPtr()) - defer stream.signal.Store(nil) - } - - // Publish incoming messages. - for _, msg := range incoming { - c.incoming <- msg - } - - if stream == nil { w.WriteHeader(http.StatusAccepted) return } - if stream.jsonResponse { - c.respondJSON(stream, w, req) - } else { - c.respondSSE(stream, w, req, -1, false) + // Invariant: we have at least one call. + // + // Create a logical stream to track its responses. + // Important: don't publish the incoming messages until the stream is + // registered, as the server may attempt to respond to imcoming messages as + // soon as they're published. + stream, err := c.newStream(req.Context(), calls, randText()) + if err != nil { + http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) + return } -} -func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { + // Set response headers. Accept was checked in [StreamableHTTPHandler]. w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Content-Type", "application/json") - if c.sessionID != "" && stream.isInitialize { + if c.jsonResponse { + w.Header().Set("Content-Type", "application/json") + } else { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + } + if c.sessionID != "" && isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) } - var msgs []json.RawMessage - ctx := req.Context() - for msg, err := range c.messages(ctx, stream, false, -1) { - if err != nil { - if ctx.Err() != nil { - w.WriteHeader(http.StatusNoContent) - return + // Message delivery has two paths, depending on whether we're responding with JSON or + // event stream. + done := make(chan struct{}) // closed after the final response is written + if c.jsonResponse { + var msgs []json.RawMessage + stream.deliver = func(data []byte, final bool) error { + // Collect messages until we've received the final response. + // + // In recent protocol versions, there should only be one message as + // batching is disabled, as checked above. + msgs = append(msgs, data) + if !final { + return nil + } + defer close(done) // final response + + // Write either the JSON object corresponding to the one response, or a + // JSON array corresponding to the batch response. + var toWrite []byte + if len(msgs) == 1 { + toWrite = []byte(msgs[0]) } else { - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return + var err error + toWrite, err = json.Marshal(msgs) + if err != nil { + return err + } } + _, err = w.Write(toWrite) + return err } - msgs = append(msgs, msg) - } - var data []byte - if len(msgs) == 1 { - data = []byte(msgs[0]) } else { - // TODO: add tests for batch responses, or disallow them entirely. - var err error - data, err = json.Marshal(msgs) - if err != nil { - http.Error(w, fmt.Sprintf("internal error marshalling response: %v", err), http.StatusInternalServerError) - return + // Write events in the order we receive them. + lastIndex := -1 + stream.deliver = func(data []byte, final bool) error { + if final { + defer close(done) + } + return c.writeEvent(w, stream, data, &lastIndex) } } - _, _ = w.Write(data) // ignore error: client disconnected -} -// lastIndex is the index of the last seen event if resuming, else -1. -func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) { - // Accept was checked in [StreamableHTTPHandler] - w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] - w.Header().Set("Connection", "keep-alive") - if c.sessionID != "" && stream.isInitialize { - w.Header().Set(sessionIDHeader, c.sessionID) - } - if persistent { - // Issue #410: the hanging GET is likely not to receive messages for a long - // time. Ensure that headers are flushed. - // - // For non-persistent requests, delay the writing of the header in case we - // may want to set an error status. - // (see the TODO: this probably isn't worth it). - w.WriteHeader(http.StatusOK) - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - } + // Release ownership of the stream by unsetting deliver. + defer func() { + stream.mu.Lock() + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + stream.deliver = nil + stream.mu.Unlock() + }() - // write one event containing data. - writes := 0 - write := func(data []byte) bool { - lastIndex++ - e := Event{ - Name: "message", - ID: formatEventID(stream.id, lastIndex), - Data: data, - } - if _, err := writeEvent(w, e); err != nil { - // Connection closed or broken. - c.logger.Warn("error writing event", "error", err) - return false - } - writes++ - return true + // The stream is now set up to deliver messages. + // + // Register it before publishing incoming messages. + c.mu.Lock() + c.streams[stream.id] = stream + for reqID := range calls { + c.requestStreams[reqID] = stream.id } + c.mu.Unlock() - // Repeatedly collect pending outgoing events and send them. - ctx := req.Context() - for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { - if err != nil { - if ctx.Err() == nil && writes == 0 && !persistent { - // If we haven't yet written the header, we have an opportunity to - // promote an error to an HTTP error. - // - // TODO: This may not matter in practice, in which case we should - // simplify. - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - } else { - if ctx.Err() != nil { - // Client disconnected or cancelled the request. - c.logger.Error("stream context done", "error", ctx.Err()) - } else { - // Some other error. - c.logger.Error("error receiving message", "error", err) - } - } - return - } - if !write(msg) { + // Publish incoming messages. + for _, msg := range incoming { + select { + case c.incoming <- msg: + // Note: don't select on req.Context().Done() here, since we've already + // received the requests and may have already published a response message + // or notification. The client could resume the stream. + case <-c.done: + // Session closed: we don't know if any data has been written, so it's + // too late to write a status code here. return } } -} - -// messages iterates over messages sent to the current stream. -// -// persistent indicates if it is the main GET listener, which should never be -// terminated. -// lastIndex is the index of the last seen event, iteration begins at lastIndex+1. -// -// The first iterated value is the received JSON message. The second iterated -// value is an error value indicating whether the stream terminated normally. -// Iteration ends at the first non-nil error. -// -// If the stream did not terminate normally, it is either because ctx was -// cancelled, or the connection is closed: check the ctx.Err() to differentiate -// these cases. -func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] { - return func(yield func(json.RawMessage, error) bool) { - for { - c.mu.Lock() - nOutstanding := len(stream.requests) - c.mu.Unlock() - for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) { - if err != nil { - yield(nil, err) - return - } - if !yield(data, nil) { - return - } - lastIndex++ - } - // If all requests have been handled and replied to, we should terminate this connection. - // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." - // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server - // We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET - // (stream ID 0) will never have requests, and should remain open indefinitely. - if nOutstanding == 0 && !persistent { - return - } - - select { - case <-*stream.signal.Load(): // there are new outgoing messages - // return to top of loop - case <-c.done: // session is closed - yield(nil, errors.New("session is closed")) - return - case <-ctx.Done(): - yield(nil, ctx.Err()) - return - } - } + select { + case <-req.Context().Done(): + // request cancelled + case <-done: + // request complete + case <-c.done: + // session is closed } } @@ -907,80 +959,93 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error // Write implements the [Connection] interface. func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // Throughout this function, note that any error that wraps ErrRejected + // indicates a does not cause the connection to break. + // + // Most errors don't break the connection: unlike a true bidirectional + // stream, a failure to deliver to a stream is not an indication that the + // logical session is broken. + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") { // Requests aren't possible with stateless servers, or when there's no session ID. return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) } + // Find the incoming request that this write relates to, if any. - var forRequest jsonrpc.ID - isResponse := false + var ( + relatedRequest jsonrpc.ID + responseTo jsonrpc.ID // if valid, the message is a response to this request + ) if resp, ok := msg.(*jsonrpc.Response); ok { // If the message is a response, it relates to its request (of course). - forRequest = resp.ID - isResponse = true + relatedRequest = resp.ID + responseTo = resp.ID } else { // Otherwise, we check to see if it request was made in the context of an // ongoing request. This may not be the case if the request was made with // an unrelated context. if v := ctx.Value(idContextKey{}); v != nil { - forRequest = v.(jsonrpc.ID) + relatedRequest = v.(jsonrpc.ID) } } - // Find the logical connection corresponding to this request. - // - // For messages sent outside of a request context, this is the default - // connection "". - var forStream string - if forRequest.IsValid() { - c.mu.Lock() - forStream = c.requestStreams[forRequest] - c.mu.Unlock() - } - - data, err := jsonrpc2.EncodeMessage(msg) - if err != nil { - return err + // If the stream is application/json, but the message is not a response, we + // must send it out of band to the standalone SSE stream. + if c.jsonResponse && !responseTo.IsValid() { + relatedRequest = jsonrpc.ID{} } + // Write the message to the stream. + var s *stream c.mu.Lock() - defer c.mu.Unlock() - if c.isDone { - return errors.New("session is closed") + if relatedRequest.IsValid() { + if streamID, ok := c.requestStreams[relatedRequest]; ok { + s = c.streams[streamID] + } + } else { + s = c.streams[""] } + sessionClosed := c.isDone + c.mu.Unlock() - stream := c.streams[forStream] - if stream == nil { - return fmt.Errorf("no stream with ID %s", forStream) + if s == nil { + return fmt.Errorf("%w: no stream for request", jsonrpc2.ErrRejected) } - - // Special case a few conditions where we fall back on stream 0 (the hanging GET): - // - // - if forStream is known, but the associated stream is logically complete - // - if the stream is application/json, but the message is not a response - // - // TODO(rfindley): either of these, particularly the first, might be - // considered a bug in the server. Report it through a side-channel? - if len(stream.requests) == 0 && forStream != "" || stream.jsonResponse && !isResponse { - stream = c.streams[""] + if sessionClosed { + return errors.New("session is closed") } - if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil { - return fmt.Errorf("error storing event: %w", err) + s.mu.Lock() + defer s.mu.Unlock() + if s.doneLocked() { + return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) } - if isResponse { - // Once we've put the reply on the queue, it's no longer outstanding. - delete(stream.requests, forRequest) + if responseTo.IsValid() { + delete(s.requests, responseTo) } - // Signal streamResponse that new work is available. - signalp := stream.signal.Load() - if signalp != nil { - select { - case *signalp <- struct{}{}: - default: + delivered := false + if c.eventStore != nil { + if err := c.eventStore.Append(ctx, c.sessionID, s.id, data); err != nil { + // TODO: report a side-channel error. + } else { + delivered = true } } + if s.deliver != nil { + if err := s.deliver(data, s.doneLocked()); err != nil { + // TODO: report a side-channel error. + } else { + delivered = true + } + } + if !delivered { + return fmt.Errorf("%w: undelivered message", jsonrpc2.ErrRejected) + } return nil } @@ -991,9 +1056,11 @@ func (c *streamableServerConn) Close() error { if !c.isDone { c.isDone = true close(c.done) - // TODO: find a way to plumb a context here, or an event store with a long-running - // close operation can take arbitrary time. Alternative: impose a fixed timeout here. - return c.eventStore.SessionClosed(context.TODO(), c.sessionID) + if c.eventStore != nil { + // TODO: find a way to plumb a context here, or an event store with a long-running + // close operation can take arbitrary time. Alternative: impose a fixed timeout here. + return c.eventStore.SessionClosed(context.TODO(), c.sessionID) + } } return nil } @@ -1118,7 +1185,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { c.initializedResult = state.InitializeResult c.mu.Unlock() - // Start the persistent SSE listener as soon as we have the initialized + // Start the standalone SSE stream as soon as we have the initialized // result. // // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be @@ -1131,7 +1198,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // § 2.5: A server using the Streamable HTTP transport MAY assign a session // ID at initialization time, by including it in an Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - go c.handleSSE("hanging GET", nil, true, nil) + go c.handleSSE("standalone SSE stream", nil, true, nil) } // fail handles an asynchronous error while reading. @@ -1272,6 +1339,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e case "text/event-stream": jsonReq, _ := msg.(*jsonrpc.Request) + // TODO: should we cancel this logical SSE request if/when jsonReq is canceled? go c.handleSSE(requestSummary, resp, false, jsonReq) default: @@ -1328,6 +1396,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt resp := initialResp var lastEventID string for { + // TODO: we should set a reasonable limit on the number of times we'll try + // getting a response for a given request. + // + // Eventually, if we don't get the response, we should stop trying and + // fail the request. if resp != nil { eventID, clientClosed := c.processStream(requestSummary, resp, forReq) lastEventID = eventID @@ -1359,17 +1432,17 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt // // [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server - // The server doesn't support the hanging GET. + // The server doesn't support the standalone SSE stream. resp.Body.Close() return } if resp.StatusCode == http.StatusNotFound && persistent && !c.strict { // modelcontextprotocol/gosdk#393: some servers return NotFound instead - // of MethodNotAllowed for the persistent GET. + // of MethodNotAllowed for the standalone SSE stream. // // Treat this like MethodNotAllowed in non-strict mode. if c.logger != nil { - c.logger.Warn("got 404 instead of 405 for hanging GET") + c.logger.Warn("got 404 instead of 405 for standalonw SSE stream") } resp.Body.Close() return @@ -1412,8 +1485,8 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R select { case c.incoming <- msg: if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil { - // TODO: we should never get a response when forReq is nil (the hanging GET). - // We should detect this case, and eliminate the 'persistent' flag arguments. + // TODO: we should never get a response when forReq is nil (the standalone SSE request). + // We should detect this case. if jsonResp.ID == forReq.ID { return "", true } @@ -1433,7 +1506,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { var finalErr error - // We can reach the 'reconnect' path through the hanging GET, in which case + // We can reach the 'reconnect' path through the standlone SSE request, in which case // lastEventID will be "". // // In this case, we need an initial attempt. diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 9116677b..faca04c6 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -246,7 +246,7 @@ func TestStreamableClientGETHandling(t *testing.T) { }{ {http.StatusOK, ""}, {http.StatusMethodNotAllowed, ""}, - {http.StatusBadRequest, "hanging GET"}, + {http.StatusBadRequest, "standalone SSE"}, } for _, test := range tests { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 3c0b8be3..a0893689 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -41,8 +41,18 @@ func TestStreamableTransports(t *testing.T) { ctx := context.Background() - for _, useJSON := range []bool{false, true} { - t.Run(fmt.Sprintf("JSONResponse=%v", useJSON), func(t *testing.T) { + tests := []struct { + useJSON bool + replay bool + }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("JSONResponse=%v;replay=%v", test.useJSON, test.replay), func(t *testing.T) { // Create a server with some simple tools. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) @@ -86,7 +96,12 @@ func TestStreamableTransports(t *testing.T) { // Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ - JSONResponse: useJSON, + JSONResponse: test.useJSON, + configureTransport: func(_ *http.Request, transport *StreamableServerTransport) { + if test.replay { + transport.EventStore = NewMemoryEventStore(nil) + } + }, }) var ( @@ -370,7 +385,11 @@ func testClientReplay(t *testing.T, test clientReplayTest) { return new(CallToolResult), nil, nil }) - realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))) + realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + configureTransport: func(_ *http.Request, t *StreamableServerTransport) { + t.EventStore = NewMemoryEventStore(nil) // necessary for replay + }, + }))) t.Cleanup(func() { t.Log("Closing real HTTP server") realServer.Close() @@ -543,7 +562,16 @@ func TestServerInitiatedSSE(t *testing.T) { notifications := make(chan string) server := NewServer(testImpl, nil) - httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))) + opts := &StreamableHTTPOptions{ + // TODO(#583): for now, this is required for guaranteed message delivery. + // However, it shouldn't be necessary to use replay here, as we should be + // guaranteed that the standalone SSE stream is started by the time the + // client is connected. + configureTransport: func(_ *http.Request, transport *StreamableServerTransport) { + transport.EventStore = NewMemoryEventStore(nil) + }, + } + httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, opts))) defer httpServer.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -640,6 +668,7 @@ func TestStreamableServerTransport(t *testing.T) { tests := []struct { name string + replay bool // if set, use a MemoryEventStore to enable stream replay tool func(*testing.T, context.Context, *ServerSession) requests []streamableRequest // http requests }{ @@ -804,10 +833,15 @@ func TestStreamableServerTransport(t *testing.T) { }, { name: "background", + // Enabling replay is necessary here because the standalone "GET" request + // is fully asynronous. Replay is needed to guarantee message delivery. + replay: true, tool: func(t *testing.T, _ context.Context, ss *ServerSession) { // Perform operations on a background context, and ensure the client // receives it. - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + if err := ss.NotifyProgress(ctx, &ProgressNotificationParams{}); err != nil { t.Errorf("Notify failed: %v", err) } @@ -816,7 +850,7 @@ func TestStreamableServerTransport(t *testing.T) { // t.Errorf("Logging failed: %v", err) // } if _, err := ss.ListRoots(ctx, &ListRootsParams{}); err != nil { - t.Errorf("Notify failed: %v", err) + t.Errorf("ListRoots failed: %v", err) } }, requests: []streamableRequest{ @@ -906,8 +940,14 @@ func TestStreamableServerTransport(t *testing.T) { return &CallToolResult{}, nil }) + opts := &StreamableHTTPOptions{} + if test.replay { + opts.configureTransport = func(_ *http.Request, t *StreamableServerTransport) { + t.EventStore = NewMemoryEventStore(nil) + } + } // Start the streamable handler. - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts) defer handler.closeAll() testStreamableHandler(t, handler, test.requests) diff --git a/mcp/transport.go b/mcp/transport.go index 1beab470..cacd65fd 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -46,7 +46,7 @@ type Connection interface { // Write writes a new message to the connection. // - // Write may be called concurrently, as calls or reponses may occur + // Write may be called concurrently, as calls or responses may occur // concurrently in user code. Write(context.Context, jsonrpc.Message) error