@@ -236,23 +236,26 @@ type stream struct {
236236 // ID 0 is used for messages that don't correlate with an incoming request.
237237 id StreamID
238238
239- // These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
239+ // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals
240+ // that there are messages available to write into the HTTP response.
241+ // In addition, the presence of a channel guarantees that at most one HTTP response
242+ // can receive messages for a logical stream. After claiming the stream, incoming
243+ // requests should read from outgoing, to ensure that no new messages are missed.
244+ //
245+ // To simplify locking, signal is an atomic. We need an atomic.Pointer, because
246+ // you can't set an atomic.Value to nil.
247+ //
248+ // Lifecycle: each channel value persists for the duration of an HTTP POST or
249+ // GET request for the given streamID.
250+ signal atomic.Pointer [chan struct {}]
251+
252+ // The following mutable fields are protected by the mutex of the containing
253+ // StreamableServerTransport.
240254
241255 // outgoing is the list of outgoing messages, enqueued by server methods that
242256 // write notifications and responses, and dequeued by streamResponse.
243257 outgoing [][]byte
244258
245- // signal is a 1-buffered channel, owned by an
246- // incoming HTTP request, that signals that there are messages available to
247- // write into the HTTP response. This guarantees that at most one HTTP
248- // response can receive messages for a logical stream. After claiming
249- // the stream, incoming requests should read from outgoing, to ensure
250- // that no new messages are missed.
251- //
252- // Lifecycle: persists for the duration of an HTTP POST or GET
253- // request for the given streamID.
254- signal chan struct {}
255-
256259 // streamRequests is the set of unanswered incoming RPCs for the stream.
257260 //
258261 // Lifecycle: requests values persist as until the requests have been
@@ -268,6 +271,11 @@ func newStream(id StreamID) *stream {
268271 }
269272}
270273
274+ func signalChanPtr () * chan struct {} {
275+ c := make (chan struct {}, 1 )
276+ return & c
277+ }
278+
271279// A StreamID identifies a stream of SSE events. It is unique within the stream's
272280// [ServerSession].
273281type StreamID int64
@@ -340,17 +348,14 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
340348
341349 t .mu .Lock ()
342350 stream , ok := t .streams [id ]
351+ t .mu .Unlock ()
343352 if ! ok {
344- t .mu .Unlock ()
345353 return http .StatusBadRequest , "unknown stream"
346354 }
347- if stream .signal != nil {
348- t . mu . Unlock ()
355+ if ! stream .signal . CompareAndSwap ( nil , signalChanPtr ()) {
356+ // The CAS returned false, meaning that the comparison failed: stream.signal is not nil.
349357 return http .StatusBadRequest , "stream ID conflicts with ongoing stream"
350358 }
351- stream .signal = make (chan struct {}, 1 )
352- t .mu .Unlock ()
353-
354359 return t .streamResponse (stream , w , req , lastIdx )
355360}
356361
@@ -389,8 +394,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
389394 t .requestStreams [reqID ] = stream .id
390395 stream .requests [reqID ] = struct {}{}
391396 }
392- stream .signal = make (chan struct {}, 1 )
393397 t .mu .Unlock ()
398+ stream .signal .Store (signalChanPtr ())
394399
395400 // Publish incoming messages.
396401 for _ , msg := range incoming {
@@ -405,18 +410,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
405410
406411// lastIndex is the index of the last seen event if resuming, else -1.
407412func (t * StreamableServerTransport ) streamResponse (stream * stream , w http.ResponseWriter , req * http.Request , lastIndex int ) (int , string ) {
408- defer func () {
409- t .mu .Lock ()
410- stream .signal = nil
411- t .mu .Unlock ()
412- }()
413-
414- t .mu .Lock ()
415- // Although there is a gap in locking between when stream.signal is set and here,
416- // it cannot change, because it is changed only when non-nil, and it is only
417- // set to nil in the defer above.
418- signal := stream .signal
419- t .mu .Unlock ()
413+ defer stream .signal .Store (nil )
420414
421415 writes := 0
422416
@@ -497,7 +491,7 @@ stream:
497491 }
498492
499493 select {
500- case <- signal : // there are new outgoing messages
494+ case <- * stream . signal . Load () : // there are new outgoing messages
501495 // return to top of loop
502496 case <- t .done : // session is closed
503497 if writes == 0 {
@@ -617,10 +611,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
617611 delete (stream .requests , replyTo )
618612 }
619613
620- // Signal work.
621- if stream .signal != nil {
614+ // Signal streamResponse that new work is available.
615+ signalp := stream .signal .Load ()
616+ if signalp != nil {
622617 select {
623- case stream . signal <- struct {}{}:
618+ case * signalp <- struct {}{}:
624619 default :
625620 }
626621 }
0 commit comments