Skip to content

Commit bad73bb

Browse files
committed
mcp: various cleanups to streamable code
- Move the done field above mu: it does not need the mutex to be held. - Use an atomic for signal, simplifying locking. - Return from servePOST and serveGET instead of calling http.Error. Each cleanup is in a separate commit.
1 parent bbfcdd1 commit bad73bb

File tree

1 file changed

+30
-35
lines changed

1 file changed

+30
-35
lines changed

mcp/streamable.go

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -242,23 +242,26 @@ type stream struct {
242242
// ID 0 is used for messages that don't correlate with an incoming request.
243243
id StreamID
244244

245-
// These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
245+
// signal is a 1-buffered channel, owned by an incoming HTTP request, that signals
246+
// that there are messages available to write into the HTTP response.
247+
// In addition, the presence of a channel guarantees that at most one HTTP response
248+
// can receive messages for a logical stream. After claiming the stream, incoming
249+
// requests should read from outgoing, to ensure that no new messages are missed.
250+
//
251+
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because
252+
// you can't set an atomic.Value to nil.
253+
//
254+
// Lifecycle: each channel value persists for the duration of an HTTP POST or
255+
// GET request for the given streamID.
256+
signal atomic.Pointer[chan struct{}]
257+
258+
// The following mutable fields are protected by the mutex of the containing
259+
// StreamableServerTransport.
246260

247261
// outgoing is the list of outgoing messages, enqueued by server methods that
248262
// write notifications and responses, and dequeued by streamResponse.
249263
outgoing [][]byte
250264

251-
// signal is a 1-buffered channel, owned by an
252-
// incoming HTTP request, that signals that there are messages available to
253-
// write into the HTTP response. This guarantees that at most one HTTP
254-
// response can receive messages for a logical stream. After claiming
255-
// the stream, incoming requests should read from outgoing, to ensure
256-
// that no new messages are missed.
257-
//
258-
// Lifecycle: persists for the duration of an HTTP POST or GET
259-
// request for the given streamID.
260-
signal chan struct{}
261-
262265
// streamRequests is the set of unanswered incoming RPCs for the stream.
263266
//
264267
// Lifecycle: requests values persist until the requests have been
@@ -274,6 +277,11 @@ func newStream(id StreamID) *stream {
274277
}
275278
}
276279

280+
func signalChanPtr() *chan struct{} {
281+
c := make(chan struct{}, 1)
282+
return &c
283+
}
284+
277285
// A StreamID identifies a stream of SSE events. It is unique within the stream's
278286
// [ServerSession].
279287
type StreamID int64
@@ -346,17 +354,14 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
346354

347355
t.mu.Lock()
348356
stream, ok := t.streams[id]
357+
t.mu.Unlock()
349358
if !ok {
350-
t.mu.Unlock()
351359
return http.StatusBadRequest, "unknown stream"
352360
}
353-
if stream.signal != nil {
354-
t.mu.Unlock()
361+
if !stream.signal.CompareAndSwap(nil, signalChanPtr()) {
362+
// The CAS returned false, meaning that the comparison failed: stream.signal is not nil.
355363
return http.StatusBadRequest, "stream ID conflicts with ongoing stream"
356364
}
357-
stream.signal = make(chan struct{}, 1)
358-
t.mu.Unlock()
359-
360365
return t.streamResponse(stream, w, req, lastIdx)
361366
}
362367

@@ -395,8 +400,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
395400
t.requestStreams[reqID] = stream.id
396401
stream.requests[reqID] = struct{}{}
397402
}
398-
stream.signal = make(chan struct{}, 1)
399403
t.mu.Unlock()
404+
stream.signal.Store(signalChanPtr())
400405

401406
// Publish incoming messages.
402407
for _, msg := range incoming {
@@ -411,18 +416,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
411416

412417
// lastIndex is the index of the last seen event if resuming, else -1.
413418
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) (int, string) {
414-
defer func() {
415-
t.mu.Lock()
416-
stream.signal = nil
417-
t.mu.Unlock()
418-
}()
419-
420-
t.mu.Lock()
421-
// Although there is a gap in locking between when stream.signal is set and here,
422-
// it cannot change, because it is changed only when non-nil, and it is only
423-
// set to nil in the defer above.
424-
signal := stream.signal
425-
t.mu.Unlock()
419+
defer stream.signal.Store(nil)
426420

427421
writes := 0
428422

@@ -503,7 +497,7 @@ stream:
503497
}
504498

505499
select {
506-
case <-signal: // there are new outgoing messages
500+
case <-*stream.signal.Load(): // there are new outgoing messages
507501
// return to top of loop
508502
case <-t.done: // session is closed
509503
if writes == 0 {
@@ -623,10 +617,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
623617
delete(stream.requests, replyTo)
624618
}
625619

626-
// Signal work.
627-
if stream.signal != nil {
620+
// Signal streamResponse that new work is available.
621+
signalp := stream.signal.Load()
622+
if signalp != nil {
628623
select {
629-
case stream.signal <- struct{}{}:
624+
case *signalp <- struct{}{}:
630625
default:
631626
}
632627
}

0 commit comments

Comments
 (0)