Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 29 additions & 53 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ type stream struct {
// that there are messages available to write into the HTTP response.
// In addition, the presence of a channel guarantees that at most one HTTP response
// can receive messages for a logical stream. After claiming the stream, incoming
// requests should read from outgoing, to ensure that no new messages are missed.
// requests should read from the event store, to ensure that no new messages are missed.
//
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because
// you can't set an atomic.Value to nil.
Expand All @@ -502,13 +502,9 @@ type stream struct {
// The following mutable fields are protected by the mutex of the containing
// StreamableServerTransport.

// outgoing is the list of outgoing messages, enqueued by server methods that
// write notifications and responses, and dequeued by streamResponse.
outgoing [][]byte

// streamRequests is the set of unanswered incoming RPCs for the stream.
//
// Requests persist until their response data has been added to outgoing.
// Requests persist until their response data has been added to the event store.
requests map[jsonrpc.ID]struct{}
}

Expand Down Expand Up @@ -706,8 +702,8 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter

var msgs []json.RawMessage
ctx := req.Context()
for msg, ok := range c.messages(ctx, stream, false) {
if !ok {
for msg, err := range c.messages(ctx, stream, false, -1) {
if err != nil {
if ctx.Err() != nil {
w.WriteHeader(http.StatusNoContent)
return
Expand Down Expand Up @@ -770,44 +766,20 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
}
}

if lastIndex >= 0 {
// Resume.
for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) {
if err != nil {
// TODO: reevaluate these status codes.
// Maybe distinguish between storage errors, which are 500s, and missing
// session or stream ID--can these arise from bad input?
status := http.StatusInternalServerError
if errors.Is(err, ErrEventsPurged) {
status = http.StatusInsufficientStorage
}
errorf(status, "failed to read events: %v", err)
return
}
// The iterator yields events beginning just after lastIndex, or it would have
// yielded an error.
if !write(data) {
return
}
}
}

// Repeatedly collect pending outgoing events and send them.
ctx := req.Context()
for msg, ok := range c.messages(ctx, stream, persistent) {
if !ok {
for msg, err := range c.messages(ctx, stream, persistent, lastIndex) {
if err != nil {
if ctx.Err() != nil && writes == 0 {
// This probably doesn't matter, but respond with NoContent if the client disconnected.
w.WriteHeader(http.StatusNoContent)
} else if errors.Is(err, ErrEventsPurged) {
errorf(http.StatusInsufficientStorage, "failed to read events: %v", err)
} else {
errorf(http.StatusGone, "stream terminated")
}
return
}
if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil {
errorf(http.StatusInternalServerError, "storing event: %v", err.Error())
return
}
if !write(msg) {
return
}
Expand All @@ -816,27 +788,30 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,

// messages iterates over messages sent to the current stream.
//
// persistent indicates if it is the main GET listener.
// lastIndex is the index of the last seen event.
//
// The first iterated value is the received JSON message. The second iterated
// value is an OK value indicating whether the stream terminated normally.
// value is an error value indicating whether the stream terminated normally.
//
// If the stream did not terminate normally, it is either because ctx was
// cancelled, or the connection is closed: check the ctx.Err() to differentiate
// these cases.
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] {
return func(yield func(json.RawMessage, bool) bool) {
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] {
return func(yield func(json.RawMessage, error) bool) {
for {
c.mu.Lock()
outgoing := stream.outgoing
stream.outgoing = nil
nOutstanding := len(stream.requests)
c.mu.Unlock()

for _, data := range outgoing {
if !yield(data, true) {
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) {
if err != nil {
break
}
if !yield(data, nil) {
return
}
lastIndex++
}

c.mu.Lock()
nOutstanding := len(stream.requests)
c.mu.Unlock()
// If all requests have been handled and replied to, we should terminate this connection.
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
Expand All @@ -850,13 +825,14 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per
case <-*stream.signal.Load(): // there are new outgoing messages
// return to top of loop
case <-c.done: // session is closed
yield(nil, false)
yield(nil, errors.New("session is closed"))
return
case <-ctx.Done():
yield(nil, false)
yield(nil, ctx.Err())
return
}
}

}
}

Expand Down Expand Up @@ -963,9 +939,9 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
stream = c.streams[""]
}

// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == ""
// and the client never did a GET), then memory will grow without bound. Consider a mitigation.
stream.outgoing = append(stream.outgoing, data)
if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil {
return fmt.Errorf("error storing event: %w", err)
}
if isResponse {
// Once we've put the reply on the queue, it's no longer outstanding.
delete(stream.requests, forRequest)
Expand Down