@@ -242,23 +242,26 @@ type stream struct {
242242 // ID 0 is used for messages that don't correlate with an incoming request.
243243 id StreamID
244244
245- // These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
245+ // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals
246+ // that there are messages available to write into the HTTP response.
247+ // In addition, the presence of a channel guarantees that at most one HTTP response
248+ // can receive messages for a logical stream. After claiming the stream, incoming
249+ // requests should read from outgoing, to ensure that no new messages are missed.
250+ //
251+ // To simplify locking, signal is an atomic. We need an atomic.Pointer, because
252+ // you can't set an atomic.Value to nil.
253+ //
254+ // Lifecycle: each channel value persists for the duration of an HTTP POST or
255+ // GET request for the given streamID.
256+ signal atomic.Pointer [chan struct {}]
257+
258+ // The following mutable fields are protected by the mutex of the containing
259+ // StreamableServerTransport.
246260
247261 // outgoing is the list of outgoing messages, enqueued by server methods that
248262 // write notifications and responses, and dequeued by streamResponse.
249263 outgoing [][]byte
250264
251- // signal is a 1-buffered channel, owned by an
252- // incoming HTTP request, that signals that there are messages available to
253- // write into the HTTP response. This guarantees that at most one HTTP
254- // response can receive messages for a logical stream. After claiming
255- // the stream, incoming requests should read from outgoing, to ensure
256- // that no new messages are missed.
257- //
258- // Lifecycle: persists for the duration of an HTTP POST or GET
259- // request for the given streamID.
260- signal chan struct {}
261-
262265 // streamRequests is the set of unanswered incoming RPCs for the stream.
263266 //
264267 // Lifecycle: requests values persist until the requests have been
@@ -274,6 +277,11 @@ func newStream(id StreamID) *stream {
274277 }
275278}
276279
280+ func signalChanPtr () * chan struct {} {
281+ c := make (chan struct {}, 1 )
282+ return & c
283+ }
284+
277285// A StreamID identifies a stream of SSE events. It is unique within the stream's
278286// [ServerSession].
279287type StreamID int64
@@ -346,17 +354,14 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
346354
347355 t .mu .Lock ()
348356 stream , ok := t .streams [id ]
357+ t .mu .Unlock ()
349358 if ! ok {
350- t .mu .Unlock ()
351359 return http .StatusBadRequest , "unknown stream"
352360 }
353- if stream .signal != nil {
354- t . mu . Unlock ()
361+ if ! stream .signal . CompareAndSwap ( nil , signalChanPtr ()) {
362+ // The CAS returned false, meaning that the comparison failed: stream.signal is not nil.
355363 return http .StatusBadRequest , "stream ID conflicts with ongoing stream"
356364 }
357- stream .signal = make (chan struct {}, 1 )
358- t .mu .Unlock ()
359-
360365 return t .streamResponse (stream , w , req , lastIdx )
361366}
362367
@@ -395,8 +400,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
395400 t .requestStreams [reqID ] = stream .id
396401 stream .requests [reqID ] = struct {}{}
397402 }
398- stream .signal = make (chan struct {}, 1 )
399403 t .mu .Unlock ()
404+ stream .signal .Store (signalChanPtr ())
400405
401406 // Publish incoming messages.
402407 for _ , msg := range incoming {
@@ -411,18 +416,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
411416
412417// lastIndex is the index of the last seen event if resuming, else -1.
413418func (t * StreamableServerTransport ) streamResponse (stream * stream , w http.ResponseWriter , req * http.Request , lastIndex int ) (int , string ) {
414- defer func () {
415- t .mu .Lock ()
416- stream .signal = nil
417- t .mu .Unlock ()
418- }()
419-
420- t .mu .Lock ()
421- // Although there is a gap in locking between when stream.signal is set and here,
422- // it cannot change, because it is changed only when non-nil, and it is only
423- // set to nil in the defer above.
424- signal := stream .signal
425- t .mu .Unlock ()
419+ defer stream .signal .Store (nil )
426420
427421 writes := 0
428422
@@ -503,7 +497,7 @@ stream:
503497 }
504498
505499 select {
506- case <- signal : // there are new outgoing messages
500+ case <- * stream . signal . Load () : // there are new outgoing messages
507501 // return to top of loop
508502 case <- t .done : // session is closed
509503 if writes == 0 {
@@ -623,10 +617,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
623617 delete (stream .requests , replyTo )
624618 }
625619
626- // Signal work.
627- if stream .signal != nil {
620+ // Signal streamResponse that new work is available.
621+ signalp := stream .signal .Load ()
622+ if signalp != nil {
628623 select {
629- case stream . signal <- struct {}{}:
624+ case * signalp <- struct {}{}:
630625 default :
631626 }
632627 }
0 commit comments