diff --git a/mcp/event.go b/mcp/event.go index 1dd36f4e..281f5925 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -153,11 +153,9 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { // // All of an EventStore's methods must be safe for use by multiple goroutines. type EventStore interface { - // Open prepares the event store for a given stream. It ensures that the - // underlying data structure for the stream is initialized, making it - // ready to store event streams. - // - // streamIDs must be globally unique. + // Open is called when a new stream is created. It may be used to ensure that + // the underlying data structure for the stream is initialized, making it + // ready to store and replay event streams. Open(_ context.Context, sessionID, streamID string) error // Append appends data for an outgoing event to given stream, which is part of the @@ -166,6 +164,7 @@ type EventStore interface { // After returns an iterator over the data for the given session and stream, beginning // just after the given index. + // // Once the iterator yields a non-nil error, it will stop. // After's iterator must return an error immediately if any data after index was // dropped; it must not return partial results. @@ -174,6 +173,7 @@ type EventStore interface { // SessionClosed informs the store that the given session is finished, along // with all of its streams. + // // A store cannot rely on this method being called for cleanup. It should institute // additional mechanisms, such as timeouts, to reclaim storage. SessionClosed(_ context.Context, sessionID string) error diff --git a/mcp/streamable.go b/mcp/streamable.go index 20eb13d5..b58e20b4 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "iter" "log/slog" "math" "math/rand/v2" @@ -72,6 +71,9 @@ type StreamableHTTPOptions struct { // Logger specifies the logger to use. // If nil, do not log. Logger *slog.Logger + + // TODO(rfindley): file a proposal to export this option, or something equivalent. + configureTransport func(req *http.Request, transport *StreamableServerTransport) } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -238,6 +240,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque jsonResponse: h.opts.JSONResponse, logger: h.opts.Logger, } + if h.opts.configureTransport != nil { + h.opts.configureTransport(req, transport) + } // To support stateless mode, we initialize the session with a default // state, so that it doesn't reject subsequent requests. @@ -335,16 +340,16 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // // Reads from the streamable server connection receive messages from http POST // requests from the client. Writes to the streamable server connection are -// sent either to the hanging POST response, or to the hanging GET, according -// to the following rules: +// sent either to the related stream, or to the standalone SSE stream, +// according to the following rules: // - JSON-RPC responses to incoming requests are always routed to the // appropriate HTTP response. // - Requests or notifications made with a context.Context value derived from // an incoming request handler, are routed to the HTTP response // corresponding to that request, unless it has already terminated, in -// which case they are routed to the hanging GET. +// which case they are routed to the standalone SSE stream. // - Requests or notifications made with a detached context.Context value are -// routed to the hanging GET. +// routed to the standalone SSE stream. type StreamableServerTransport struct { // SessionID is the ID of this session. // @@ -372,15 +377,16 @@ type StreamableServerTransport struct { // Specifically, responses will be application/json whenever incoming POST // request contain only a single message. In this case, notifications or // requests made within the context of a server request will be sent to the - // hanging GET request, if any. + // standalone SSE stream, if any. // // TODO(rfindley): jsonResponse should be exported, since - // StreamableHTTPOptions.JSONResponse is exported. + // StreamableHTTPOptions.JSONResponse is exported, and we want to allow users + // to write their own streamable HTTP handler. jsonResponse bool // optional logger provided through the [StreamableHTTPOptions.Logger]. // - // TODO(rfindley): logger should be exported, since we want to allow people + // TODO(rfindley): logger should be exported, since we want to allow users // to write their own streamable HTTP handler. logger *slog.Logger @@ -404,15 +410,12 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er streams: make(map[string]*stream), requestStreams: make(map[jsonrpc.ID]string), } - if t.connection.eventStore == nil { - t.connection.eventStore = NewMemoryEventStore(nil) - } - // Stream 0 corresponds to the hanging 'GET'. + // Stream 0 corresponds to the standalone SSE stream. // // It is always text/event-stream, since it must carry arbitrarily many // messages. var err error - t.connection.streams[""], err = t.connection.newStream(ctx, "", false, false) + t.connection.streams[""], err = t.connection.newStream(ctx, nil, "") if err != nil { return nil, err } @@ -449,13 +452,18 @@ type streamableServerConn struct { // bound. If we deleted a stream when the response is sent, we would lose the ability // to replay if there was a cut just before the response was transmitted. // Perhaps we could have a TTL for streams that starts just after the response. + // + // TODO(rfindley): Once all responses have been received for a stream, we can + // remove it as it is no longer necessary, even if the client wants to replay + // messages. streams map[string]*stream // requestStreams maps incoming requests to their logical stream ID. // // Lifecycle: requestStreams persist for the duration of the session. // - // TODO: clean up once requests are handled. See the TODO for streams above. + // TODO(rfindley): clean up once requests are handled. See the TODO for + // streams above. requestStreams map[jsonrpc.ID]string } @@ -472,58 +480,45 @@ func (c *streamableServerConn) SessionID() string { // at any time. type stream struct { // id is the logical ID for the stream, unique within a session. - // an empty string is used for messages that don't correlate with an incoming request. + // + // The standalone SSE stream has id "". id string - // If isInitialize is set, the stream is in response to an initialize request, - // and therefore should include the session ID header. - isInitialize bool - - // jsonResponse records whether this stream should respond with application/json - // instead of text/event-stream. - // - // See [StreamableServerTransportOptions.JSONResponse]. - jsonResponse bool + // mu guards the fields below, as well as storage of new messages in the + // connection's event store (if any). + mu sync.Mutex - // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals - // that there are messages available to write into the HTTP response. - // In addition, the presence of a channel guarantees that at most one HTTP response - // can receive messages for a logical stream. After claiming the stream, incoming - // requests should read from the event store, to ensure that no new messages are missed. - // - // To simplify locking, signal is an atomic. We need an atomic.Pointer, because - // you can't set an atomic.Value to nil. + // If non-nil, deliver writes data directly to the HTTP response. // - // Lifecycle: each channel value persists for the duration of an HTTP POST or - // GET request for the given streamID. - signal atomic.Pointer[chan struct{}] + // Only one HTTP response may receive messages at a given time. An active + // HTTP connection acquires ownership of the stream by setting this field. + deliver func(data []byte, final bool) error - // The following mutable fields are protected by the mutex of the containing - // StreamableServerTransport. - - // streamRequests is the set of unanswered incoming RPCs for the stream. + // streamRequests is the set of unanswered incoming requests for the stream. // - // Requests persist until their response data has been added to the event store. + // Requests are removed when their response has been received. requests map[jsonrpc.ID]struct{} } -func (c *streamableServerConn) newStream(ctx context.Context, id string, isInitialize, jsonResponse bool) (*stream, error) { - if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { - return nil, err +// doneLocked reports whether the stream is logically complete. +// +// s.mu must be held while calling this function. +func (s *stream) doneLocked() bool { + return len(s.requests) == 0 && s.id != "" +} + +func (c *streamableServerConn) newStream(ctx context.Context, requests map[jsonrpc.ID]struct{}, id string) (*stream, error) { + if c.eventStore != nil { + if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { + return nil, err + } } return &stream{ - id: id, - isInitialize: isInitialize, - jsonResponse: jsonResponse, - requests: make(map[jsonrpc.ID]struct{}), + id: id, + requests: requests, }, nil } -func signalChanPtr() *chan struct{} { - c := make(chan struct{}, 1) - return &c -} - // We track the incoming request ID inside the handler context using // idContextValue, so that notifications and server->client calls that occur in // the course of handling incoming requests are correlated with the incoming @@ -571,37 +566,157 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R // // It returns an HTTP status code and error message. func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { - // connID 0 corresponds to the default GET request. - id := "" + // streamID "" corresponds to the default GET request. + streamID := "" // By default, we haven't seen a last index. Since indices start at 0, we represent - // that by -1. This is incremented just before each event is written, in streamResponse - // around L407. + // that by -1. This is incremented just before each event is written. lastIdx := -1 if len(req.Header.Values("Last-Event-ID")) > 0 { eid := req.Header.Get("Last-Event-ID") var ok bool - id, lastIdx, ok = parseEventID(eid) + streamID, lastIdx, ok = parseEventID(eid) if !ok { http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) return } + if c.eventStore == nil { + http.Error(w, "stream replay unsupported", http.StatusBadRequest) + return + } + } + + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + stream, done := c.acquireStream(ctx, w, streamID, &lastIdx) + if stream == nil { + return + } + // Release the stream when we're done. + defer func() { + stream.mu.Lock() + stream.deliver = nil + stream.mu.Unlock() + }() + + select { + case <-ctx.Done(): + // request cancelled + case <-done: + // request complete + case <-c.done: + // session closed + } +} + +// writeEvent writes an SSE event to w corresponding to the given stream, data, and index. +// lastIdx is incremented before writing, so that it continues to point to the index of the +// last event written to the stream. +func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, data []byte, lastIdx *int) error { + *lastIdx++ + e := Event{ + Name: "message", + Data: data, + } + if c.eventStore != nil { + e.ID = formatEventID(stream.id, *lastIdx) } + if _, err := writeEvent(w, e); err != nil { + return err + } + return nil +} +// acquireStream acquires the stream and replays all events since lastIdx, if +// any, updating lastIdx accordingly. If non-nil, the resulting stream will be +// registered for receiving new messages, and the resulting done channel will +// be closed when all related messages have been delivered. +// +// If any errors occur, they will be written to w and the resulting stream will +// be nil. The resulting stream may also be nil if the stream is complete. +// +// Importantly, this function must hold the stream mutex until done replaying +// all messages, so that no delivery or storage of new messages occurs while +// the stream is still replaying. +func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) { c.mu.Lock() - stream, ok := c.streams[id] + stream, ok := c.streams[streamID] c.mu.Unlock() if !ok { http.Error(w, "unknown stream", http.StatusBadRequest) - return + return nil, nil } - if !stream.signal.CompareAndSwap(nil, signalChanPtr()) { - // The CAS returned false, meaning that the comparison failed: stream.signal is not nil. + + stream.mu.Lock() + defer stream.mu.Unlock() + if stream.deliver != nil { http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) - return + return nil, nil + } + + // Collect events to replay. Collect them all before writing, so that we + // have an opportunity to set the HTTP status code on an error. + // + // As indicated above, we must do that while holding stream.mu, so that no + // new messages are added to the eventstore until we've replayed all previous + // messages, and registered our delivery function. + var toReplay [][]byte + if c.eventStore != nil { + for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, *lastIdx) { + if err != nil { + // We can't replay events, perhaps because the underlying event store + // has garbage collected its storage. + // + // We must be careful here: any 404 will signal to the client that the + // *session* is not found, rather than the stream. + // + // 400 is not really accurate, but should at least have no side effects. + // Other SDKs (typescript) do not have a mechanism for events to be purged. + http.Error(w, "failed to replay events", http.StatusBadRequest) + return nil, nil + } + toReplay = append(toReplay, data) + } + } + + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] + w.Header().Set("Connection", "keep-alive") + + if stream.id == "" { + // Issue #410: the standalone SSE stream is likely not to receive messages + // for a long time. Ensure that headers are flushed. + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + + for _, data := range toReplay { + if err := c.writeEvent(w, stream, data, lastIdx); err != nil { + return nil, nil + } + } + + if stream.doneLocked() { + // Nothing more to do. + return nil, nil + } + + // Finally register a delivery function and unlock the stream, allowing the + // connection to write new events. + done := make(chan struct{}) + stream.deliver = func(data []byte, final bool) error { + if err := ctx.Err(); err != nil { + return err + } + err := c.writeEvent(w, stream, data, lastIdx) + if final { + close(done) + } + return err } - defer stream.signal.Store(nil) - persistent := id == "" // Only the special stream "" is a hanging get. - c.respondSSE(stream, w, req, lastIdx, persistent) + return stream, done } // servePOST handles an incoming message, and replies with either an outgoing @@ -641,7 +756,14 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } - requests := make(map[jsonrpc.ID]struct{}) + // TODO(rfindley): no tests fail if we reject batch JSON requests entirely. + // We need to test this with older protocol versions. + // if isBatch && c.jsonResponse { + // http.Error(w, "server does not support batch requests", http.StatusBadRequest) + // return + // } + + calls := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false for _, msg := range incoming { @@ -661,204 +783,134 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques Header: req.Header, } if jreq.IsCall() { - requests[jreq.ID] = struct{}{} + calls[jreq.ID] = struct{}{} } } } - var stream *stream // if non-nil, used to handle requests - - // If we have requests, we need to handle responses along with any - // notifications or server->client requests made in the course of handling. - // Update accounting for this incoming payload. - if len(requests) > 0 { - stream, err = c.newStream(req.Context(), randText(), isInitialize, c.jsonResponse) - if err != nil { - http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) - return - } - c.mu.Lock() - c.streams[stream.id] = stream - stream.requests = requests - for reqID := range requests { - c.requestStreams[reqID] = stream.id + // If we don't have any calls, we can just publish the incoming messages and return. + // No need to track a logical stream. + if len(calls) == 0 { + for _, msg := range incoming { + select { + case c.incoming <- msg: + case <-c.done: + // The session is closing. Since we haven't yet written any data to the + // response, we can signal to the client that the session is gone. + http.Error(w, "session is closing", http.StatusNotFound) + return + } } - c.mu.Unlock() - stream.signal.Store(signalChanPtr()) - defer stream.signal.Store(nil) - } - - // Publish incoming messages. - for _, msg := range incoming { - c.incoming <- msg - } - - if stream == nil { w.WriteHeader(http.StatusAccepted) return } - if stream.jsonResponse { - c.respondJSON(stream, w, req) - } else { - c.respondSSE(stream, w, req, -1, false) + // Invariant: we have at least one call. + // + // Create a logical stream to track its responses. + // Important: don't publish the incoming messages until the stream is + // registered, as the server may attempt to respond to imcoming messages as + // soon as they're published. + stream, err := c.newStream(req.Context(), calls, randText()) + if err != nil { + http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) + return } -} -func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { + // Set response headers. Accept was checked in [StreamableHTTPHandler]. w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Content-Type", "application/json") - if c.sessionID != "" && stream.isInitialize { + if c.jsonResponse { + w.Header().Set("Content-Type", "application/json") + } else { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + } + if c.sessionID != "" && isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) } - var msgs []json.RawMessage - ctx := req.Context() - for msg, err := range c.messages(ctx, stream, false, -1) { - if err != nil { - if ctx.Err() != nil { - w.WriteHeader(http.StatusNoContent) - return + // Message delivery has two paths, depending on whether we're responding with JSON or + // event stream. + done := make(chan struct{}) // closed after the final response is written + if c.jsonResponse { + var msgs []json.RawMessage + stream.deliver = func(data []byte, final bool) error { + // Collect messages until we've received the final response. + // + // In recent protocol versions, there should only be one message as + // batching is disabled, as checked above. + msgs = append(msgs, data) + if !final { + return nil + } + defer close(done) // final response + + // Write either the JSON object corresponding to the one response, or a + // JSON array corresponding to the batch response. + var toWrite []byte + if len(msgs) == 1 { + toWrite = []byte(msgs[0]) } else { - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return + var err error + toWrite, err = json.Marshal(msgs) + if err != nil { + return err + } } + _, err = w.Write(toWrite) + return err } - msgs = append(msgs, msg) - } - var data []byte - if len(msgs) == 1 { - data = []byte(msgs[0]) } else { - // TODO: add tests for batch responses, or disallow them entirely. - var err error - data, err = json.Marshal(msgs) - if err != nil { - http.Error(w, fmt.Sprintf("internal error marshalling response: %v", err), http.StatusInternalServerError) - return + // Write events in the order we receive them. + lastIndex := -1 + stream.deliver = func(data []byte, final bool) error { + if final { + defer close(done) + } + return c.writeEvent(w, stream, data, &lastIndex) } } - _, _ = w.Write(data) // ignore error: client disconnected -} -// lastIndex is the index of the last seen event if resuming, else -1. -func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) { - // Accept was checked in [StreamableHTTPHandler] - w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] - w.Header().Set("Connection", "keep-alive") - if c.sessionID != "" && stream.isInitialize { - w.Header().Set(sessionIDHeader, c.sessionID) - } - if persistent { - // Issue #410: the hanging GET is likely not to receive messages for a long - // time. Ensure that headers are flushed. - // - // For non-persistent requests, delay the writing of the header in case we - // may want to set an error status. - // (see the TODO: this probably isn't worth it). - w.WriteHeader(http.StatusOK) - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - } + // Release ownership of the stream by unsetting deliver. + defer func() { + stream.mu.Lock() + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + stream.deliver = nil + stream.mu.Unlock() + }() - // write one event containing data. - writes := 0 - write := func(data []byte) bool { - lastIndex++ - e := Event{ - Name: "message", - ID: formatEventID(stream.id, lastIndex), - Data: data, - } - if _, err := writeEvent(w, e); err != nil { - // Connection closed or broken. - c.logger.Warn("error writing event", "error", err) - return false - } - writes++ - return true + // The stream is now set up to deliver messages. + // + // Register it before publishing incoming messages. + c.mu.Lock() + c.streams[stream.id] = stream + for reqID := range calls { + c.requestStreams[reqID] = stream.id } + c.mu.Unlock() - // Repeatedly collect pending outgoing events and send them. - ctx := req.Context() - for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { - if err != nil { - if ctx.Err() == nil && writes == 0 && !persistent { - // If we haven't yet written the header, we have an opportunity to - // promote an error to an HTTP error. - // - // TODO: This may not matter in practice, in which case we should - // simplify. - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - } else { - if ctx.Err() != nil { - // Client disconnected or cancelled the request. - c.logger.Error("stream context done", "error", ctx.Err()) - } else { - // Some other error. - c.logger.Error("error receiving message", "error", err) - } - } - return - } - if !write(msg) { + // Publish incoming messages. + for _, msg := range incoming { + select { + case c.incoming <- msg: + // Note: don't select on req.Context().Done() here, since we've already + // received the requests and may have already published a response message + // or notification. The client could resume the stream. + case <-c.done: + // Session closed: we don't know if any data has been written, so it's + // too late to write a status code here. return } } -} - -// messages iterates over messages sent to the current stream. -// -// persistent indicates if it is the main GET listener, which should never be -// terminated. -// lastIndex is the index of the last seen event, iteration begins at lastIndex+1. -// -// The first iterated value is the received JSON message. The second iterated -// value is an error value indicating whether the stream terminated normally. -// Iteration ends at the first non-nil error. -// -// If the stream did not terminate normally, it is either because ctx was -// cancelled, or the connection is closed: check the ctx.Err() to differentiate -// these cases. -func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] { - return func(yield func(json.RawMessage, error) bool) { - for { - c.mu.Lock() - nOutstanding := len(stream.requests) - c.mu.Unlock() - for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) { - if err != nil { - yield(nil, err) - return - } - if !yield(data, nil) { - return - } - lastIndex++ - } - // If all requests have been handled and replied to, we should terminate this connection. - // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." - // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server - // We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET - // (stream ID 0) will never have requests, and should remain open indefinitely. - if nOutstanding == 0 && !persistent { - return - } - - select { - case <-*stream.signal.Load(): // there are new outgoing messages - // return to top of loop - case <-c.done: // session is closed - yield(nil, errors.New("session is closed")) - return - case <-ctx.Done(): - yield(nil, ctx.Err()) - return - } - } + select { + case <-req.Context().Done(): + // request cancelled + case <-done: + // request complete + case <-c.done: + // session is closed } } @@ -907,80 +959,93 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error // Write implements the [Connection] interface. func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // Throughout this function, note that any error that wraps ErrRejected + // indicates a does not cause the connection to break. + // + // Most errors don't break the connection: unlike a true bidirectional + // stream, a failure to deliver to a stream is not an indication that the + // logical session is broken. + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") { // Requests aren't possible with stateless servers, or when there's no session ID. return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) } + // Find the incoming request that this write relates to, if any. - var forRequest jsonrpc.ID - isResponse := false + var ( + relatedRequest jsonrpc.ID + responseTo jsonrpc.ID // if valid, the message is a response to this request + ) if resp, ok := msg.(*jsonrpc.Response); ok { // If the message is a response, it relates to its request (of course). - forRequest = resp.ID - isResponse = true + relatedRequest = resp.ID + responseTo = resp.ID } else { // Otherwise, we check to see if it request was made in the context of an // ongoing request. This may not be the case if the request was made with // an unrelated context. if v := ctx.Value(idContextKey{}); v != nil { - forRequest = v.(jsonrpc.ID) + relatedRequest = v.(jsonrpc.ID) } } - // Find the logical connection corresponding to this request. - // - // For messages sent outside of a request context, this is the default - // connection "". - var forStream string - if forRequest.IsValid() { - c.mu.Lock() - forStream = c.requestStreams[forRequest] - c.mu.Unlock() - } - - data, err := jsonrpc2.EncodeMessage(msg) - if err != nil { - return err + // If the stream is application/json, but the message is not a response, we + // must send it out of band to the standalone SSE stream. + if c.jsonResponse && !responseTo.IsValid() { + relatedRequest = jsonrpc.ID{} } + // Write the message to the stream. + var s *stream c.mu.Lock() - defer c.mu.Unlock() - if c.isDone { - return errors.New("session is closed") + if relatedRequest.IsValid() { + if streamID, ok := c.requestStreams[relatedRequest]; ok { + s = c.streams[streamID] + } + } else { + s = c.streams[""] } + sessionClosed := c.isDone + c.mu.Unlock() - stream := c.streams[forStream] - if stream == nil { - return fmt.Errorf("no stream with ID %s", forStream) + if s == nil { + return fmt.Errorf("%w: no stream for request", jsonrpc2.ErrRejected) } - - // Special case a few conditions where we fall back on stream 0 (the hanging GET): - // - // - if forStream is known, but the associated stream is logically complete - // - if the stream is application/json, but the message is not a response - // - // TODO(rfindley): either of these, particularly the first, might be - // considered a bug in the server. Report it through a side-channel? - if len(stream.requests) == 0 && forStream != "" || stream.jsonResponse && !isResponse { - stream = c.streams[""] + if sessionClosed { + return errors.New("session is closed") } - if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil { - return fmt.Errorf("error storing event: %w", err) + s.mu.Lock() + defer s.mu.Unlock() + if s.doneLocked() { + return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) } - if isResponse { - // Once we've put the reply on the queue, it's no longer outstanding. - delete(stream.requests, forRequest) + if responseTo.IsValid() { + delete(s.requests, responseTo) } - // Signal streamResponse that new work is available. - signalp := stream.signal.Load() - if signalp != nil { - select { - case *signalp <- struct{}{}: - default: + delivered := false + if c.eventStore != nil { + if err := c.eventStore.Append(ctx, c.sessionID, s.id, data); err != nil { + // TODO: report a side-channel error. + } else { + delivered = true } } + if s.deliver != nil { + if err := s.deliver(data, s.doneLocked()); err != nil { + // TODO: report a side-channel error. + } else { + delivered = true + } + } + if !delivered { + return fmt.Errorf("%w: undelivered message", jsonrpc2.ErrRejected) + } return nil } @@ -991,9 +1056,11 @@ func (c *streamableServerConn) Close() error { if !c.isDone { c.isDone = true close(c.done) - // TODO: find a way to plumb a context here, or an event store with a long-running - // close operation can take arbitrary time. Alternative: impose a fixed timeout here. - return c.eventStore.SessionClosed(context.TODO(), c.sessionID) + if c.eventStore != nil { + // TODO: find a way to plumb a context here, or an event store with a long-running + // close operation can take arbitrary time. Alternative: impose a fixed timeout here. + return c.eventStore.SessionClosed(context.TODO(), c.sessionID) + } } return nil } @@ -1118,7 +1185,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { c.initializedResult = state.InitializeResult c.mu.Unlock() - // Start the persistent SSE listener as soon as we have the initialized + // Start the standalone SSE stream as soon as we have the initialized // result. // // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be @@ -1131,7 +1198,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // § 2.5: A server using the Streamable HTTP transport MAY assign a session // ID at initialization time, by including it in an Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - go c.handleSSE("hanging GET", nil, true, nil) + go c.handleSSE("standalone SSE stream", nil, true, nil) } // fail handles an asynchronous error while reading. @@ -1272,6 +1339,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e case "text/event-stream": jsonReq, _ := msg.(*jsonrpc.Request) + // TODO: should we cancel this logical SSE request if/when jsonReq is canceled? go c.handleSSE(requestSummary, resp, false, jsonReq) default: @@ -1328,6 +1396,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt resp := initialResp var lastEventID string for { + // TODO: we should set a reasonable limit on the number of times we'll try + // getting a response for a given request. + // + // Eventually, if we don't get the response, we should stop trying and + // fail the request. if resp != nil { eventID, clientClosed := c.processStream(requestSummary, resp, forReq) lastEventID = eventID @@ -1359,17 +1432,17 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt // // [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server - // The server doesn't support the hanging GET. + // The server doesn't support the standalone SSE stream. resp.Body.Close() return } if resp.StatusCode == http.StatusNotFound && persistent && !c.strict { // modelcontextprotocol/gosdk#393: some servers return NotFound instead - // of MethodNotAllowed for the persistent GET. + // of MethodNotAllowed for the standalone SSE stream. // // Treat this like MethodNotAllowed in non-strict mode. if c.logger != nil { - c.logger.Warn("got 404 instead of 405 for hanging GET") + c.logger.Warn("got 404 instead of 405 for standalonw SSE stream") } resp.Body.Close() return @@ -1412,8 +1485,8 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R select { case c.incoming <- msg: if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil { - // TODO: we should never get a response when forReq is nil (the hanging GET). - // We should detect this case, and eliminate the 'persistent' flag arguments. + // TODO: we should never get a response when forReq is nil (the standalone SSE request). + // We should detect this case. if jsonResp.ID == forReq.ID { return "", true } @@ -1433,7 +1506,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { var finalErr error - // We can reach the 'reconnect' path through the hanging GET, in which case + // We can reach the 'reconnect' path through the standlone SSE request, in which case // lastEventID will be "". // // In this case, we need an initial attempt. diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 9116677b..faca04c6 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -246,7 +246,7 @@ func TestStreamableClientGETHandling(t *testing.T) { }{ {http.StatusOK, ""}, {http.StatusMethodNotAllowed, ""}, - {http.StatusBadRequest, "hanging GET"}, + {http.StatusBadRequest, "standalone SSE"}, } for _, test := range tests { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 3c0b8be3..a0893689 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -41,8 +41,18 @@ func TestStreamableTransports(t *testing.T) { ctx := context.Background() - for _, useJSON := range []bool{false, true} { - t.Run(fmt.Sprintf("JSONResponse=%v", useJSON), func(t *testing.T) { + tests := []struct { + useJSON bool + replay bool + }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("JSONResponse=%v;replay=%v", test.useJSON, test.replay), func(t *testing.T) { // Create a server with some simple tools. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) @@ -86,7 +96,12 @@ func TestStreamableTransports(t *testing.T) { // Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ - JSONResponse: useJSON, + JSONResponse: test.useJSON, + configureTransport: func(_ *http.Request, transport *StreamableServerTransport) { + if test.replay { + transport.EventStore = NewMemoryEventStore(nil) + } + }, }) var ( @@ -370,7 +385,11 @@ func testClientReplay(t *testing.T, test clientReplayTest) { return new(CallToolResult), nil, nil }) - realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))) + realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + configureTransport: func(_ *http.Request, t *StreamableServerTransport) { + t.EventStore = NewMemoryEventStore(nil) // necessary for replay + }, + }))) t.Cleanup(func() { t.Log("Closing real HTTP server") realServer.Close() @@ -543,7 +562,16 @@ func TestServerInitiatedSSE(t *testing.T) { notifications := make(chan string) server := NewServer(testImpl, nil) - httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))) + opts := &StreamableHTTPOptions{ + // TODO(#583): for now, this is required for guaranteed message delivery. + // However, it shouldn't be necessary to use replay here, as we should be + // guaranteed that the standalone SSE stream is started by the time the + // client is connected. + configureTransport: func(_ *http.Request, transport *StreamableServerTransport) { + transport.EventStore = NewMemoryEventStore(nil) + }, + } + httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, opts))) defer httpServer.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -640,6 +668,7 @@ func TestStreamableServerTransport(t *testing.T) { tests := []struct { name string + replay bool // if set, use a MemoryEventStore to enable stream replay tool func(*testing.T, context.Context, *ServerSession) requests []streamableRequest // http requests }{ @@ -804,10 +833,15 @@ func TestStreamableServerTransport(t *testing.T) { }, { name: "background", + // Enabling replay is necessary here because the standalone "GET" request + // is fully asynronous. Replay is needed to guarantee message delivery. + replay: true, tool: func(t *testing.T, _ context.Context, ss *ServerSession) { // Perform operations on a background context, and ensure the client // receives it. - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + if err := ss.NotifyProgress(ctx, &ProgressNotificationParams{}); err != nil { t.Errorf("Notify failed: %v", err) } @@ -816,7 +850,7 @@ func TestStreamableServerTransport(t *testing.T) { // t.Errorf("Logging failed: %v", err) // } if _, err := ss.ListRoots(ctx, &ListRootsParams{}); err != nil { - t.Errorf("Notify failed: %v", err) + t.Errorf("ListRoots failed: %v", err) } }, requests: []streamableRequest{ @@ -906,8 +940,14 @@ func TestStreamableServerTransport(t *testing.T) { return &CallToolResult{}, nil }) + opts := &StreamableHTTPOptions{} + if test.replay { + opts.configureTransport = func(_ *http.Request, t *StreamableServerTransport) { + t.EventStore = NewMemoryEventStore(nil) + } + } // Start the streamable handler. - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts) defer handler.closeAll() testStreamableHandler(t, handler, test.requests) diff --git a/mcp/transport.go b/mcp/transport.go index 1beab470..cacd65fd 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -46,7 +46,7 @@ type Connection interface { // Write writes a new message to the connection. // - // Write may be called concurrently, as calls or reponses may occur + // Write may be called concurrently, as calls or responses may occur // concurrently in user code. Write(context.Context, jsonrpc.Message) error