Skip to content

Commit 4e85827

Browse files
committed
add EventStore.Open
1 parent 06a0ab0 commit 4e85827

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

mcp/event.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
153153
//
154154
// All of an EventStore's methods must be safe for use by multiple goroutines.
155155
type EventStore interface {
156+
// Open prepares the event store for a given session. It ensures that the
157+
// underlying data structure for the sessionID is initialized, making it
158+
// ready to store event streams.
159+
Open(_ context.Context, sessionID string, streamID StreamID) error
160+
156161
// Append appends data for an outgoing event to given stream, which is part of the
157162
// given session.
158163
Append(_ context.Context, sessionID string, _ StreamID, data []byte) error
@@ -256,11 +261,21 @@ func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore {
256261
}
257262
}
258263

259-
// Append implements [EventStore.Append] by recording data in memory.
260-
func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error {
264+
// Open implements [EventStore.Open]. It ensures that the underlying data
265+
// structures for the given session and stream IDs are initialized and ready
266+
// for use.
267+
func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID StreamID) error {
261268
s.mu.Lock()
262269
defer s.mu.Unlock()
270+
s.init(sessionID, streamID)
271+
return nil
272+
}
263273

274+
// init is an internal helper function that ensures the nested map structure for a
275+
// given sessionID and streamID exists, creating it if necessary. It returns the
276+
// dataList associated with the specified IDs.
277+
// This function must be called within a locked context.
278+
func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList {
264279
streamMap, ok := s.store[sessionID]
265280
if !ok {
266281
streamMap = make(map[StreamID]*dataList)
@@ -271,6 +286,14 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID
271286
dl = &dataList{}
272287
streamMap[streamID] = dl
273288
}
289+
return dl
290+
}
291+
292+
// Append implements [EventStore.Append] by recording data in memory.
293+
func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error {
294+
s.mu.Lock()
295+
defer s.mu.Unlock()
296+
dl := s.init(sessionID, streamID)
274297
// Purge before adding, so at least the current data item will be present.
275298
// (That could result in nBytes > maxBytes, but we'll live with that.)
276299
s.purge()

mcp/streamable.go

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
401401
}
402402

403403
// Connect implements the [Transport] interface.
404-
func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) {
404+
func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) {
405405
if t.connection != nil {
406406
return nil, fmt.Errorf("transport already connected")
407407
}
@@ -415,13 +415,17 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
415415
streams: make(map[StreamID]*stream),
416416
requestStreams: make(map[jsonrpc.ID]StreamID),
417417
}
418+
if t.connection.eventStore == nil {
419+
t.connection.eventStore = NewMemoryEventStore(nil)
420+
}
418421
// Stream 0 corresponds to the hanging 'GET'.
419422
//
420423
// It is always text/event-stream, since it must carry arbitrarily many
421424
// messages.
422-
t.connection.streams[""] = newStream("", false)
423-
if t.connection.eventStore == nil {
424-
t.connection.eventStore = NewMemoryEventStore(nil)
425+
var err error
426+
t.connection.streams[""], err = t.connection.newStream(ctx, "", false)
427+
if err != nil {
428+
return nil, err
425429
}
426430
return t.connection, nil
427431
}
@@ -508,12 +512,15 @@ type stream struct {
508512
requests map[jsonrpc.ID]struct{}
509513
}
510514

511-
func newStream(id StreamID, jsonResponse bool) *stream {
515+
func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) {
516+
if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil {
517+
return nil, err
518+
}
512519
return &stream{
513520
id: id,
514521
jsonResponse: jsonResponse,
515522
requests: make(map[jsonrpc.ID]struct{}),
516-
}
523+
}, nil
517524
}
518525

519526
func signalChanPtr() *chan struct{} {
@@ -664,7 +671,11 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
664671
// notifications or server->client requests made in the course of handling.
665672
// Update accounting for this incoming payload.
666673
if len(requests) > 0 {
667-
stream = newStream(StreamID(randText()), c.jsonResponse)
674+
stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse)
675+
if err != nil {
676+
http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError)
677+
return
678+
}
668679
c.mu.Lock()
669680
c.streams[stream.id] = stream
670681
stream.requests = requests
@@ -800,18 +811,19 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
800811
func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] {
801812
return func(yield func(json.RawMessage, error) bool) {
802813
for {
814+
c.mu.Lock()
815+
nOutstanding := len(stream.requests)
816+
c.mu.Unlock()
803817
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) {
804818
if err != nil {
805-
break
819+
yield(nil, err)
820+
return
806821
}
807822
if !yield(data, nil) {
808823
return
809824
}
810825
lastIndex++
811826
}
812-
c.mu.Lock()
813-
nOutstanding := len(stream.requests)
814-
c.mu.Unlock()
815827
// If all requests have been handled and replied to, we should terminate this connection.
816828
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
817829
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server

0 commit comments

Comments
 (0)