Skip to content

Commit 0f55f3e

Browse files
committed
mcp/streamable: use event store to fix unbounded memory issues
This CL utilizes the event store to write outgoing messages and removes the unbounded outgoing data structure. For #190
1 parent 3db848a commit 0f55f3e

File tree

2 files changed

+46
-55
lines changed

2 files changed

+46
-55
lines changed

mcp/event.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID
283283
// index is no longer available.
284284
var ErrEventsPurged = errors.New("data purged")
285285

286+
// ErrUnknownSession is the error that [EventStore.After] should return if the session ID is unknown.
287+
var ErrUnknownSession = errors.New("unknown session ID")
288+
289+
// ErrUnknownSession is the error that [EventStore.After] should return if the stream ID is unknown.
290+
var ErrUnknownStream = errors.New("unknown stream ID")
291+
286292
// After implements [EventStore.After].
287293
func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID StreamID, index int) iter.Seq2[[]byte, error] {
288294
// Return the data items to yield.
@@ -292,11 +298,11 @@ func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID S
292298
defer s.mu.Unlock()
293299
streamMap, ok := s.store[sessionID]
294300
if !ok {
295-
return nil, fmt.Errorf("MemoryEventStore.After: unknown session ID %q", sessionID)
301+
return nil, fmt.Errorf("MemoryEventStore.After: session ID %v: %w", sessionID, ErrUnknownSession)
296302
}
297303
dl, ok := streamMap[streamID]
298304
if !ok {
299-
return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID)
305+
return nil, fmt.Errorf("MemoryEventStore.After: stream ID %v in session %q: %w", streamID, sessionID, ErrUnknownStream)
300306
}
301307
start := index + 1
302308
if dl.first > start {

mcp/streamable.go

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ type stream struct {
348348
// that there are messages available to write into the HTTP response.
349349
// In addition, the presence of a channel guarantees that at most one HTTP response
350350
// can receive messages for a logical stream. After claiming the stream, incoming
351-
// requests should read from outgoing, to ensure that no new messages are missed.
351+
// requests should read from the event store, to ensure that no new messages are missed.
352352
//
353353
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because
354354
// you can't set an atomic.Value to nil.
@@ -360,22 +360,23 @@ type stream struct {
360360
// The following mutable fields are protected by the mutex of the containing
361361
// StreamableServerTransport.
362362

363-
// outgoing is the list of outgoing messages, enqueued by server methods that
364-
// write notifications and responses, and dequeued by streamResponse.
365-
outgoing [][]byte
366-
367363
// streamRequests is the set of unanswered incoming RPCs for the stream.
368364
//
369-
// Requests persist until their response data has been added to outgoing.
365+
// Requests persist until their response data has been added to the event store.
370366
requests map[jsonrpc.ID]struct{}
367+
368+
// lastWriteIndex tracks the index of the last message written to the event store for this stream.
369+
lastWriteIndex atomic.Int64
371370
}
372371

373372
func newStream(id StreamID, jsonResponse bool) *stream {
374-
return &stream{
373+
s := &stream{
375374
id: id,
376375
jsonResponse: jsonResponse,
377376
requests: make(map[jsonrpc.ID]struct{}),
378377
}
378+
s.lastWriteIndex.Store(-1)
379+
return s
379380
}
380381

381382
func signalChanPtr() *chan struct{} {
@@ -559,8 +560,8 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter
559560

560561
var msgs []json.RawMessage
561562
ctx := req.Context()
562-
for msg, ok := range c.messages(ctx, stream, false) {
563-
if !ok {
563+
for msg, err := range c.messages(ctx, stream, false, -1) {
564+
if err != nil {
564565
if ctx.Err() != nil {
565566
w.WriteHeader(http.StatusNoContent)
566567
return
@@ -623,44 +624,20 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
623624
}
624625
}
625626

626-
if lastIndex >= 0 {
627-
// Resume.
628-
for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) {
629-
if err != nil {
630-
// TODO: reevaluate these status codes.
631-
// Maybe distinguish between storage errors, which are 500s, and missing
632-
// session or stream ID--can these arise from bad input?
633-
status := http.StatusInternalServerError
634-
if errors.Is(err, ErrEventsPurged) {
635-
status = http.StatusInsufficientStorage
636-
}
637-
errorf(status, "failed to read events: %v", err)
638-
return
639-
}
640-
// The iterator yields events beginning just after lastIndex, or it would have
641-
// yielded an error.
642-
if !write(data) {
643-
return
644-
}
645-
}
646-
}
647-
648627
// Repeatedly collect pending outgoing events and send them.
649628
ctx := req.Context()
650-
for msg, ok := range c.messages(ctx, stream, persistent) {
651-
if !ok {
629+
for msg, err := range c.messages(ctx, stream, persistent, lastIndex) {
630+
if err != nil {
652631
if ctx.Err() != nil && writes == 0 {
653632
// This probably doesn't matter, but respond with NoContent if the client disconnected.
654633
w.WriteHeader(http.StatusNoContent)
634+
} else if errors.Is(err, ErrEventsPurged) {
635+
errorf(http.StatusInsufficientStorage, "failed to read events: %v", err)
655636
} else {
656637
errorf(http.StatusGone, "stream terminated")
657638
}
658639
return
659640
}
660-
if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil {
661-
errorf(http.StatusInternalServerError, "storing event: %v", err.Error())
662-
return
663-
}
664641
if !write(msg) {
665642
return
666643
}
@@ -675,41 +652,48 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
675652
// If the stream did not terminate normally, it is either because ctx was
676653
// cancelled, or the connection is closed: check the ctx.Err() to differentiate
677654
// these cases.
678-
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] {
679-
return func(yield func(json.RawMessage, bool) bool) {
655+
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] {
656+
return func(yield func(json.RawMessage, error) bool) {
680657
for {
681-
c.mu.Lock()
682-
outgoing := stream.outgoing
683-
stream.outgoing = nil
684-
nOutstanding := len(stream.requests)
685-
c.mu.Unlock()
686-
687-
for _, data := range outgoing {
688-
if !yield(data, true) {
658+
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) {
659+
if err != nil {
660+
// Wait for session initialization before yielding.
661+
if errors.Is(err, ErrUnknownSession) || errors.Is(err, ErrUnknownStream) {
662+
break
663+
}
664+
yield(nil, err)
689665
return
690666
}
667+
if !yield(data, nil) {
668+
return
669+
}
670+
lastIndex++
691671
}
672+
c.mu.Lock()
673+
nOutstanding := len(stream.requests)
674+
c.mu.Unlock()
692675

693676
// If all requests have been handled and replied to, we should terminate this connection.
694677
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
695678
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
696679
// We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET
697680
// (stream ID 0) will never have requests, and should remain open indefinitely.
698-
if nOutstanding == 0 && !persistent {
681+
if nOutstanding == 0 && !persistent && lastIndex >= int(stream.lastWriteIndex.Load()) {
699682
return
700683
}
701684

702685
select {
703686
case <-*stream.signal.Load(): // there are new outgoing messages
704687
// return to top of loop
705688
case <-c.done: // session is closed
706-
yield(nil, false)
689+
yield(nil, errors.New("session is closed"))
707690
return
708691
case <-ctx.Done():
709-
yield(nil, false)
692+
yield(nil, ctx.Err())
710693
return
711694
}
712695
}
696+
713697
}
714698
}
715699

@@ -812,9 +796,10 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
812796
stream = c.streams[""]
813797
}
814798

815-
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == ""
816-
// and the client never did a GET), then memory will grow without bound. Consider a mitigation.
817-
stream.outgoing = append(stream.outgoing, data)
799+
if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil {
800+
return fmt.Errorf("error storing event: %w", err)
801+
}
802+
stream.lastWriteIndex.Add(1)
818803
if isResponse {
819804
// Once we've put the reply on the queue, it's no longer outstanding.
820805
delete(stream.requests, forRequest)

0 commit comments

Comments
 (0)