Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 86 additions & 66 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,10 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
id: sessionID,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
outgoing: make(map[StreamID][][]byte),
signals: make(map[StreamID]chan struct{}),
streams: make(map[StreamID]*stream),
requestStreams: make(map[jsonrpc.ID]StreamID),
streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}),
}
t.streams[0] = newStream(0)
if opts != nil {
t.opts = *opts
}
Expand Down Expand Up @@ -213,59 +212,66 @@ type StreamableServerTransport struct {
// perform the accounting described below when incoming HTTP requests are
// handled.
//
// The accounting is complicated. It is tempting to merge some of the maps
// below, but they each have different lifecycles, as indicated by Lifecycle:
// comments.
//
// TODO(rfindley): simplify.

// outgoing is the collection of outgoing messages, keyed by the logical
// stream ID where they should be delivered.
// streams holds the logical streams for this session, keyed by their ID.
streams map[StreamID]*stream

// requestStreams maps incoming requests to their logical stream ID.
//
// streamID 0 is used for messages that don't correlate with an incoming
// request.
// Lifecycle: requestStreams persists for the duration of the session.
//
// Lifecycle: persists for the duration of the session.
outgoing map[StreamID][][]byte
// TODO(rfindley): clean up once requests are handled.
requestStreams map[jsonrpc.ID]StreamID
}

// A stream is a single logical stream of SSE events within a server session.
// A stream begins with a client request, or with a client GET that has
// no Last-Event-ID header.
// A stream ends only when its session ends; we cannot determine its end otherwise,
// since a client may send a GET with a Last-Event-ID that references the stream
// at any time.
type stream struct {
// id is the logical ID for the stream, unique within a session.
// ID 0 is used for messages that don't correlate with an incoming request.
id StreamID

// signals maps a logical stream ID to a 1-buffered channel, owned by an
// These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.

// outgoing is the list of outgoing messages, enqueued by server methods that
// write notifications and responses, and dequeued by streamResponse.
outgoing [][]byte

// signal is a 1-buffered channel, owned by an
// incoming HTTP request, that signals that there are messages available to
// write into the HTTP response. Signals guarantees that at most one HTTP
// write into the HTTP response. This guarantees that at most one HTTP
// response can receive messages for a logical stream. After claiming
// the stream, incoming requests should read from outgoing, to ensure
// that no new messages are missed.
//
// Lifecycle: signals persists for the duration of an HTTP POST or GET
// Lifecycle: persists for the duration of an HTTP POST or GET
// request for the given streamID.
signals map[StreamID]chan struct{}
signal chan struct{}

// requestStreams maps incoming requests to their logical stream ID.
//
// Lifecycle: requestStreams persists for the duration of the session.
// streamRequests is the set of unanswered incoming RPCs for the stream.
//
// TODO(rfindley): clean up once requests are handled.
requestStreams map[jsonrpc.ID]StreamID

// streamRequests tracks the set of unanswered incoming RPCs for each logical
// stream.
//
// When the server has responded to each request, the stream should be
// closed.
//
// Lifecycle: streamRequests values persist as until the requests have been
// Lifecycle: requests values persist as until the requests have been
// replied to by the server. Notably, NOT until they are sent to an HTTP
// response, as delivery is not guaranteed.
streamRequests map[StreamID]map[jsonrpc.ID]struct{}
requests map[jsonrpc.ID]struct{}
}

type StreamID int64

// a streamableMsg is an SSE event with an index into its logical stream.
type streamableMsg struct {
idx int
event Event
func newStream(id StreamID) *stream {
return &stream{
id: id,
requests: make(map[jsonrpc.ID]struct{}),
}
}

// A StreamID identifies a stream of SSE events. It is unique within the stream's
// [ServerSession].
type StreamID int64

// Connect implements the [Transport] interface.
//
// TODO(rfindley): Connect should return a new object.
Expand Down Expand Up @@ -328,16 +334,21 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
}

t.mu.Lock()
if _, ok := t.signals[id]; ok {
stream, ok := t.streams[id]
if !ok {
http.Error(w, "unknown stream", http.StatusBadRequest)
t.mu.Unlock()
return
}
if stream.signal != nil {
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest)
t.mu.Unlock()
return
}
signal := make(chan struct{}, 1)
t.signals[id] = signal
stream.signal = make(chan struct{}, 1)
t.mu.Unlock()

t.streamResponse(w, req, id, lastIdx, signal)
t.streamResponse(stream, w, req, lastIdx)
}

func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -369,17 +380,17 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
}

// Update accounting for this request.
id := StreamID(t.nextStreamID.Add(1))
signal := make(chan struct{}, 1)
stream := newStream(StreamID(t.nextStreamID.Add(1)))
t.mu.Lock()
t.streams[stream.id] = stream
if len(requests) > 0 {
t.streamRequests[id] = make(map[jsonrpc.ID]struct{})
stream.requests = make(map[jsonrpc.ID]struct{})
}
for reqID := range requests {
t.requestStreams[reqID] = id
t.streamRequests[id][reqID] = struct{}{}
t.requestStreams[reqID] = stream.id
stream.requests[reqID] = struct{}{}
}
t.signals[id] = signal
stream.signal = make(chan struct{}, 1)
t.mu.Unlock()

// Publish incoming messages.
Expand All @@ -390,29 +401,37 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
// TODO(rfindley): consider optimizing for a single incoming request, by
// responding with application/json when there is only a single message in
// the response.
t.streamResponse(w, req, id, -1, signal)
t.streamResponse(stream, w, req, -1)
}

// lastIndex is the index of the last seen event if resuming, else -1.
func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, lastIndex int, signal chan struct{}) {
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) {
defer func() {
t.mu.Lock()
delete(t.signals, id)
stream.signal = nil
t.mu.Unlock()
}()

t.mu.Lock()
// Although there is a gap in locking between when stream.signal is set and here,
// it cannot change, because it is changed only when non-nil, and it is only
// set to nil in the defer above.
signal := stream.signal
t.mu.Unlock()

writes := 0

// write one event containing data.
write := func(data []byte) bool {
lastIndex++
e := Event{
Name: "message",
ID: formatEventID(id, lastIndex),
ID: formatEventID(stream.id, lastIndex),
Data: data,
}
if _, err := writeEvent(w, e); err != nil {
// Connection closed or broken.
// TODO: log when we add server-side logging.
return false
}
writes++
Expand All @@ -426,7 +445,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h

if lastIndex >= 0 {
// Resume.
for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), id, lastIndex) {
for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), stream.id, lastIndex) {
if err != nil {
// TODO: reevaluate these status codes.
// Maybe distinguish between storage errors, which are 500s, and missing
Expand All @@ -450,12 +469,12 @@ stream:
// Repeatedly collect pending outgoing events and send them.
for {
t.mu.Lock()
outgoing := t.outgoing[id]
t.outgoing[id] = nil
outgoing := stream.outgoing
stream.outgoing = nil
t.mu.Unlock()

for _, data := range outgoing {
if err := t.opts.EventStore.Append(req.Context(), t.id, id, data); err != nil {
if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil {
Copy link

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent method call: using t.SessionID() here but t.id was used in the original code. This may cause incorrect session identification in the event store.

Suggested change
if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil {
if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), t.SessionID(), data); err != nil {

Copilot uses AI. Check for mistakes.
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -465,7 +484,7 @@ stream:
}

t.mu.Lock()
nOutstanding := len(t.streamRequests[id])
nOutstanding := len(stream.requests)
t.mu.Unlock()
// If all requests have been handled and replied to, we should terminate this connection.
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
Expand Down Expand Up @@ -579,30 +598,31 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
t.mu.Lock()
defer t.mu.Unlock()
if t.isDone {
return fmt.Errorf("session is closed") // TODO: should this be EOF?
return errors.New("session is closed") // TODO: should this be EOF?
}

if _, ok := t.streamRequests[forConn]; !ok && forConn != 0 {
stream := t.streams[forConn]
if stream == nil {
return fmt.Errorf("no stream with ID %d", forConn)
}
if len(stream.requests) == 0 && forConn != 0 {
// No outstanding requests for this connection, which means it is logically
// done. This is a sequencing violation from the server, so we should report
// a side-channel error here. Put the message on the general queue to avoid
// dropping messages.
forConn = 0
stream = t.streams[0]
}

t.outgoing[forConn] = append(t.outgoing[forConn], data)
stream.outgoing = append(stream.outgoing, data)
if replyTo.IsValid() {
// Once we've put the reply on the queue, it's no longer outstanding.
delete(t.streamRequests[forConn], replyTo)
if len(t.streamRequests[forConn]) == 0 {
delete(t.streamRequests, forConn)
}
delete(stream.requests, replyTo)
}

// Signal work.
if c, ok := t.signals[forConn]; ok {
if stream.signal != nil {
select {
case c <- struct{}{}:
case stream.signal <- struct{}{}:
default:
}
}
Expand Down
Loading