Skip to content

Commit 86f71fc

Browse files
authored
mcp: add stream type (#171)
Consolidate several maps into a single struct. Simplifies the code, for the most part.
1 parent e8c6e03 commit 86f71fc

File tree

1 file changed

+86
-66
lines changed

1 file changed

+86
-66
lines changed

mcp/streamable.go

Lines changed: 86 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,10 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
179179
id: sessionID,
180180
incoming: make(chan jsonrpc.Message, 10),
181181
done: make(chan struct{}),
182-
outgoing: make(map[StreamID][][]byte),
183-
signals: make(map[StreamID]chan struct{}),
182+
streams: make(map[StreamID]*stream),
184183
requestStreams: make(map[jsonrpc.ID]StreamID),
185-
streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}),
186184
}
185+
t.streams[0] = newStream(0)
187186
if opts != nil {
188187
t.opts = *opts
189188
}
@@ -219,59 +218,66 @@ type StreamableServerTransport struct {
219218
// perform the accounting described below when incoming HTTP requests are
220219
// handled.
221220
//
222-
// The accounting is complicated. It is tempting to merge some of the maps
223-
// below, but they each have different lifecycles, as indicated by Lifecycle:
224-
// comments.
225-
//
226221
// TODO(rfindley): simplify.
227222

228-
// outgoing is the collection of outgoing messages, keyed by the logical
229-
// stream ID where they should be delivered.
223+
// streams holds the logical streams for this session, keyed by their ID.
224+
streams map[StreamID]*stream
225+
226+
// requestStreams maps incoming requests to their logical stream ID.
230227
//
231-
// streamID 0 is used for messages that don't correlate with an incoming
232-
// request.
228+
// Lifecycle: requestStreams persists for the duration of the session.
233229
//
234-
// Lifecycle: persists for the duration of the session.
235-
outgoing map[StreamID][][]byte
230+
// TODO(rfindley): clean up once requests are handled.
231+
requestStreams map[jsonrpc.ID]StreamID
232+
}
233+
234+
// A stream is a single logical stream of SSE events within a server session.
235+
// A stream begins with a client request, or with a client GET that has
236+
// no Last-Event-ID header.
237+
// A stream ends only when its session ends; we cannot determine its end otherwise,
238+
// since a client may send a GET with a Last-Event-ID that references the stream
239+
// at any time.
240+
type stream struct {
241+
// id is the logical ID for the stream, unique within a session.
242+
// ID 0 is used for messages that don't correlate with an incoming request.
243+
id StreamID
236244

237-
// signals maps a logical stream ID to a 1-buffered channel, owned by an
245+
// These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
246+
247+
// outgoing is the list of outgoing messages, enqueued by server methods that
248+
// write notifications and responses, and dequeued by streamResponse.
249+
outgoing [][]byte
250+
251+
// signal is a 1-buffered channel, owned by an
238252
// incoming HTTP request, that signals that there are messages available to
239-
// write into the HTTP response. Signals guarantees that at most one HTTP
253+
// write into the HTTP response. This guarantees that at most one HTTP
240254
// response can receive messages for a logical stream. After claiming
241255
// the stream, incoming requests should read from outgoing, to ensure
242256
// that no new messages are missed.
243257
//
244-
// Lifecycle: signals persists for the duration of an HTTP POST or GET
258+
// Lifecycle: persists for the duration of an HTTP POST or GET
245259
// request for the given streamID.
246-
signals map[StreamID]chan struct{}
260+
signal chan struct{}
247261

248-
// requestStreams maps incoming requests to their logical stream ID.
249-
//
250-
// Lifecycle: requestStreams persists for the duration of the session.
262+
// streamRequests is the set of unanswered incoming RPCs for the stream.
251263
//
252-
// TODO(rfindley): clean up once requests are handled.
253-
requestStreams map[jsonrpc.ID]StreamID
254-
255-
// streamRequests tracks the set of unanswered incoming RPCs for each logical
256-
// stream.
257-
//
258-
// When the server has responded to each request, the stream should be
259-
// closed.
260-
//
261-
// Lifecycle: streamRequests values persist as until the requests have been
264+
// Lifecycle: requests values persist until the requests have been
262265
// replied to by the server. Notably, NOT until they are sent to an HTTP
263266
// response, as delivery is not guaranteed.
264-
streamRequests map[StreamID]map[jsonrpc.ID]struct{}
267+
requests map[jsonrpc.ID]struct{}
265268
}
266269

267-
type StreamID int64
268-
269-
// a streamableMsg is an SSE event with an index into its logical stream.
270-
type streamableMsg struct {
271-
idx int
272-
event Event
270+
func newStream(id StreamID) *stream {
271+
return &stream{
272+
id: id,
273+
requests: make(map[jsonrpc.ID]struct{}),
274+
}
273275
}
274276

277+
// A StreamID identifies a stream of SSE events. It is unique within the stream's
278+
// [ServerSession].
279+
type StreamID int64
280+
275281
// Connect implements the [Transport] interface.
276282
//
277283
// TODO(rfindley): Connect should return a new object.
@@ -334,16 +340,21 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
334340
}
335341

336342
t.mu.Lock()
337-
if _, ok := t.signals[id]; ok {
343+
stream, ok := t.streams[id]
344+
if !ok {
345+
http.Error(w, "unknown stream", http.StatusBadRequest)
346+
t.mu.Unlock()
347+
return
348+
}
349+
if stream.signal != nil {
338350
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest)
339351
t.mu.Unlock()
340352
return
341353
}
342-
signal := make(chan struct{}, 1)
343-
t.signals[id] = signal
354+
stream.signal = make(chan struct{}, 1)
344355
t.mu.Unlock()
345356

346-
t.streamResponse(w, req, id, lastIdx, signal)
357+
t.streamResponse(stream, w, req, lastIdx)
347358
}
348359

349360
func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) {
@@ -375,17 +386,17 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
375386
}
376387

377388
// Update accounting for this request.
378-
id := StreamID(t.nextStreamID.Add(1))
379-
signal := make(chan struct{}, 1)
389+
stream := newStream(StreamID(t.nextStreamID.Add(1)))
380390
t.mu.Lock()
391+
t.streams[stream.id] = stream
381392
if len(requests) > 0 {
382-
t.streamRequests[id] = make(map[jsonrpc.ID]struct{})
393+
stream.requests = make(map[jsonrpc.ID]struct{})
383394
}
384395
for reqID := range requests {
385-
t.requestStreams[reqID] = id
386-
t.streamRequests[id][reqID] = struct{}{}
396+
t.requestStreams[reqID] = stream.id
397+
stream.requests[reqID] = struct{}{}
387398
}
388-
t.signals[id] = signal
399+
stream.signal = make(chan struct{}, 1)
389400
t.mu.Unlock()
390401

391402
// Publish incoming messages.
@@ -396,29 +407,37 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
396407
// TODO(rfindley): consider optimizing for a single incoming request, by
397408
// responding with application/json when there is only a single message in
398409
// the response.
399-
t.streamResponse(w, req, id, -1, signal)
410+
t.streamResponse(stream, w, req, -1)
400411
}
401412

402413
// lastIndex is the index of the last seen event if resuming, else -1.
403-
func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, lastIndex int, signal chan struct{}) {
414+
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) {
404415
defer func() {
405416
t.mu.Lock()
406-
delete(t.signals, id)
417+
stream.signal = nil
407418
t.mu.Unlock()
408419
}()
409420

421+
t.mu.Lock()
422+
// Although there is a gap in locking between when stream.signal is set and here,
423+
// it cannot change, because it is changed only when non-nil, and it is only
424+
// set to nil in the defer above.
425+
signal := stream.signal
426+
t.mu.Unlock()
427+
410428
writes := 0
411429

412430
// write one event containing data.
413431
write := func(data []byte) bool {
414432
lastIndex++
415433
e := Event{
416434
Name: "message",
417-
ID: formatEventID(id, lastIndex),
435+
ID: formatEventID(stream.id, lastIndex),
418436
Data: data,
419437
}
420438
if _, err := writeEvent(w, e); err != nil {
421439
// Connection closed or broken.
440+
// TODO: log when we add server-side logging.
422441
return false
423442
}
424443
writes++
@@ -432,7 +451,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h
432451

433452
if lastIndex >= 0 {
434453
// Resume.
435-
for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), id, lastIndex) {
454+
for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), stream.id, lastIndex) {
436455
if err != nil {
437456
// TODO: reevaluate these status codes.
438457
// Maybe distinguish between storage errors, which are 500s, and missing
@@ -456,12 +475,12 @@ stream:
456475
// Repeatedly collect pending outgoing events and send them.
457476
for {
458477
t.mu.Lock()
459-
outgoing := t.outgoing[id]
460-
t.outgoing[id] = nil
478+
outgoing := stream.outgoing
479+
stream.outgoing = nil
461480
t.mu.Unlock()
462481

463482
for _, data := range outgoing {
464-
if err := t.opts.EventStore.Append(req.Context(), t.id, id, data); err != nil {
483+
if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil {
465484
http.Error(w, err.Error(), http.StatusInternalServerError)
466485
return
467486
}
@@ -471,7 +490,7 @@ stream:
471490
}
472491

473492
t.mu.Lock()
474-
nOutstanding := len(t.streamRequests[id])
493+
nOutstanding := len(stream.requests)
475494
t.mu.Unlock()
476495
// If all requests have been handled and replied to, we should terminate this connection.
477496
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
@@ -585,30 +604,31 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
585604
t.mu.Lock()
586605
defer t.mu.Unlock()
587606
if t.isDone {
588-
return fmt.Errorf("session is closed") // TODO: should this be EOF?
607+
return errors.New("session is closed")
589608
}
590609

591-
if _, ok := t.streamRequests[forConn]; !ok && forConn != 0 {
610+
stream := t.streams[forConn]
611+
if stream == nil {
612+
return fmt.Errorf("no stream with ID %d", forConn)
613+
}
614+
if len(stream.requests) == 0 && forConn != 0 {
592615
// No outstanding requests for this connection, which means it is logically
593616
// done. This is a sequencing violation from the server, so we should report
594617
// a side-channel error here. Put the message on the general queue to avoid
595618
// dropping messages.
596-
forConn = 0
619+
stream = t.streams[0]
597620
}
598621

599-
t.outgoing[forConn] = append(t.outgoing[forConn], data)
622+
stream.outgoing = append(stream.outgoing, data)
600623
if replyTo.IsValid() {
601624
// Once we've put the reply on the queue, it's no longer outstanding.
602-
delete(t.streamRequests[forConn], replyTo)
603-
if len(t.streamRequests[forConn]) == 0 {
604-
delete(t.streamRequests, forConn)
605-
}
625+
delete(stream.requests, replyTo)
606626
}
607627

608628
// Signal work.
609-
if c, ok := t.signals[forConn]; ok {
629+
if stream.signal != nil {
610630
select {
611-
case c <- struct{}{}:
631+
case stream.signal <- struct{}{}:
612632
default:
613633
}
614634
}

0 commit comments

Comments
 (0)