Skip to content

Commit a5aa370

Browse files
authored
mcp: various cleanups to streamable code (#184)
- move done field outside of mutex hat - refactor handler into a function that returns `error` to avoid repeated `http.Error` calls - use an atomic for the signal channel to simplify locking
1 parent 619bc41 commit a5aa370

File tree

1 file changed

+61
-68
lines changed

1 file changed

+61
-68
lines changed

mcp/streamable.go

Lines changed: 61 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,11 @@ type StreamableServerTransport struct {
204204
id string
205205
opts StreamableServerTransportOptions
206206
incoming chan jsonrpc.Message // messages from the client to the server
207+
done chan struct{}
207208

208209
mu sync.Mutex
209210
// Sessions are closed exactly once.
210211
isDone bool
211-
done chan struct{}
212212

213213
// Sessions can have multiple logical connections, corresponding to HTTP
214214
// requests. Additionally, logical sessions may be resumed by subsequent HTTP
@@ -242,23 +242,26 @@ type stream struct {
242242
// ID 0 is used for messages that don't correlate with an incoming request.
243243
id StreamID
244244

245-
// These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
245+
// signal is a 1-buffered channel, owned by an incoming HTTP request, that signals
246+
// that there are messages available to write into the HTTP response.
247+
// In addition, the presence of a channel guarantees that at most one HTTP response
248+
// can receive messages for a logical stream. After claiming the stream, incoming
249+
// requests should read from outgoing, to ensure that no new messages are missed.
250+
//
251+
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because
252+
// you can't set an atomic.Value to nil.
253+
//
254+
// Lifecycle: each channel value persists for the duration of an HTTP POST or
255+
// GET request for the given streamID.
256+
signal atomic.Pointer[chan struct{}]
257+
258+
// The following mutable fields are protected by the mutex of the containing
259+
// StreamableServerTransport.
246260

247261
// outgoing is the list of outgoing messages, enqueued by server methods that
248262
// write notifications and responses, and dequeued by streamResponse.
249263
outgoing [][]byte
250264

251-
// signal is a 1-buffered channel, owned by an
252-
// incoming HTTP request, that signals that there are messages available to
253-
// write into the HTTP response. This guarantees that at most one HTTP
254-
// response can receive messages for a logical stream. After claiming
255-
// the stream, incoming requests should read from outgoing, to ensure
256-
// that no new messages are missed.
257-
//
258-
// Lifecycle: persists for the duration of an HTTP POST or GET
259-
// request for the given streamID.
260-
signal chan struct{}
261-
262265
// streamRequests is the set of unanswered incoming RPCs for the stream.
263266
//
264267
// Lifecycle: requests values persist until the requests have been
@@ -274,6 +277,11 @@ func newStream(id StreamID) *stream {
274277
}
275278
}
276279

280+
func signalChanPtr() *chan struct{} {
281+
c := make(chan struct{}, 1)
282+
return &c
283+
}
284+
277285
// A StreamID identifies a stream of SSE events. It is unique within the stream's
278286
// [ServerSession].
279287
type StreamID int64
@@ -310,19 +318,25 @@ type idContextKey struct{}
310318

311319
// ServeHTTP handles a single HTTP request for the session.
312320
func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) {
321+
status := 0
322+
message := ""
313323
switch req.Method {
314324
case http.MethodGet:
315-
t.serveGET(w, req)
325+
status, message = t.serveGET(w, req)
316326
case http.MethodPost:
317-
t.servePOST(w, req)
327+
status, message = t.servePOST(w, req)
318328
default:
319329
// Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP.
320330
w.Header().Set("Allow", "GET, POST")
321-
http.Error(w, "unsupported method", http.StatusMethodNotAllowed)
331+
status = http.StatusMethodNotAllowed
332+
message = "unsupported method"
333+
}
334+
if status != 0 && status != http.StatusOK {
335+
http.Error(w, message, status)
322336
}
323337
}
324338

325-
func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) {
339+
func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) (int, string) {
326340
// connID 0 corresponds to the default GET request.
327341
id := StreamID(0)
328342
// By default, we haven't seen a last index. Since indices start at 0, we represent
@@ -334,49 +348,39 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
334348
var ok bool
335349
id, lastIdx, ok = parseEventID(eid)
336350
if !ok {
337-
http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest)
338-
return
351+
return http.StatusBadRequest, fmt.Sprintf("malformed Last-Event-ID %q", eid)
339352
}
340353
}
341354

342355
t.mu.Lock()
343356
stream, ok := t.streams[id]
357+
t.mu.Unlock()
344358
if !ok {
345-
http.Error(w, "unknown stream", http.StatusBadRequest)
346-
t.mu.Unlock()
347-
return
359+
return http.StatusBadRequest, "unknown stream"
348360
}
349-
if stream.signal != nil {
350-
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest)
351-
t.mu.Unlock()
352-
return
361+
if !stream.signal.CompareAndSwap(nil, signalChanPtr()) {
362+
// The CAS returned false, meaning that the comparison failed: stream.signal is not nil.
363+
return http.StatusBadRequest, "stream ID conflicts with ongoing stream"
353364
}
354-
stream.signal = make(chan struct{}, 1)
355-
t.mu.Unlock()
356-
357-
t.streamResponse(stream, w, req, lastIdx)
365+
return t.streamResponse(stream, w, req, lastIdx)
358366
}
359367

360-
func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) {
368+
func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) (int, string) {
361369
if len(req.Header.Values("Last-Event-ID")) > 0 {
362-
http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest)
363-
return
370+
return http.StatusBadRequest, "can't send Last-Event-ID for POST request"
364371
}
365372

366373
// Read incoming messages.
367374
body, err := io.ReadAll(req.Body)
368375
if err != nil {
369-
http.Error(w, "failed to read body", http.StatusBadRequest)
370-
return
376+
return http.StatusBadRequest, "failed to read body"
371377
}
372378
if len(body) == 0 {
373-
http.Error(w, "POST requires a non-empty body", http.StatusBadRequest)
374-
return
379+
return http.StatusBadRequest, "POST requires a non-empty body"
375380
}
376381
incoming, _, err := readBatch(body)
377382
if err != nil {
378-
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
379-
return
383+
return http.StatusBadRequest, fmt.Sprintf("malformed payload: %v", err)
380384
}
381385
requests := make(map[jsonrpc.ID]struct{})
382386
for _, msg := range incoming {
@@ -396,8 +400,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
396400
t.requestStreams[reqID] = stream.id
397401
stream.requests[reqID] = struct{}{}
398402
}
399-
stream.signal = make(chan struct{}, 1)
400403
t.mu.Unlock()
404+
stream.signal.Store(signalChanPtr())
401405

402406
// Publish incoming messages.
403407
for _, msg := range incoming {
@@ -407,23 +411,12 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
407411
// TODO(rfindley): consider optimizing for a single incoming request, by
408412
// responding with application/json when there is only a single message in
409413
// the response.
410-
t.streamResponse(stream, w, req, -1)
414+
return t.streamResponse(stream, w, req, -1)
411415
}
412416

413417
// lastIndex is the index of the last seen event if resuming, else -1.
414-
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) {
415-
defer func() {
416-
t.mu.Lock()
417-
stream.signal = nil
418-
t.mu.Unlock()
419-
}()
420-
421-
t.mu.Lock()
422-
// Although there is a gap in locking between when stream.signal is set and here,
423-
// it cannot change, because it is changed only when non-nil, and it is only
424-
// set to nil in the defer above.
425-
signal := stream.signal
426-
t.mu.Unlock()
418+
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) (int, string) {
419+
defer stream.signal.Store(nil)
427420

428421
writes := 0
429422

@@ -437,7 +430,7 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon
437430
}
438431
if _, err := writeEvent(w, e); err != nil {
439432
// Connection closed or broken.
440-
// TODO: log when we add server-side logging.
433+
// TODO(#170): log when we add server-side logging.
441434
return false
442435
}
443436
writes++
@@ -460,13 +453,12 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon
460453
if errors.Is(err, ErrEventsPurged) {
461454
status = http.StatusInsufficientStorage
462455
}
463-
http.Error(w, err.Error(), status)
464-
return
456+
return status, err.Error()
465457
}
466458
// The iterator yields events beginning just after lastIndex, or it would have
467459
// yielded an error.
468460
if !write(data) {
469-
return
461+
return 0, ""
470462
}
471463
}
472464
}
@@ -481,11 +473,10 @@ stream:
481473

482474
for _, data := range outgoing {
483475
if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil {
484-
http.Error(w, err.Error(), http.StatusInternalServerError)
485-
return
476+
return http.StatusInternalServerError, err.Error()
486477
}
487478
if !write(data) {
488-
return
479+
return 0, ""
489480
}
490481
}
491482

@@ -495,22 +486,22 @@ stream:
495486
// If all requests have been handled and replied to, we should terminate this connection.
496487
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
497488
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
498-
// TODO(jba): why not terminate regardless of http method?
489+
// TODO(jba,findleyr): why not terminate regardless of http method?
499490
if req.Method == http.MethodPost && nOutstanding == 0 {
500491
if writes == 0 {
501492
// Spec: If the server accepts the input, the server MUST return HTTP
502493
// status code 202 Accepted with no body.
503494
w.WriteHeader(http.StatusAccepted)
504495
}
505-
return
496+
return 0, ""
506497
}
507498

508499
select {
509-
case <-signal: // there are new outgoing messages
500+
case <-*stream.signal.Load(): // there are new outgoing messages
510501
// return to top of loop
511502
case <-t.done: // session is closed
512503
if writes == 0 {
513-
http.Error(w, "session terminated", http.StatusGone)
504+
return http.StatusGone, "session terminated"
514505
}
515506
break stream
516507
case <-req.Context().Done():
@@ -520,6 +511,7 @@ stream:
520511
break stream
521512
}
522513
}
514+
return 0, ""
523515
}
524516

525517
// Event IDs: encode both the logical connection ID and the index, as
@@ -625,10 +617,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
625617
delete(stream.requests, replyTo)
626618
}
627619

628-
// Signal work.
629-
if stream.signal != nil {
620+
// Signal streamResponse that new work is available.
621+
signalp := stream.signal.Load()
622+
if signalp != nil {
630623
select {
631-
case stream.signal <- struct{}{}:
624+
case *signalp <- struct{}{}:
632625
default:
633626
}
634627
}

0 commit comments

Comments
 (0)