@@ -204,11 +204,11 @@ type StreamableServerTransport struct {
204204 id string
205205 opts StreamableServerTransportOptions
206206 incoming chan jsonrpc.Message // messages from the client to the server
207+ done chan struct {}
207208
208209 mu sync.Mutex
209210 // Sessions are closed exactly once.
210211 isDone bool
211- done chan struct {}
212212
213213 // Sessions can have multiple logical connections, corresponding to HTTP
214214 // requests. Additionally, logical sessions may be resumed by subsequent HTTP
@@ -242,23 +242,26 @@ type stream struct {
242242 // ID 0 is used for messages that don't correlate with an incoming request.
243243 id StreamID
244244
245- // These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
245+ // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals
246+ // that there are messages available to write into the HTTP response.
247+ // In addition, the presence of a channel guarantees that at most one HTTP response
248+ // can receive messages for a logical stream. After claiming the stream, incoming
249+ // requests should read from outgoing, to ensure that no new messages are missed.
250+ //
251+ // To simplify locking, signal is an atomic. We need an atomic.Pointer, because
252+ // you can't set an atomic.Value to nil.
253+ //
254+ // Lifecycle: each channel value persists for the duration of an HTTP POST or
255+ // GET request for the given streamID.
256+ signal atomic.Pointer [chan struct {}]
257+
258+ // The following mutable fields are protected by the mutex of the containing
259+ // StreamableServerTransport.
246260
247261 // outgoing is the list of outgoing messages, enqueued by server methods that
248262 // write notifications and responses, and dequeued by streamResponse.
249263 outgoing [][]byte
250264
251- // signal is a 1-buffered channel, owned by an
252- // incoming HTTP request, that signals that there are messages available to
253- // write into the HTTP response. This guarantees that at most one HTTP
254- // response can receive messages for a logical stream. After claiming
255- // the stream, incoming requests should read from outgoing, to ensure
256- // that no new messages are missed.
257- //
258- // Lifecycle: persists for the duration of an HTTP POST or GET
259- // request for the given streamID.
260- signal chan struct {}
261-
262265 // streamRequests is the set of unanswered incoming RPCs for the stream.
263266 //
264267 // Lifecycle: requests values persist until the requests have been
@@ -274,6 +277,11 @@ func newStream(id StreamID) *stream {
274277 }
275278}
276279
280+ func signalChanPtr () * chan struct {} {
281+ c := make (chan struct {}, 1 )
282+ return & c
283+ }
284+
277285// A StreamID identifies a stream of SSE events. It is unique within the stream's
278286// [ServerSession].
279287type StreamID int64
@@ -310,19 +318,25 @@ type idContextKey struct{}
310318
311319// ServeHTTP handles a single HTTP request for the session.
312320func (t * StreamableServerTransport ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
321+ status := 0
322+ message := ""
313323 switch req .Method {
314324 case http .MethodGet :
315- t .serveGET (w , req )
325+ status , message = t .serveGET (w , req )
316326 case http .MethodPost :
317- t .servePOST (w , req )
327+ status , message = t .servePOST (w , req )
318328 default :
319329 // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP.
320330 w .Header ().Set ("Allow" , "GET, POST" )
321- http .Error (w , "unsupported method" , http .StatusMethodNotAllowed )
331+ status = http .StatusMethodNotAllowed
332+ message = "unsupported method"
333+ }
334+ if status != 0 && status != http .StatusOK {
335+ http .Error (w , message , status )
322336 }
323337}
324338
325- func (t * StreamableServerTransport ) serveGET (w http.ResponseWriter , req * http.Request ) {
339+ func (t * StreamableServerTransport ) serveGET (w http.ResponseWriter , req * http.Request ) ( int , string ) {
326340 // connID 0 corresponds to the default GET request.
327341 id := StreamID (0 )
328342 // By default, we haven't seen a last index. Since indices start at 0, we represent
@@ -334,49 +348,39 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
334348 var ok bool
335349 id , lastIdx , ok = parseEventID (eid )
336350 if ! ok {
337- http .Error (w , fmt .Sprintf ("malformed Last-Event-ID %q" , eid ), http .StatusBadRequest )
338- return
351+ return http .StatusBadRequest , fmt .Sprintf ("malformed Last-Event-ID %q" , eid )
339352 }
340353 }
341354
342355 t .mu .Lock ()
343356 stream , ok := t .streams [id ]
357+ t .mu .Unlock ()
344358 if ! ok {
345- http .Error (w , "unknown stream" , http .StatusBadRequest )
346- t .mu .Unlock ()
347- return
359+ return http .StatusBadRequest , "unknown stream"
348360 }
349- if stream .signal != nil {
350- http .Error (w , "stream ID conflicts with ongoing stream" , http .StatusBadRequest )
351- t .mu .Unlock ()
352- return
361+ if ! stream .signal .CompareAndSwap (nil , signalChanPtr ()) {
362+ // The CAS returned false, meaning that the comparison failed: stream.signal is not nil.
363+ return http .StatusBadRequest , "stream ID conflicts with ongoing stream"
353364 }
354- stream .signal = make (chan struct {}, 1 )
355- t .mu .Unlock ()
356-
357- t .streamResponse (stream , w , req , lastIdx )
365+ return t .streamResponse (stream , w , req , lastIdx )
358366}
359367
360- func (t * StreamableServerTransport ) servePOST (w http.ResponseWriter , req * http.Request ) {
368+ func (t * StreamableServerTransport ) servePOST (w http.ResponseWriter , req * http.Request ) ( int , string ) {
361369 if len (req .Header .Values ("Last-Event-ID" )) > 0 {
362- http .Error (w , "can't send Last-Event-ID for POST request" , http .StatusBadRequest )
363- return
370+ return http .StatusBadRequest , "can't send Last-Event-ID for POST request"
364371 }
365372
366373 // Read incoming messages.
367374 body , err := io .ReadAll (req .Body )
368375 if err != nil {
369- http .Error (w , "failed to read body" , http .StatusBadRequest )
370- return
376+ return http .StatusBadRequest , "failed to read body"
371377 }
372378 if len (body ) == 0 {
373- http .Error (w , "POST requires a non-empty body" , http .StatusBadRequest )
374- return
379+ return http .StatusBadRequest , "POST requires a non-empty body"
375380 }
376381 incoming , _ , err := readBatch (body )
377382 if err != nil {
378- http .Error (w , fmt .Sprintf ("malformed payload: %v" , err ), http .StatusBadRequest )
379- return
383+ return http .StatusBadRequest , fmt .Sprintf ("malformed payload: %v" , err )
380384 }
381385 requests := make (map [jsonrpc.ID ]struct {})
382386 for _ , msg := range incoming {
@@ -396,8 +400,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
396400 t .requestStreams [reqID ] = stream .id
397401 stream .requests [reqID ] = struct {}{}
398402 }
399- stream .signal = make (chan struct {}, 1 )
400403 t .mu .Unlock ()
404+ stream .signal .Store (signalChanPtr ())
401405
402406 // Publish incoming messages.
403407 for _ , msg := range incoming {
@@ -407,23 +411,12 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
407411 // TODO(rfindley): consider optimizing for a single incoming request, by
408412 // responding with application/json when there is only a single message in
409413 // the response.
410- t .streamResponse (stream , w , req , - 1 )
414+ return t .streamResponse (stream , w , req , - 1 )
411415}
412416
413417// lastIndex is the index of the last seen event if resuming, else -1.
414- func (t * StreamableServerTransport ) streamResponse (stream * stream , w http.ResponseWriter , req * http.Request , lastIndex int ) {
415- defer func () {
416- t .mu .Lock ()
417- stream .signal = nil
418- t .mu .Unlock ()
419- }()
420-
421- t .mu .Lock ()
422- // Although there is a gap in locking between when stream.signal is set and here,
423- // it cannot change, because it is changed only when non-nil, and it is only
424- // set to nil in the defer above.
425- signal := stream .signal
426- t .mu .Unlock ()
418+ func (t * StreamableServerTransport ) streamResponse (stream * stream , w http.ResponseWriter , req * http.Request , lastIndex int ) (int , string ) {
419+ defer stream .signal .Store (nil )
427420
428421 writes := 0
429422
@@ -437,7 +430,7 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon
437430 }
438431 if _ , err := writeEvent (w , e ); err != nil {
439432 // Connection closed or broken.
440- // TODO: log when we add server-side logging.
433+ // TODO(#170) : log when we add server-side logging.
441434 return false
442435 }
443436 writes ++
@@ -460,13 +453,12 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon
460453 if errors .Is (err , ErrEventsPurged ) {
461454 status = http .StatusInsufficientStorage
462455 }
463- http .Error (w , err .Error (), status )
464- return
456+ return status , err .Error ()
465457 }
466458 // The iterator yields events beginning just after lastIndex, or it would have
467459 // yielded an error.
468460 if ! write (data ) {
469- return
461+ return 0 , ""
470462 }
471463 }
472464 }
@@ -481,11 +473,10 @@ stream:
481473
482474 for _ , data := range outgoing {
483475 if err := t .opts .EventStore .Append (req .Context (), t .SessionID (), stream .id , data ); err != nil {
484- http .Error (w , err .Error (), http .StatusInternalServerError )
485- return
476+ return http .StatusInternalServerError , err .Error ()
486477 }
487478 if ! write (data ) {
488- return
479+ return 0 , ""
489480 }
490481 }
491482
@@ -495,22 +486,22 @@ stream:
495486 // If all requests have been handled and replied to, we should terminate this connection.
496487 // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
497488 // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
498- // TODO(jba): why not terminate regardless of http method?
489+ // TODO(jba,findleyr ): why not terminate regardless of http method?
499490 if req .Method == http .MethodPost && nOutstanding == 0 {
500491 if writes == 0 {
501492 // Spec: If the server accepts the input, the server MUST return HTTP
502493 // status code 202 Accepted with no body.
503494 w .WriteHeader (http .StatusAccepted )
504495 }
505- return
496+ return 0 , ""
506497 }
507498
508499 select {
509- case <- signal : // there are new outgoing messages
500+ case <- * stream . signal . Load () : // there are new outgoing messages
510501 // return to top of loop
511502 case <- t .done : // session is closed
512503 if writes == 0 {
513- http .Error ( w , "session terminated" , http . StatusGone )
504+ return http .StatusGone , "session terminated"
514505 }
515506 break stream
516507 case <- req .Context ().Done ():
@@ -520,6 +511,7 @@ stream:
520511 break stream
521512 }
522513 }
514+ return 0 , ""
523515}
524516
525517// Event IDs: encode both the logical connection ID and the index, as
@@ -625,10 +617,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
625617 delete (stream .requests , replyTo )
626618 }
627619
628- // Signal work.
629- if stream .signal != nil {
620+ // Signal streamResponse that new work is available.
621+ signalp := stream .signal .Load ()
622+ if signalp != nil {
630623 select {
631- case stream . signal <- struct {}{}:
624+ case * signalp <- struct {}{}:
632625 default :
633626 }
634627 }
0 commit comments