Skip to content

Commit ddaf35e

Browse files
mcp/streamable: use event store to fix unbounded memory issues (#335)
This CL utilizes the event store to write outgoing messages and removes the unbounded outgoing data structure. It also adds a new interface [EventStore.Open] For #190
1 parent f6118aa commit ddaf35e

File tree

2 files changed

+72
-61
lines changed

2 files changed

+72
-61
lines changed

mcp/event.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
153153
//
154154
// All of an EventStore's methods must be safe for use by multiple goroutines.
155155
type EventStore interface {
156+
// Open prepares the event store for a given stream. It ensures that the
157+
// underlying data structure for the stream is initialized, making it
158+
// ready to store event streams.
159+
Open(_ context.Context, sessionID string, streamID StreamID) error
160+
156161
// Append appends data for an outgoing event to given stream, which is part of the
157162
// given session.
158163
Append(_ context.Context, sessionID string, _ StreamID, data []byte) error
@@ -162,6 +167,7 @@ type EventStore interface {
162167
// Once the iterator yields a non-nil error, it will stop.
163168
// After's iterator must return an error immediately if any data after index was
164169
// dropped; it must not return partial results.
170+
// The stream must have been opened previously (see [EventStore.Open]).
165171
After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error]
166172

167173
// SessionClosed informs the store that the given session is finished, along
@@ -256,11 +262,20 @@ func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore {
256262
}
257263
}
258264

259-
// Append implements [EventStore.Append] by recording data in memory.
260-
func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error {
265+
// Open implements [EventStore.Open]. It ensures that the underlying data
266+
// structures for the given session are initialized and ready for use.
267+
func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID StreamID) error {
261268
s.mu.Lock()
262269
defer s.mu.Unlock()
270+
s.init(sessionID, streamID)
271+
return nil
272+
}
263273

274+
// init is an internal helper function that ensures the nested map structure for a
275+
// given sessionID and streamID exists, creating it if necessary. It returns the
276+
// dataList associated with the specified IDs.
277+
// Requires s.mu.
278+
func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList {
264279
streamMap, ok := s.store[sessionID]
265280
if !ok {
266281
streamMap = make(map[StreamID]*dataList)
@@ -271,6 +286,14 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID
271286
dl = &dataList{}
272287
streamMap[streamID] = dl
273288
}
289+
return dl
290+
}
291+
292+
// Append implements [EventStore.Append] by recording data in memory.
293+
func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error {
294+
s.mu.Lock()
295+
defer s.mu.Unlock()
296+
dl := s.init(sessionID, streamID)
274297
// Purge before adding, so at least the current data item will be present.
275298
// (That could result in nBytes > maxBytes, but we'll live with that.)
276299
s.purge()

mcp/streamable.go

Lines changed: 47 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
401401
}
402402

403403
// Connect implements the [Transport] interface.
404-
func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) {
404+
func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) {
405405
if t.connection != nil {
406406
return nil, fmt.Errorf("transport already connected")
407407
}
@@ -415,13 +415,17 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
415415
streams: make(map[StreamID]*stream),
416416
requestStreams: make(map[jsonrpc.ID]StreamID),
417417
}
418+
if t.connection.eventStore == nil {
419+
t.connection.eventStore = NewMemoryEventStore(nil)
420+
}
418421
// Stream 0 corresponds to the hanging 'GET'.
419422
//
420423
// It is always text/event-stream, since it must carry arbitrarily many
421424
// messages.
422-
t.connection.streams[""] = newStream("", false)
423-
if t.connection.eventStore == nil {
424-
t.connection.eventStore = NewMemoryEventStore(nil)
425+
var err error
426+
t.connection.streams[""], err = t.connection.newStream(ctx, "", false)
427+
if err != nil {
428+
return nil, err
425429
}
426430
return t.connection, nil
427431
}
@@ -490,7 +494,7 @@ type stream struct {
490494
// that there are messages available to write into the HTTP response.
491495
// In addition, the presence of a channel guarantees that at most one HTTP response
492496
// can receive messages for a logical stream. After claiming the stream, incoming
493-
// requests should read from outgoing, to ensure that no new messages are missed.
497+
// requests should read from the event store, to ensure that no new messages are missed.
494498
//
495499
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because
496500
// you can't set an atomic.Value to nil.
@@ -502,22 +506,21 @@ type stream struct {
502506
// The following mutable fields are protected by the mutex of the containing
503507
// StreamableServerTransport.
504508

505-
// outgoing is the list of outgoing messages, enqueued by server methods that
506-
// write notifications and responses, and dequeued by streamResponse.
507-
outgoing [][]byte
508-
509509
// streamRequests is the set of unanswered incoming RPCs for the stream.
510510
//
511-
// Requests persist until their response data has been added to outgoing.
511+
// Requests persist until their response data has been added to the event store.
512512
requests map[jsonrpc.ID]struct{}
513513
}
514514

515-
func newStream(id StreamID, jsonResponse bool) *stream {
515+
func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) {
516+
if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil {
517+
return nil, err
518+
}
516519
return &stream{
517520
id: id,
518521
jsonResponse: jsonResponse,
519522
requests: make(map[jsonrpc.ID]struct{}),
520-
}
523+
}, nil
521524
}
522525

523526
func signalChanPtr() *chan struct{} {
@@ -668,7 +671,11 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
668671
// notifications or server->client requests made in the course of handling.
669672
// Update accounting for this incoming payload.
670673
if len(requests) > 0 {
671-
stream = newStream(StreamID(randText()), c.jsonResponse)
674+
stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse)
675+
if err != nil {
676+
http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError)
677+
return
678+
}
672679
c.mu.Lock()
673680
c.streams[stream.id] = stream
674681
stream.requests = requests
@@ -706,13 +713,13 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter
706713

707714
var msgs []json.RawMessage
708715
ctx := req.Context()
709-
for msg, ok := range c.messages(ctx, stream, false) {
710-
if !ok {
716+
for msg, err := range c.messages(ctx, stream, false, -1) {
717+
if err != nil {
711718
if ctx.Err() != nil {
712719
w.WriteHeader(http.StatusNoContent)
713720
return
714721
} else {
715-
http.Error(w, http.StatusText(http.StatusGone), http.StatusGone)
722+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
716723
return
717724
}
718725
}
@@ -770,44 +777,18 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
770777
}
771778
}
772779

773-
if lastIndex >= 0 {
774-
// Resume.
775-
for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) {
776-
if err != nil {
777-
// TODO: reevaluate these status codes.
778-
// Maybe distinguish between storage errors, which are 500s, and missing
779-
// session or stream ID--can these arise from bad input?
780-
status := http.StatusInternalServerError
781-
if errors.Is(err, ErrEventsPurged) {
782-
status = http.StatusInsufficientStorage
783-
}
784-
errorf(status, "failed to read events: %v", err)
785-
return
786-
}
787-
// The iterator yields events beginning just after lastIndex, or it would have
788-
// yielded an error.
789-
if !write(data) {
790-
return
791-
}
792-
}
793-
}
794-
795780
// Repeatedly collect pending outgoing events and send them.
796781
ctx := req.Context()
797-
for msg, ok := range c.messages(ctx, stream, persistent) {
798-
if !ok {
782+
for msg, err := range c.messages(ctx, stream, persistent, lastIndex) {
783+
if err != nil {
799784
if ctx.Err() != nil && writes == 0 {
800785
// This probably doesn't matter, but respond with NoContent if the client disconnected.
801786
w.WriteHeader(http.StatusNoContent)
802787
} else {
803-
errorf(http.StatusGone, "stream terminated")
788+
errorf(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
804789
}
805790
return
806791
}
807-
if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil {
808-
errorf(http.StatusInternalServerError, "storing event: %v", err.Error())
809-
return
810-
}
811792
if !write(msg) {
812793
return
813794
}
@@ -816,27 +797,33 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
816797

817798
// messages iterates over messages sent to the current stream.
818799
//
800+
// persistent indicates if it is the main GET listener, which should never be
801+
// terminated.
802+
// lastIndex is the index of the last seen event, iteration begins at lastIndex+1.
803+
//
819804
// The first iterated value is the received JSON message. The second iterated
820-
// value is an OK value indicating whether the stream terminated normally.
805+
// value is an error value indicating whether the stream terminated normally.
806+
// Iteration ends at the first non-nil error.
821807
//
822808
// If the stream did not terminate normally, it is either because ctx was
823809
// cancelled, or the connection is closed: check the ctx.Err() to differentiate
824810
// these cases.
825-
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] {
826-
return func(yield func(json.RawMessage, bool) bool) {
811+
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] {
812+
return func(yield func(json.RawMessage, error) bool) {
827813
for {
828814
c.mu.Lock()
829-
outgoing := stream.outgoing
830-
stream.outgoing = nil
831815
nOutstanding := len(stream.requests)
832816
c.mu.Unlock()
833-
834-
for _, data := range outgoing {
835-
if !yield(data, true) {
817+
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) {
818+
if err != nil {
819+
yield(nil, err)
836820
return
837821
}
822+
if !yield(data, nil) {
823+
return
824+
}
825+
lastIndex++
838826
}
839-
840827
// If all requests have been handled and replied to, we should terminate this connection.
841828
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
842829
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
@@ -850,13 +837,14 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per
850837
case <-*stream.signal.Load(): // there are new outgoing messages
851838
// return to top of loop
852839
case <-c.done: // session is closed
853-
yield(nil, false)
840+
yield(nil, errors.New("session is closed"))
854841
return
855842
case <-ctx.Done():
856-
yield(nil, false)
843+
yield(nil, ctx.Err())
857844
return
858845
}
859846
}
847+
860848
}
861849
}
862850

@@ -963,9 +951,9 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
963951
stream = c.streams[""]
964952
}
965953

966-
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == ""
967-
// and the client never did a GET), then memory will grow without bound. Consider a mitigation.
968-
stream.outgoing = append(stream.outgoing, data)
954+
if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil {
955+
return fmt.Errorf("error storing event: %w", err)
956+
}
969957
if isResponse {
970958
// Once we've put the reply on the queue, it's no longer outstanding.
971959
delete(stream.requests, forRequest)

0 commit comments

Comments
 (0)