diff --git a/mcp/streamable.go b/mcp/streamable.go index e15c6a69..7b814012 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -204,11 +204,11 @@ type StreamableServerTransport struct { id string opts StreamableServerTransportOptions incoming chan jsonrpc.Message // messages from the client to the server + done chan struct{} mu sync.Mutex // Sessions are closed exactly once. isDone bool - done chan struct{} // Sessions can have multiple logical connections, corresponding to HTTP // requests. Additionally, logical sessions may be resumed by subsequent HTTP @@ -242,23 +242,26 @@ type stream struct { // ID 0 is used for messages that don't correlate with an incoming request. id StreamID - // These mutable fields are protected by the mutex of the corresponding StreamableServerTransport. + // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals + // that there are messages available to write into the HTTP response. + // In addition, the presence of a channel guarantees that at most one HTTP response + // can receive messages for a logical stream. After claiming the stream, incoming + // requests should read from outgoing, to ensure that no new messages are missed. + // + // To simplify locking, signal is an atomic. We need an atomic.Pointer, because + // you can't set an atomic.Value to nil. + // + // Lifecycle: each channel value persists for the duration of an HTTP POST or + // GET request for the given streamID. + signal atomic.Pointer[chan struct{}] + + // The following mutable fields are protected by the mutex of the containing + // StreamableServerTransport. // outgoing is the list of outgoing messages, enqueued by server methods that // write notifications and responses, and dequeued by streamResponse. outgoing [][]byte - // signal is a 1-buffered channel, owned by an - // incoming HTTP request, that signals that there are messages available to - // write into the HTTP response. This guarantees that at most one HTTP - // response can receive messages for a logical stream. After claiming - // the stream, incoming requests should read from outgoing, to ensure - // that no new messages are missed. - // - // Lifecycle: persists for the duration of an HTTP POST or GET - // request for the given streamID. - signal chan struct{} - // streamRequests is the set of unanswered incoming RPCs for the stream. // // Lifecycle: requests values persist until the requests have been @@ -274,6 +277,11 @@ func newStream(id StreamID) *stream { } } +func signalChanPtr() *chan struct{} { + c := make(chan struct{}, 1) + return &c +} + // A StreamID identifies a stream of SSE events. It is unique within the stream's // [ServerSession]. type StreamID int64 @@ -310,19 +318,25 @@ type idContextKey struct{} // ServeHTTP handles a single HTTP request for the session. func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + status := 0 + message := "" switch req.Method { case http.MethodGet: - t.serveGET(w, req) + status, message = t.serveGET(w, req) case http.MethodPost: - t.servePOST(w, req) + status, message = t.servePOST(w, req) default: // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP. w.Header().Set("Allow", "GET, POST") - http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + status = http.StatusMethodNotAllowed + message = "unsupported method" + } + if status != 0 && status != http.StatusOK { + http.Error(w, message, status) } } -func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { +func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) (int, string) { // connID 0 corresponds to the default GET request. id := StreamID(0) // By default, we haven't seen a last index. Since indices start at 0, we represent @@ -334,49 +348,39 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re var ok bool id, lastIdx, ok = parseEventID(eid) if !ok { - http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) - return + return http.StatusBadRequest, fmt.Sprintf("malformed Last-Event-ID %q", eid) } } t.mu.Lock() stream, ok := t.streams[id] + t.mu.Unlock() if !ok { - http.Error(w, "unknown stream", http.StatusBadRequest) - t.mu.Unlock() - return + return http.StatusBadRequest, "unknown stream" } - if stream.signal != nil { - http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest) - t.mu.Unlock() - return + if !stream.signal.CompareAndSwap(nil, signalChanPtr()) { + // The CAS returned false, meaning that the comparison failed: stream.signal is not nil. + return http.StatusBadRequest, "stream ID conflicts with ongoing stream" } - stream.signal = make(chan struct{}, 1) - t.mu.Unlock() - - t.streamResponse(stream, w, req, lastIdx) + return t.streamResponse(stream, w, req, lastIdx) } -func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) { +func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) (int, string) { if len(req.Header.Values("Last-Event-ID")) > 0 { - http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) - return + return http.StatusBadRequest, "can't send Last-Event-ID for POST request" } // Read incoming messages. body, err := io.ReadAll(req.Body) if err != nil { - http.Error(w, "failed to read body", http.StatusBadRequest) - return + return http.StatusBadRequest, "failed to read body" } if len(body) == 0 { - http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) - return + return http.StatusBadRequest, "POST requires a non-empty body" } incoming, _, err := readBatch(body) if err != nil { - http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) - return + return http.StatusBadRequest, fmt.Sprintf("malformed payload: %v", err) } requests := make(map[jsonrpc.ID]struct{}) for _, msg := range incoming { @@ -396,8 +400,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R t.requestStreams[reqID] = stream.id stream.requests[reqID] = struct{}{} } - stream.signal = make(chan struct{}, 1) t.mu.Unlock() + stream.signal.Store(signalChanPtr()) // Publish incoming messages. for _, msg := range incoming { @@ -407,23 +411,12 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R // TODO(rfindley): consider optimizing for a single incoming request, by // responding with application/json when there is only a single message in // the response. - t.streamResponse(stream, w, req, -1) + return t.streamResponse(stream, w, req, -1) } // lastIndex is the index of the last seen event if resuming, else -1. -func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) { - defer func() { - t.mu.Lock() - stream.signal = nil - t.mu.Unlock() - }() - - t.mu.Lock() - // Although there is a gap in locking between when stream.signal is set and here, - // it cannot change, because it is changed only when non-nil, and it is only - // set to nil in the defer above. - signal := stream.signal - t.mu.Unlock() +func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) (int, string) { + defer stream.signal.Store(nil) writes := 0 @@ -437,7 +430,7 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon } if _, err := writeEvent(w, e); err != nil { // Connection closed or broken. - // TODO: log when we add server-side logging. + // TODO(#170): log when we add server-side logging. return false } writes++ @@ -460,13 +453,12 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon if errors.Is(err, ErrEventsPurged) { status = http.StatusInsufficientStorage } - http.Error(w, err.Error(), status) - return + return status, err.Error() } // The iterator yields events beginning just after lastIndex, or it would have // yielded an error. if !write(data) { - return + return 0, "" } } } @@ -481,11 +473,10 @@ stream: for _, data := range outgoing { if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + return http.StatusInternalServerError, err.Error() } if !write(data) { - return + return 0, "" } } @@ -495,22 +486,22 @@ stream: // If all requests have been handled and replied to, we should terminate this connection. // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." // ยง6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server - // TODO(jba): why not terminate regardless of http method? + // TODO(jba,findleyr): why not terminate regardless of http method? if req.Method == http.MethodPost && nOutstanding == 0 { if writes == 0 { // Spec: If the server accepts the input, the server MUST return HTTP // status code 202 Accepted with no body. w.WriteHeader(http.StatusAccepted) } - return + return 0, "" } select { - case <-signal: // there are new outgoing messages + case <-*stream.signal.Load(): // there are new outgoing messages // return to top of loop case <-t.done: // session is closed if writes == 0 { - http.Error(w, "session terminated", http.StatusGone) + return http.StatusGone, "session terminated" } break stream case <-req.Context().Done(): @@ -520,6 +511,7 @@ stream: break stream } } + return 0, "" } // Event IDs: encode both the logical connection ID and the index, as @@ -625,10 +617,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa delete(stream.requests, replyTo) } - // Signal work. - if stream.signal != nil { + // Signal streamResponse that new work is available. + signalp := stream.signal.Load() + if signalp != nil { select { - case stream.signal <- struct{}{}: + case *signalp <- struct{}{}: default: } }