Skip to content

Commit 82965b0

Browse files
committed
mcp: various cleanups to streamable code
- Move the done field above mu: it does not need the mutex to be held. - Use an atomic for signal, simplifying locking. - Return from servePOST and serveGET instead of calling http.Error. Each cleanup is in a separate commit.
1 parent 5ac2016 commit 82965b0

File tree

1 file changed

+30
-35
lines changed

1 file changed

+30
-35
lines changed

mcp/streamable.go

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -236,23 +236,26 @@ type stream struct {
236236
// ID 0 is used for messages that don't correlate with an incoming request.
237237
id StreamID
238238

239-
// These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
239+
// signal is a 1-buffered channel, owned by an incoming HTTP request, that signals
240+
// that there are messages available to write into the HTTP response.
241+
// In addition, the presence of a channel guarantees that at most one HTTP response
242+
// can receive messages for a logical stream. After claiming the stream, incoming
243+
// requests should read from outgoing, to ensure that no new messages are missed.
244+
//
245+
// To simplify locking, signal is an atomic. We need an atomic.Pointer, because
246+
// you can't set an atomic.Value to nil.
247+
//
248+
// Lifecycle: each channel value persists for the duration of an HTTP POST or
249+
// GET request for the given streamID.
250+
signal atomic.Pointer[chan struct{}]
251+
252+
// The following mutable fields are protected by the mutex of the containing
253+
// StreamableServerTransport.
240254

241255
// outgoing is the list of outgoing messages, enqueued by server methods that
242256
// write notifications and responses, and dequeued by streamResponse.
243257
outgoing [][]byte
244258

245-
// signal is a 1-buffered channel, owned by an
246-
// incoming HTTP request, that signals that there are messages available to
247-
// write into the HTTP response. This guarantees that at most one HTTP
248-
// response can receive messages for a logical stream. After claiming
249-
// the stream, incoming requests should read from outgoing, to ensure
250-
// that no new messages are missed.
251-
//
252-
// Lifecycle: persists for the duration of an HTTP POST or GET
253-
// request for the given streamID.
254-
signal chan struct{}
255-
256259
// streamRequests is the set of unanswered incoming RPCs for the stream.
257260
//
258261
// Lifecycle: requests values persist as until the requests have been
@@ -268,6 +271,11 @@ func newStream(id StreamID) *stream {
268271
}
269272
}
270273

274+
func signalChanPtr() *chan struct{} {
275+
c := make(chan struct{}, 1)
276+
return &c
277+
}
278+
271279
// A StreamID identifies a stream of SSE events. It is unique within the stream's
272280
// [ServerSession].
273281
type StreamID int64
@@ -340,17 +348,14 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
340348

341349
t.mu.Lock()
342350
stream, ok := t.streams[id]
351+
t.mu.Unlock()
343352
if !ok {
344-
t.mu.Unlock()
345353
return http.StatusBadRequest, "unknown stream"
346354
}
347-
if stream.signal != nil {
348-
t.mu.Unlock()
355+
if !stream.signal.CompareAndSwap(nil, signalChanPtr()) {
356+
// The CAS returned false, meaning that the comparison failed: stream.signal is not nil.
349357
return http.StatusBadRequest, "stream ID conflicts with ongoing stream"
350358
}
351-
stream.signal = make(chan struct{}, 1)
352-
t.mu.Unlock()
353-
354359
return t.streamResponse(stream, w, req, lastIdx)
355360
}
356361

@@ -389,8 +394,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
389394
t.requestStreams[reqID] = stream.id
390395
stream.requests[reqID] = struct{}{}
391396
}
392-
stream.signal = make(chan struct{}, 1)
393397
t.mu.Unlock()
398+
stream.signal.Store(signalChanPtr())
394399

395400
// Publish incoming messages.
396401
for _, msg := range incoming {
@@ -405,18 +410,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
405410

406411
// lastIndex is the index of the last seen event if resuming, else -1.
407412
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) (int, string) {
408-
defer func() {
409-
t.mu.Lock()
410-
stream.signal = nil
411-
t.mu.Unlock()
412-
}()
413-
414-
t.mu.Lock()
415-
// Although there is a gap in locking between when stream.signal is set and here,
416-
// it cannot change, because it is changed only when non-nil, and it is only
417-
// set to nil in the defer above.
418-
signal := stream.signal
419-
t.mu.Unlock()
413+
defer stream.signal.Store(nil)
420414

421415
writes := 0
422416

@@ -497,7 +491,7 @@ stream:
497491
}
498492

499493
select {
500-
case <-signal: // there are new outgoing messages
494+
case <-*stream.signal.Load(): // there are new outgoing messages
501495
// return to top of loop
502496
case <-t.done: // session is closed
503497
if writes == 0 {
@@ -617,10 +611,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
617611
delete(stream.requests, replyTo)
618612
}
619613

620-
// Signal work.
621-
if stream.signal != nil {
614+
// Signal streamResponse that new work is available.
615+
signalp := stream.signal.Load()
616+
if signalp != nil {
622617
select {
623-
case stream.signal <- struct{}{}:
618+
case *signalp <- struct{}{}:
624619
default:
625620
}
626621
}

0 commit comments

Comments
 (0)