From c21ea6595bf42345a2274e75d1854b9514cba73d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 21 Jul 2025 16:16:26 -0400 Subject: [PATCH 1/2] mcp: incorporate EventStore into StreamableServerTransport Use the EventStore interface to implement resumption on the server side. --- mcp/event.go | 64 ++++++------------ mcp/event_test.go | 30 +++------ mcp/server.go | 2 + mcp/streamable.go | 161 +++++++++++++++++++++++++++------------------- 4 files changed, 124 insertions(+), 133 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index fbbe1941..9092da76 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -153,10 +153,9 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { // // All of an EventStore's methods must be safe for use by multiple goroutines. type EventStore interface { - // AppendEvent appends data for an outgoing event to given stream, which is part of the - // given session. It returns the index of the event in the stream, suitable for constructing - // an event ID to send to the client. - AppendEvent(_ context.Context, sessionID string, _ StreamID, data []byte) (int, error) + // Append appends data for an outgoing event to given stream, which is part of the + // given session. + Append(_ context.Context, sessionID string, _ StreamID, data []byte) error // After returns an iterator over the data for the given session and stream, beginning // just after the given index. @@ -165,16 +164,15 @@ type EventStore interface { // dropped; it must not return partial results. After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error] - // StreamClosed informs the store that the given stream is finished. - // A store cannot rely on this method being called for cleanup. It should institute - // additional mechanisms, such as timeouts, to reclaim storage. - StreamClosed(_ context.Context, sessionID string, streamID StreamID) error - // SessionClosed informs the store that the given session is finished, along // with all of its streams. // A store cannot rely on this method being called for cleanup. It should institute // additional mechanisms, such as timeouts, to reclaim storage. + // SessionClosed(_ context.Context, sessionID string) error + + // There is no StreamClosed method. A server doesn't know when a stream is finished, because + // the client can always send a GET with a Last-Event-ID referring to the stream. } // A dataList is a list of []byte. @@ -210,15 +208,6 @@ func (dl *dataList) removeFirst() int { return r } -// lastIndex returns the index of the last data item in dl. -// It panics if there are none. -func (dl *dataList) lastIndex() int { - if len(dl.data) == 0 { - panic("empty dataList") - } - return dl.first + len(dl.data) - 1 -} - // A MemoryEventStore is an [EventStore] backed by memory. type MemoryEventStore struct { mu sync.Mutex @@ -267,9 +256,8 @@ func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { } } -// AppendEvent implements [EventStore.AppendEvent] by recording data -// in memory. -func (s *MemoryEventStore) AppendEvent(_ context.Context, sessionID string, streamID StreamID, data []byte) (int, error) { +// Append implements [EventStore.Append] by recording data in memory. +func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error { s.mu.Lock() defer s.mu.Unlock() @@ -288,9 +276,13 @@ func (s *MemoryEventStore) AppendEvent(_ context.Context, sessionID string, stre s.purge() dl.appendData(data) s.nBytes += len(data) - return dl.lastIndex(), nil + return nil } +// ErrEventsPurged is the error that [EventStore.After] should return if the event just after the +// index is no longer available. +var ErrEventsPurged = errors.New("data purged") + // After implements [EventStore.After]. func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID StreamID, index int) iter.Seq2[[]byte, error] { // Return the data items to yield. @@ -306,10 +298,12 @@ func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID S if !ok { return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID) } - if dl.first > index { - return nil, fmt.Errorf("MemoryEventStore.After: data purged at index %d, stream ID %v, session %q", index, streamID, sessionID) + start := index + 1 + if dl.first > start { + return nil, fmt.Errorf("MemoryEventStore.After: index %d, stream ID %v, session %q: %w", + index, streamID, sessionID, ErrEventsPurged) } - return slices.Clone(dl.data[index-dl.first:]), nil + return slices.Clone(dl.data[start-dl.first:]), nil } return func(yield func([]byte, error) bool) { @@ -326,26 +320,6 @@ func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID S } } -// StreamClosed implements [EventStore.StreamClosed]. -func (s *MemoryEventStore) StreamClosed(_ context.Context, sessionID string, streamID StreamID) error { - if sessionID == "" { - panic("empty sessionID") - } - - s.mu.Lock() - defer s.mu.Unlock() - - sm := s.store[sessionID] - dl := sm[streamID] - s.nBytes -= dl.size - delete(sm, streamID) - if len(sm) == 0 { - delete(s.store, sessionID) - } - s.validate() - return nil -} - // SessionClosed implements [EventStore.SessionClosed]. func (s *MemoryEventStore) SessionClosed(_ context.Context, sessionID string) error { s.mu.Lock() diff --git a/mcp/event_test.go b/mcp/event_test.go index 9b01555f..147a947a 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -105,7 +105,7 @@ func TestMemoryEventStoreState(t *testing.T) { ctx := context.Background() appendEvent := func(s *MemoryEventStore, sess string, str StreamID, data string) { - if _, err := s.AppendEvent(ctx, sess, str, []byte(data)); err != nil { + if err := s.Append(ctx, sess, str, []byte(data)); err != nil { t.Fatal(err) } } @@ -127,18 +127,6 @@ func TestMemoryEventStoreState(t *testing.T) { "S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4", 8, }, - { - "stream close", - func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") - s.StreamClosed(ctx, "S1", 1) - }, - "S1 2 first=0 d2; S2 8 first=0 d4", - 4, - }, { "session close", func(s *MemoryEventStore) { @@ -218,10 +206,10 @@ func TestMemoryEventStoreAfter(t *testing.T) { ctx := context.Background() s := NewMemoryEventStore(nil) s.SetMaxBytes(4) - s.AppendEvent(ctx, "S1", 1, []byte("d1")) - s.AppendEvent(ctx, "S1", 1, []byte("d2")) - s.AppendEvent(ctx, "S1", 1, []byte("d3")) - s.AppendEvent(ctx, "S1", 2, []byte("d4")) // will purge d1 + s.Append(ctx, "S1", 1, []byte("d1")) + s.Append(ctx, "S1", 1, []byte("d2")) + s.Append(ctx, "S1", 1, []byte("d3")) + s.Append(ctx, "S1", 2, []byte("d4")) // will purge d1 want := "S1 1 first=1 d2 d3; S1 2 first=0 d4" if got := s.debugString(); got != want { t.Fatalf("got state %q, want %q", got, want) @@ -234,10 +222,10 @@ func TestMemoryEventStoreAfter(t *testing.T) { want []string wantErr string // if non-empty, error should contain this string }{ - {"S1", 1, 0, nil, "purge"}, - {"S1", 1, 1, []string{"d2", "d3"}, ""}, - {"S1", 1, 2, []string{"d3"}, ""}, - {"S1", 2, 0, []string{"d4"}, ""}, + {"S1", 1, 0, []string{"d2", "d3"}, ""}, + {"S1", 1, 1, []string{"d3"}, ""}, + {"S1", 1, 2, nil, ""}, + {"S1", 2, 0, nil, ""}, {"S1", 3, 0, nil, "unknown stream ID"}, {"S2", 0, 0, nil, "unknown session ID"}, } { diff --git a/mcp/server.go b/mcp/server.go index 6b287ad7..e0f691dc 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -85,9 +85,11 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { if opts.PageSize < 0 { panic(fmt.Errorf("invalid page size %d", opts.PageSize)) } + // TODO(jba): don't modify opts, modify Server.opts. if opts.PageSize == 0 { opts.PageSize = DefaultPageSize } + return &Server{ impl: impl, opts: *opts, diff --git a/mcp/streamable.go b/mcp/streamable.go index f7c6ed63..2d3606e8 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -7,6 +7,7 @@ package mcp import ( "bytes" "context" + "errors" "fmt" "io" "math" @@ -132,7 +133,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } if session == nil { - s := NewStreamableServerTransport(randText()) + s := NewStreamableServerTransport(randText(), nil) server := h.getServer(req) // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the @@ -150,27 +151,40 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque session.ServeHTTP(w, req) } +type StreamableServerTransportOptions struct { + // Storage for events, to enable stream resumption. + // If nil, a [MemoryEventStore] with the default maximum size will be used. + EventStore EventStore +} + // NewStreamableServerTransport returns a new [StreamableServerTransport] with -// the given session ID. +// the given session ID and options. // The session ID must be globally unique, that is, different from any other // session ID anywhere, past and future. (We recommend using a crypto random number -// generator to produce one, as in [crypto/rand.Text].) +// generator to produce one, as with [crypto/rand.Text].) // // A StreamableServerTransport implements the server-side of the streamable // transport. -// -// TODO(rfindley): consider adding options here, to configure event storage -// policy. -func NewStreamableServerTransport(sessionID string) *StreamableServerTransport { - return &StreamableServerTransport{ - id: sessionID, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - outgoingMessages: make(map[StreamID][]*streamableMsg), - signals: make(map[StreamID]chan struct{}), - requestStreams: make(map[jsonrpc.ID]StreamID), - streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}), +func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransportOptions) *StreamableServerTransport { + if opts == nil { + opts = &StreamableServerTransportOptions{} + } + t := &StreamableServerTransport{ + id: sessionID, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + outgoing: make(map[StreamID][][]byte), + signals: make(map[StreamID]chan struct{}), + requestStreams: make(map[jsonrpc.ID]StreamID), + streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}), } + if opts != nil { + t.opts = *opts + } + if t.opts.EventStore == nil { + t.opts.EventStore = NewMemoryEventStore(nil) + } + return t } func (t *StreamableServerTransport) SessionID() string { @@ -183,10 +197,10 @@ type StreamableServerTransport struct { nextStreamID atomic.Int64 // incrementing next stream ID id string + opts StreamableServerTransportOptions incoming chan jsonrpc.Message // messages from the client to the server mu sync.Mutex - // Sessions are closed exactly once. isDone bool done chan struct{} @@ -205,17 +219,14 @@ type StreamableServerTransport struct { // // TODO(rfindley): simplify. - // outgoingMessages is the collection of outgoingMessages messages, keyed by the logical + // outgoing is the collection of outgoing messages, keyed by the logical // stream ID where they should be delivered. // // streamID 0 is used for messages that don't correlate with an incoming // request. // - // Lifecycle: outgoingMessages persists for the duration of the session. - // - // TODO(rfindley): garbage collect this data. For now, we save all outgoingMessages - // messages for the lifespan of the transport. - outgoingMessages map[StreamID][]*streamableMsg + // Lifecycle: persists for the duration of the session. + outgoing map[StreamID][][]byte // signals maps a logical stream ID to a 1-buffered channel, owned by an // incoming HTTP request, that signals that there are messages available to @@ -301,16 +312,15 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id, nextIdx := StreamID(0), 0 + id, lastIdx := StreamID(0), -1 if len(req.Header.Values("Last-Event-ID")) > 0 { eid := req.Header.Get("Last-Event-ID") var ok bool - id, nextIdx, ok = parseEventID(eid) + id, lastIdx, ok = parseEventID(eid) if !ok { http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) return } - nextIdx++ } t.mu.Lock() @@ -323,7 +333,7 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re t.signals[id] = signal t.mu.Unlock() - t.streamResponse(w, req, id, nextIdx, signal) + t.streamResponse(w, req, id, lastIdx, signal) } func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) { @@ -376,26 +386,33 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R // TODO(rfindley): consider optimizing for a single incoming request, by // responding with application/json when there is only a single message in // the response. - t.streamResponse(w, req, id, 0, signal) + t.streamResponse(w, req, id, -1, signal) } -func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, nextIndex int, signal chan struct{}) { +// lastIndex is the index of the last seen event if resuming, else -1. +func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, lastIndex int, signal chan struct{}) { defer func() { t.mu.Lock() delete(t.signals, id) t.mu.Unlock() }() - // Stream resumption: adjust outgoing index based on what the user says - // they've received. - if nextIndex > 0 { - t.mu.Lock() - // Clamp nextIndex to outgoing messages. - outgoing := t.outgoingMessages[id] - if nextIndex > len(outgoing) { - nextIndex = len(outgoing) + writes := 0 + + // write one event containing data. + write := func(data []byte) bool { + lastIndex++ + e := Event{ + Name: "message", + ID: formatEventID(id, lastIndex), + Data: data, } - t.mu.Unlock() + if _, err := writeEvent(w, e); err != nil { + // Connection closed or broken. + return false + } + writes++ + return true } w.Header().Set(sessionIDHeader, t.id) @@ -403,37 +420,53 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") - writes := 0 + if lastIndex >= 0 { + // Resume. + for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), id, lastIndex) { + if err != nil { + // TODO: reevaluate these status codes. + // Maybe distinguish between storage errors, which are 500s, and missing + // session or stream ID--can these arise from bad input? + status := http.StatusInternalServerError + if errors.Is(err, ErrEventsPurged) { + status = http.StatusInsufficientStorage + } + http.Error(w, err.Error(), status) + return + } + // The iterator yields events beginning just after lastIndex, or it would have + // yielded an error. + if !write(data) { + return + } + } + } + stream: + // Repeatedly collect pending outgoing events and send them. for { - // Send outgoing messages t.mu.Lock() - outgoing := t.outgoingMessages[id][nextIndex:] + outgoing := t.outgoing[id] + t.outgoing[id] = nil t.mu.Unlock() - for _, msg := range outgoing { - if _, err := writeEvent(w, msg.event); err != nil { - // Connection closed or broken. + for _, data := range outgoing { + if err := t.opts.EventStore.Append(req.Context(), t.id, id, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !write(data) { return } - writes++ - nextIndex++ } t.mu.Lock() nOutstanding := len(t.streamRequests[id]) - nOutgoing := len(t.outgoingMessages[id]) t.mu.Unlock() - // If all requests have been handled and replied to, we can terminate this - // connection. However, in the case of a sequencing violation from the server - // (a send on the request context after the request has been handled), we - // loop until we've written all messages. - // - // TODO(rfindley): should we instead refuse to send messages after the last - // response? Decide, write a test, and change the behavior. - if nextIndex < nOutgoing { - continue // more to send - } + // If all requests have been handled and replied to, we should terminate this connection. + // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." + // ยง6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + // TODO(jba): why not terminate regardless of http method? if req.Method == http.MethodPost && nOutstanding == 0 { if writes == 0 { // Spec: If the server accepts the input, the server MUST return HTTP @@ -444,8 +477,9 @@ stream: } select { - case <-signal: - case <-t.done: + case <-signal: // there are new outgoing messages + // return to top of loop + case <-t.done: // session is closed if writes == 0 { http.Error(w, "session terminated", http.StatusGone) } @@ -552,15 +586,7 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa forConn = 0 } - idx := len(t.outgoingMessages[forConn]) - t.outgoingMessages[forConn] = append(t.outgoingMessages[forConn], &streamableMsg{ - idx: idx, - event: Event{ - Name: "message", - ID: formatEventID(forConn, idx), - Data: data, - }, - }) + t.outgoing[forConn] = append(t.outgoing[forConn], data) if replyTo.IsValid() { // Once we've put the reply on the queue, it's no longer outstanding. delete(t.streamRequests[forConn], replyTo) @@ -586,6 +612,7 @@ func (t *StreamableServerTransport) Close() error { if !t.isDone { t.isDone = true close(t.done) + return t.opts.EventStore.SessionClosed(context.TODO(), t.id) } return nil } From 4b168b6b029ca3b537a855cf1f6baccd67a51d49 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 23 Jul 2025 13:16:29 -0400 Subject: [PATCH 2/2] add doc --- mcp/streamable.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 2d3606e8..7dba4504 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -312,7 +312,11 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id, lastIdx := StreamID(0), -1 + id := StreamID(0) + // By default, we haven't seen a last index. Since indices start at 0, we represent + // that by -1. This is incremented just before each event is written, in streamResponse + // around L407. + lastIdx := -1 if len(req.Header.Values("Last-Event-ID")) > 0 { eid := req.Header.Get("Last-Event-ID") var ok bool