Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 61 additions & 68 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ type StreamableServerTransport struct {
id string
opts StreamableServerTransportOptions
incoming chan jsonrpc.Message // messages from the client to the server
done chan struct{}

mu sync.Mutex
// Sessions are closed exactly once.
isDone bool
done chan struct{}

// Sessions can have multiple logical connections, corresponding to HTTP
// requests. Additionally, logical sessions may be resumed by subsequent HTTP
Expand Down Expand Up @@ -242,23 +242,26 @@ type stream struct {
// ID 0 is used for messages that don't correlate with an incoming request.
id StreamID

// These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
// signal is a 1-buffered channel, owned by an incoming HTTP request, that signals
// that there are messages available to write into the HTTP response.
// In addition, the presence of a channel guarantees that at most one HTTP response
// can receive messages for a logical stream. After claiming the stream, incoming
// requests should read from outgoing, to ensure that no new messages are missed.
//
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because
// you can't set an atomic.Value to nil.
//
// Lifecycle: each channel value persists for the duration of an HTTP POST or
// GET request for the given streamID.
signal atomic.Pointer[chan struct{}]

// The following mutable fields are protected by the mutex of the containing
// StreamableServerTransport.

// outgoing is the list of outgoing messages, enqueued by server methods that
// write notifications and responses, and dequeued by streamResponse.
outgoing [][]byte

// signal is a 1-buffered channel, owned by an
// incoming HTTP request, that signals that there are messages available to
// write into the HTTP response. This guarantees that at most one HTTP
// response can receive messages for a logical stream. After claiming
// the stream, incoming requests should read from outgoing, to ensure
// that no new messages are missed.
//
// Lifecycle: persists for the duration of an HTTP POST or GET
// request for the given streamID.
signal chan struct{}

// streamRequests is the set of unanswered incoming RPCs for the stream.
//
// Lifecycle: requests values persist until the requests have been
Expand All @@ -274,6 +277,11 @@ func newStream(id StreamID) *stream {
}
}

func signalChanPtr() *chan struct{} {
c := make(chan struct{}, 1)
return &c
}

// A StreamID identifies a stream of SSE events. It is unique within the stream's
// [ServerSession].
type StreamID int64
Expand Down Expand Up @@ -310,19 +318,25 @@ type idContextKey struct{}

// ServeHTTP handles a single HTTP request for the session.
func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) {
status := 0
message := ""
switch req.Method {
case http.MethodGet:
t.serveGET(w, req)
status, message = t.serveGET(w, req)
case http.MethodPost:
t.servePOST(w, req)
status, message = t.servePOST(w, req)
default:
// Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP.
w.Header().Set("Allow", "GET, POST")
http.Error(w, "unsupported method", http.StatusMethodNotAllowed)
status = http.StatusMethodNotAllowed
message = "unsupported method"
}
if status != 0 && status != http.StatusOK {
http.Error(w, message, status)
}
}

func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) {
func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) (int, string) {
// connID 0 corresponds to the default GET request.
id := StreamID(0)
// By default, we haven't seen a last index. Since indices start at 0, we represent
Expand All @@ -334,49 +348,39 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
var ok bool
id, lastIdx, ok = parseEventID(eid)
if !ok {
http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest)
return
return http.StatusBadRequest, fmt.Sprintf("malformed Last-Event-ID %q", eid)
}
}

t.mu.Lock()
stream, ok := t.streams[id]
t.mu.Unlock()
if !ok {
http.Error(w, "unknown stream", http.StatusBadRequest)
t.mu.Unlock()
return
return http.StatusBadRequest, "unknown stream"
}
if stream.signal != nil {
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest)
t.mu.Unlock()
return
if !stream.signal.CompareAndSwap(nil, signalChanPtr()) {
// The CAS returned false, meaning that the comparison failed: stream.signal is not nil.
return http.StatusBadRequest, "stream ID conflicts with ongoing stream"
}
stream.signal = make(chan struct{}, 1)
t.mu.Unlock()

t.streamResponse(stream, w, req, lastIdx)
return t.streamResponse(stream, w, req, lastIdx)
}

func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) {
func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) (int, string) {
if len(req.Header.Values("Last-Event-ID")) > 0 {
http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest)
return
return http.StatusBadRequest, "can't send Last-Event-ID for POST request"
}

// Read incoming messages.
body, err := io.ReadAll(req.Body)
if err != nil {
http.Error(w, "failed to read body", http.StatusBadRequest)
return
return http.StatusBadRequest, "failed to read body"
}
if len(body) == 0 {
http.Error(w, "POST requires a non-empty body", http.StatusBadRequest)
return
return http.StatusBadRequest, "POST requires a non-empty body"
}
incoming, _, err := readBatch(body)
if err != nil {
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
return
return http.StatusBadRequest, fmt.Sprintf("malformed payload: %v", err)
}
requests := make(map[jsonrpc.ID]struct{})
for _, msg := range incoming {
Expand All @@ -396,8 +400,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
t.requestStreams[reqID] = stream.id
stream.requests[reqID] = struct{}{}
}
stream.signal = make(chan struct{}, 1)
t.mu.Unlock()
stream.signal.Store(signalChanPtr())

// Publish incoming messages.
for _, msg := range incoming {
Expand All @@ -407,23 +411,12 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
// TODO(rfindley): consider optimizing for a single incoming request, by
// responding with application/json when there is only a single message in
// the response.
t.streamResponse(stream, w, req, -1)
return t.streamResponse(stream, w, req, -1)
}

// lastIndex is the index of the last seen event if resuming, else -1.
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) {
defer func() {
t.mu.Lock()
stream.signal = nil
t.mu.Unlock()
}()

t.mu.Lock()
// Although there is a gap in locking between when stream.signal is set and here,
// it cannot change, because it is changed only when non-nil, and it is only
// set to nil in the defer above.
signal := stream.signal
t.mu.Unlock()
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) (int, string) {
defer stream.signal.Store(nil)

writes := 0

Expand All @@ -437,7 +430,7 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon
}
if _, err := writeEvent(w, e); err != nil {
// Connection closed or broken.
// TODO: log when we add server-side logging.
// TODO(#170): log when we add server-side logging.
return false
}
writes++
Expand All @@ -460,13 +453,12 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon
if errors.Is(err, ErrEventsPurged) {
status = http.StatusInsufficientStorage
}
http.Error(w, err.Error(), status)
return
return status, err.Error()
}
// The iterator yields events beginning just after lastIndex, or it would have
// yielded an error.
if !write(data) {
return
return 0, ""
}
}
}
Expand All @@ -481,11 +473,10 @@ stream:

for _, data := range outgoing {
if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
return http.StatusInternalServerError, err.Error()
}
if !write(data) {
return
return 0, ""
}
}

Expand All @@ -495,22 +486,22 @@ stream:
// If all requests have been handled and replied to, we should terminate this connection.
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
// TODO(jba): why not terminate regardless of http method?
// TODO(jba,findleyr): why not terminate regardless of http method?
if req.Method == http.MethodPost && nOutstanding == 0 {
if writes == 0 {
// Spec: If the server accepts the input, the server MUST return HTTP
// status code 202 Accepted with no body.
w.WriteHeader(http.StatusAccepted)
}
return
return 0, ""
}

select {
case <-signal: // there are new outgoing messages
case <-*stream.signal.Load(): // there are new outgoing messages
// return to top of loop
case <-t.done: // session is closed
if writes == 0 {
http.Error(w, "session terminated", http.StatusGone)
return http.StatusGone, "session terminated"
}
break stream
case <-req.Context().Done():
Expand All @@ -520,6 +511,7 @@ stream:
break stream
}
}
return 0, ""
}

// Event IDs: encode both the logical connection ID and the index, as
Expand Down Expand Up @@ -625,10 +617,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
delete(stream.requests, replyTo)
}

// Signal work.
if stream.signal != nil {
// Signal streamResponse that new work is available.
signalp := stream.signal.Load()
if signalp != nil {
select {
case stream.signal <- struct{}{}:
case *signalp <- struct{}{}:
default:
}
}
Expand Down
Loading