Skip to content

Commit 615f0e6

Browse files
committed
mcp: add stream type
Consolidate several maps into a single struct. Simplifies the code, for the most part.
1 parent 8dd9a81 commit 615f0e6

File tree

1 file changed

+86
-66
lines changed

1 file changed

+86
-66
lines changed

mcp/streamable.go

Lines changed: 86 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,10 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
173173
id: sessionID,
174174
incoming: make(chan jsonrpc.Message, 10),
175175
done: make(chan struct{}),
176-
outgoing: make(map[StreamID][][]byte),
177-
signals: make(map[StreamID]chan struct{}),
176+
streams: make(map[StreamID]*stream),
178177
requestStreams: make(map[jsonrpc.ID]StreamID),
179-
streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}),
180178
}
179+
t.streams[0] = newStream(0)
181180
if opts != nil {
182181
t.opts = *opts
183182
}
@@ -213,59 +212,66 @@ type StreamableServerTransport struct {
213212
// perform the accounting described below when incoming HTTP requests are
214213
// handled.
215214
//
216-
// The accounting is complicated. It is tempting to merge some of the maps
217-
// below, but they each have different lifecycles, as indicated by Lifecycle:
218-
// comments.
219-
//
220215
// TODO(rfindley): simplify.
221216

222-
// outgoing is the collection of outgoing messages, keyed by the logical
223-
// stream ID where they should be delivered.
217+
// streams holds the logical streams for this session, keyed by their ID.
218+
streams map[StreamID]*stream
219+
220+
// requestStreams maps incoming requests to their logical stream ID.
224221
//
225-
// streamID 0 is used for messages that don't correlate with an incoming
226-
// request.
222+
// Lifecycle: requestStreams persists for the duration of the session.
227223
//
228-
// Lifecycle: persists for the duration of the session.
229-
outgoing map[StreamID][][]byte
224+
// TODO(rfindley): clean up once requests are handled.
225+
requestStreams map[jsonrpc.ID]StreamID
226+
}
227+
228+
// A stream is a single logical stream of SSE events within a server session.
229+
// A stream begins with a client request, or with a client GET that has
230+
// no Last-Event-ID header.
231+
// A stream ends only when its session ends; we cannot determine its end otherwise,
232+
// since a client may send a GET with a Last-Event-ID that references the stream
233+
// at any time.
234+
type stream struct {
235+
// id is the logical ID for the stream, unique within a session.
236+
// ID 0 is used for messages that don't correlate with an incoming request.
237+
id StreamID
230238

231-
// signals maps a logical stream ID to a 1-buffered channel, owned by an
239+
// These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
240+
241+
// outgoing is the list of outgoing messages, enqueued by server methods that
242+
// write notifications and responses, and dequeued by streamResponse.
243+
outgoing [][]byte
244+
245+
// signal is a 1-buffered channel, owned by an
232246
// incoming HTTP request, that signals that there are messages available to
233-
// write into the HTTP response. Signals guarantees that at most one HTTP
247+
// write into the HTTP response. This guarantees that at most one HTTP
234248
// response can receive messages for a logical stream. After claiming
235249
// the stream, incoming requests should read from outgoing, to ensure
236250
// that no new messages are missed.
237251
//
238-
// Lifecycle: signals persists for the duration of an HTTP POST or GET
252+
// Lifecycle: persists for the duration of an HTTP POST or GET
239253
// request for the given streamID.
240-
signals map[StreamID]chan struct{}
254+
signal chan struct{}
241255

242-
// requestStreams maps incoming requests to their logical stream ID.
243-
//
244-
// Lifecycle: requestStreams persists for the duration of the session.
256+
// streamRequests is the set of unanswered incoming RPCs for the stream.
245257
//
246-
// TODO(rfindley): clean up once requests are handled.
247-
requestStreams map[jsonrpc.ID]StreamID
248-
249-
// streamRequests tracks the set of unanswered incoming RPCs for each logical
250-
// stream.
251-
//
252-
// When the server has responded to each request, the stream should be
253-
// closed.
254-
//
255-
// Lifecycle: streamRequests values persist as until the requests have been
258+
// Lifecycle: requests values persist as until the requests have been
256259
// replied to by the server. Notably, NOT until they are sent to an HTTP
257260
// response, as delivery is not guaranteed.
258-
streamRequests map[StreamID]map[jsonrpc.ID]struct{}
261+
requests map[jsonrpc.ID]struct{}
259262
}
260263

261-
type StreamID int64
262-
263-
// a streamableMsg is an SSE event with an index into its logical stream.
264-
type streamableMsg struct {
265-
idx int
266-
event Event
264+
func newStream(id StreamID) *stream {
265+
return &stream{
266+
id: id,
267+
requests: make(map[jsonrpc.ID]struct{}),
268+
}
267269
}
268270

271+
// A StreamID identifies a stream of SSE events. It is unique within the stream's
272+
// [ServerSession].
273+
type StreamID int64
274+
269275
// Connect implements the [Transport] interface.
270276
//
271277
// TODO(rfindley): Connect should return a new object.
@@ -328,16 +334,21 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
328334
}
329335

330336
t.mu.Lock()
331-
if _, ok := t.signals[id]; ok {
337+
stream, ok := t.streams[id]
338+
if !ok {
339+
http.Error(w, "unknown stream", http.StatusBadRequest)
340+
t.mu.Unlock()
341+
return
342+
}
343+
if stream.signal != nil {
332344
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest)
333345
t.mu.Unlock()
334346
return
335347
}
336-
signal := make(chan struct{}, 1)
337-
t.signals[id] = signal
348+
stream.signal = make(chan struct{}, 1)
338349
t.mu.Unlock()
339350

340-
t.streamResponse(w, req, id, lastIdx, signal)
351+
t.streamResponse(stream, w, req, lastIdx)
341352
}
342353

343354
func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) {
@@ -369,17 +380,17 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
369380
}
370381

371382
// Update accounting for this request.
372-
id := StreamID(t.nextStreamID.Add(1))
373-
signal := make(chan struct{}, 1)
383+
stream := newStream(StreamID(t.nextStreamID.Add(1)))
374384
t.mu.Lock()
385+
t.streams[stream.id] = stream
375386
if len(requests) > 0 {
376-
t.streamRequests[id] = make(map[jsonrpc.ID]struct{})
387+
stream.requests = make(map[jsonrpc.ID]struct{})
377388
}
378389
for reqID := range requests {
379-
t.requestStreams[reqID] = id
380-
t.streamRequests[id][reqID] = struct{}{}
390+
t.requestStreams[reqID] = stream.id
391+
stream.requests[reqID] = struct{}{}
381392
}
382-
t.signals[id] = signal
393+
stream.signal = make(chan struct{}, 1)
383394
t.mu.Unlock()
384395

385396
// Publish incoming messages.
@@ -390,29 +401,37 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
390401
// TODO(rfindley): consider optimizing for a single incoming request, by
391402
// responding with application/json when there is only a single message in
392403
// the response.
393-
t.streamResponse(w, req, id, -1, signal)
404+
t.streamResponse(stream, w, req, -1)
394405
}
395406

396407
// lastIndex is the index of the last seen event if resuming, else -1.
397-
func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, lastIndex int, signal chan struct{}) {
408+
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) {
398409
defer func() {
399410
t.mu.Lock()
400-
delete(t.signals, id)
411+
stream.signal = nil
401412
t.mu.Unlock()
402413
}()
403414

415+
t.mu.Lock()
416+
// Although there is a gap in locking between when stream.signal is set and here,
417+
// it cannot change, because it is changed only when non-nil, and it is only
418+
// set to nil in the defer above.
419+
signal := stream.signal
420+
t.mu.Unlock()
421+
404422
writes := 0
405423

406424
// write one event containing data.
407425
write := func(data []byte) bool {
408426
lastIndex++
409427
e := Event{
410428
Name: "message",
411-
ID: formatEventID(id, lastIndex),
429+
ID: formatEventID(stream.id, lastIndex),
412430
Data: data,
413431
}
414432
if _, err := writeEvent(w, e); err != nil {
415433
// Connection closed or broken.
434+
// TODO: log when we add server-side logging.
416435
return false
417436
}
418437
writes++
@@ -426,7 +445,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h
426445

427446
if lastIndex >= 0 {
428447
// Resume.
429-
for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), id, lastIndex) {
448+
for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), stream.id, lastIndex) {
430449
if err != nil {
431450
// TODO: reevaluate these status codes.
432451
// Maybe distinguish between storage errors, which are 500s, and missing
@@ -450,12 +469,12 @@ stream:
450469
// Repeatedly collect pending outgoing events and send them.
451470
for {
452471
t.mu.Lock()
453-
outgoing := t.outgoing[id]
454-
t.outgoing[id] = nil
472+
outgoing := stream.outgoing
473+
stream.outgoing = nil
455474
t.mu.Unlock()
456475

457476
for _, data := range outgoing {
458-
if err := t.opts.EventStore.Append(req.Context(), t.id, id, data); err != nil {
477+
if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil {
459478
http.Error(w, err.Error(), http.StatusInternalServerError)
460479
return
461480
}
@@ -465,7 +484,7 @@ stream:
465484
}
466485

467486
t.mu.Lock()
468-
nOutstanding := len(t.streamRequests[id])
487+
nOutstanding := len(stream.requests)
469488
t.mu.Unlock()
470489
// If all requests have been handled and replied to, we should terminate this connection.
471490
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
@@ -579,30 +598,31 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
579598
t.mu.Lock()
580599
defer t.mu.Unlock()
581600
if t.isDone {
582-
return fmt.Errorf("session is closed") // TODO: should this be EOF?
601+
return errors.New("session is closed") // TODO: should this be EOF?
583602
}
584603

585-
if _, ok := t.streamRequests[forConn]; !ok && forConn != 0 {
604+
stream := t.streams[forConn]
605+
if stream == nil {
606+
return fmt.Errorf("no stream with ID %d", forConn)
607+
}
608+
if len(stream.requests) == 0 && forConn != 0 {
586609
// No outstanding requests for this connection, which means it is logically
587610
// done. This is a sequencing violation from the server, so we should report
588611
// a side-channel error here. Put the message on the general queue to avoid
589612
// dropping messages.
590-
forConn = 0
613+
stream = t.streams[0]
591614
}
592615

593-
t.outgoing[forConn] = append(t.outgoing[forConn], data)
616+
stream.outgoing = append(stream.outgoing, data)
594617
if replyTo.IsValid() {
595618
// Once we've put the reply on the queue, it's no longer outstanding.
596-
delete(t.streamRequests[forConn], replyTo)
597-
if len(t.streamRequests[forConn]) == 0 {
598-
delete(t.streamRequests, forConn)
599-
}
619+
delete(stream.requests, replyTo)
600620
}
601621

602622
// Signal work.
603-
if c, ok := t.signals[forConn]; ok {
623+
if stream.signal != nil {
604624
select {
605-
case c <- struct{}{}:
625+
case stream.signal <- struct{}{}:
606626
default:
607627
}
608628
}

0 commit comments

Comments
 (0)