Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions mcp/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
//
// All of an EventStore's methods must be safe for use by multiple goroutines.
type EventStore interface {
// Open prepares the event store for a given session. It ensures that the
// underlying data structure for the sessionID is initialized, making it
// ready to store event streams.
Open(_ context.Context, sessionID string, streamID StreamID) error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be called once per session, or once per stream? If the latter, the doc is wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like once per stream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should call this OpenStream or StartStream to make that clearer? I'd be fine with either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything is XXXStream. Append to stream, get messages from stream. The only session method has "Session" in the name. So "Open" is right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I said that badly. I meant that EventStore.M is about streams unless M explicitly says otherwise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, Open takes a stream ID, so of course it is opening a stream.


// Append appends data for an outgoing event to given stream, which is part of the
// given session.
Append(_ context.Context, sessionID string, _ StreamID, data []byte) error
Expand Down Expand Up @@ -256,11 +261,21 @@ func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore {
}
}

// Append implements [EventStore.Append] by recording data in memory.
func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error {
// Open implements [EventStore.Open]. It ensures that the underlying data
// structures for the given session and stream IDs are initialized and ready
// for use.
func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID StreamID) error {
s.mu.Lock()
defer s.mu.Unlock()
s.init(sessionID, streamID)
return nil
}

// init is an internal helper function that ensures the nested map structure for a
// given sessionID and streamID exists, creating it if necessary. It returns the
// dataList associated with the specified IDs.
// This function must be called within a locked context.
func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList {
streamMap, ok := s.store[sessionID]
if !ok {
streamMap = make(map[StreamID]*dataList)
Expand All @@ -271,6 +286,14 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID
dl = &dataList{}
streamMap[streamID] = dl
}
return dl
}

// Append implements [EventStore.Append] by recording data in memory.
func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
dl := s.init(sessionID, streamID)
// Purge before adding, so at least the current data item will be present.
// (That could result in nBytes > maxBytes, but we'll live with that.)
s.purge()
Expand Down
102 changes: 45 additions & 57 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
}

// Connect implements the [Transport] interface.
func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) {
func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) {
if t.connection != nil {
return nil, fmt.Errorf("transport already connected")
}
Expand All @@ -415,13 +415,17 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
streams: make(map[StreamID]*stream),
requestStreams: make(map[jsonrpc.ID]StreamID),
}
if t.connection.eventStore == nil {
t.connection.eventStore = NewMemoryEventStore(nil)
}
// Stream 0 corresponds to the hanging 'GET'.
//
// It is always text/event-stream, since it must carry arbitrarily many
// messages.
t.connection.streams[""] = newStream("", false)
if t.connection.eventStore == nil {
t.connection.eventStore = NewMemoryEventStore(nil)
var err error
t.connection.streams[""], err = t.connection.newStream(ctx, "", false)
if err != nil {
return nil, err
}
return t.connection, nil
}
Expand Down Expand Up @@ -490,7 +494,7 @@ type stream struct {
// that there are messages available to write into the HTTP response.
// In addition, the presence of a channel guarantees that at most one HTTP response
// can receive messages for a logical stream. After claiming the stream, incoming
// requests should read from outgoing, to ensure that no new messages are missed.
// requests should read from the event store, to ensure that no new messages are missed.
//
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because
// you can't set an atomic.Value to nil.
Expand All @@ -502,22 +506,21 @@ type stream struct {
// The following mutable fields are protected by the mutex of the containing
// StreamableServerTransport.

// outgoing is the list of outgoing messages, enqueued by server methods that
// write notifications and responses, and dequeued by streamResponse.
outgoing [][]byte

// streamRequests is the set of unanswered incoming RPCs for the stream.
//
// Requests persist until their response data has been added to outgoing.
// Requests persist until their response data has been added to the event store.
requests map[jsonrpc.ID]struct{}
}

func newStream(id StreamID, jsonResponse bool) *stream {
func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) {
if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil {
return nil, err
}
return &stream{
id: id,
jsonResponse: jsonResponse,
requests: make(map[jsonrpc.ID]struct{}),
}
}, nil
}

func signalChanPtr() *chan struct{} {
Expand Down Expand Up @@ -668,7 +671,11 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
// notifications or server->client requests made in the course of handling.
// Update accounting for this incoming payload.
if len(requests) > 0 {
stream = newStream(StreamID(randText()), c.jsonResponse)
stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse)
if err != nil {
http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError)
return
}
c.mu.Lock()
c.streams[stream.id] = stream
stream.requests = requests
Expand Down Expand Up @@ -706,8 +713,8 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter

var msgs []json.RawMessage
ctx := req.Context()
for msg, ok := range c.messages(ctx, stream, false) {
if !ok {
for msg, err := range c.messages(ctx, stream, false, -1) {
if err != nil {
if ctx.Err() != nil {
w.WriteHeader(http.StatusNoContent)
return
Expand Down Expand Up @@ -770,44 +777,20 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
}
}

if lastIndex >= 0 {
// Resume.
for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) {
if err != nil {
// TODO: reevaluate these status codes.
// Maybe distinguish between storage errors, which are 500s, and missing
// session or stream ID--can these arise from bad input?
status := http.StatusInternalServerError
if errors.Is(err, ErrEventsPurged) {
status = http.StatusInsufficientStorage
}
errorf(status, "failed to read events: %v", err)
return
}
// The iterator yields events beginning just after lastIndex, or it would have
// yielded an error.
if !write(data) {
return
}
}
}

// Repeatedly collect pending outgoing events and send them.
ctx := req.Context()
for msg, ok := range c.messages(ctx, stream, persistent) {
if !ok {
for msg, err := range c.messages(ctx, stream, persistent, lastIndex) {
if err != nil {
if ctx.Err() != nil && writes == 0 {
// This probably doesn't matter, but respond with NoContent if the client disconnected.
w.WriteHeader(http.StatusNoContent)
} else if errors.Is(err, ErrEventsPurged) {
errorf(http.StatusInsufficientStorage, "failed to read events: %v", err)
} else {
errorf(http.StatusGone, "stream terminated")
}
return
}
if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil {
errorf(http.StatusInternalServerError, "storing event: %v", err.Error())
return
}
if !write(msg) {
return
}
Expand All @@ -816,27 +799,31 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,

// messages iterates over messages sent to the current stream.
//
// persistent indicates if it is the main GET listener.
// lastIndex is the index of the last seen event.
//
// The first iterated value is the received JSON message. The second iterated
// value is an OK value indicating whether the stream terminated normally.
// value is an error value indicating whether the stream terminated normally.
//
// If the stream did not terminate normally, it is either because ctx was
// cancelled, or the connection is closed: check the ctx.Err() to differentiate
// these cases.
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] {
return func(yield func(json.RawMessage, bool) bool) {
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] {
return func(yield func(json.RawMessage, error) bool) {
for {
c.mu.Lock()
outgoing := stream.outgoing
stream.outgoing = nil
nOutstanding := len(stream.requests)
c.mu.Unlock()

for _, data := range outgoing {
if !yield(data, true) {
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) {
if err != nil {
yield(nil, err)
return
}
if !yield(data, nil) {
return
}
lastIndex++
}

// If all requests have been handled and replied to, we should terminate this connection.
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
Expand All @@ -850,13 +837,14 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per
case <-*stream.signal.Load(): // there are new outgoing messages
// return to top of loop
case <-c.done: // session is closed
yield(nil, false)
yield(nil, errors.New("session is closed"))
return
case <-ctx.Done():
yield(nil, false)
yield(nil, ctx.Err())
return
}
}

}
}

Expand Down Expand Up @@ -963,9 +951,9 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
stream = c.streams[""]
}

// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == ""
// and the client never did a GET), then memory will grow without bound. Consider a mitigation.
stream.outgoing = append(stream.outgoing, data)
if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil {
return fmt.Errorf("error storing event: %w", err)
}
if isResponse {
// Once we've put the reply on the queue, it's no longer outstanding.
delete(stream.requests, forRequest)
Expand Down