@@ -491,7 +491,7 @@ type stream struct {
491491 // If non-nil, deliver writes data directly to the HTTP response.
492492 //
493493 // Only one HTTP response may receive messages at a given time. An active
494- // HTTP connection acquires ownership of the stream by setting
494+ // HTTP connection acquires ownership of the stream by setting this field.
495495 deliver func (data []byte , final bool ) error
496496
497497 // streamRequests is the set of unanswered incoming requests for the stream.
@@ -500,10 +500,10 @@ type stream struct {
500500 requests map [jsonrpc.ID ]struct {}
501501}
502502
503- // done reports whether the stream is logically complete.
503+ // doneLocked reports whether the stream is logically complete.
504504//
505505// s.mu must be held while calling this function.
506- func (s * stream ) done () bool {
506+ func (s * stream ) doneLocked () bool {
507507 return len (s .requests ) == 0 && s .id != ""
508508}
509509
@@ -585,42 +585,73 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
585585 }
586586 }
587587
588+ ctx , cancel := context .WithCancel (req .Context ())
589+ defer cancel ()
590+
591+ stream , done := c .acquireStream (ctx , w , streamID , & lastIdx )
592+ if stream == nil {
593+ return
594+ }
595+ // Release the stream when we're done.
596+ defer func () {
597+ stream .mu .Lock ()
598+ stream .deliver = nil
599+ stream .mu .Unlock ()
600+ }()
601+
602+ select {
603+ case <- ctx .Done ():
604+ // request cancelled
605+ case <- done :
606+ // request complete
607+ case <- c .done :
608+ // session closed
609+ }
610+ }
611+
612+ // writeEvent writes an SSE event to w corresponding to the given stream, data, and index.
613+ // lastIdx is incremented before writing, so that it continues to point to the index of the
614+ // last event written to the stream.
615+ func (c * streamableServerConn ) writeEvent (w http.ResponseWriter , stream * stream , data []byte , lastIdx * int ) error {
616+ * lastIdx ++
617+ e := Event {
618+ Name : "message" ,
619+ Data : data ,
620+ }
621+ if c .eventStore != nil {
622+ e .ID = formatEventID (stream .id , * lastIdx )
623+ }
624+ if _ , err := writeEvent (w , e ); err != nil {
625+ return err
626+ }
627+ return nil
628+ }
629+
630+ // acquireStream acquires the stream and replays all events since lastIdx, if
631+ // any, updating lastIdx accordingly. If non-nil, the resulting stream will be
632+ // registered for receiving new messages, and the resulting done channel will
633+ // be closed when all related messages have been delivered.
634+ //
635+ // If any errors occur, they will be written to w and the resulting stream will
636+ // be nil. The resulting stream may also be nil if the stream is complete.
637+ //
638+ // Importantly, this function must hold the stream mutex until done replaying
639+ // all messages, so that no delivery or storage of new messages occurs while
640+ // the stream is still replaying.
641+ func (c * streamableServerConn ) acquireStream (ctx context.Context , w http.ResponseWriter , streamID string , lastIdx * int ) (* stream , chan struct {}) {
588642 c .mu .Lock ()
589643 stream , ok := c .streams [streamID ]
590644 c .mu .Unlock ()
591645 if ! ok {
592646 http .Error (w , "unknown stream" , http .StatusBadRequest )
593- return
647+ return nil , nil
594648 }
595649
596- write := func (data []byte ) error {
597- lastIdx ++
598- e := Event {
599- Name : "message" ,
600- Data : data ,
601- }
602- if c .eventStore != nil {
603- e .ID = formatEventID (stream .id , lastIdx )
604- }
605- if _ , err := writeEvent (w , e ); err != nil {
606- return err
607- }
608- return nil
609- }
610-
611- ctx , cancel := context .WithCancel (req .Context ())
612- defer cancel ()
613-
614- // Acquire the stream.
615- //
616- // Importantly, hold the mutex until we've replayed all messages, since we
617- // don't want to allow delivery (or storage) of new messages until we've
618- // replayed everything that has been stored thus far.
619650 stream .mu .Lock ()
651+ defer stream .mu .Unlock ()
620652 if stream .deliver != nil {
621- stream .mu .Unlock ()
622653 http .Error (w , "stream ID conflicts with ongoing stream" , http .StatusConflict )
623- return
654+ return nil , nil
624655 }
625656
626657 // Collect events to replay. Collect them all before writing, so that we
@@ -631,9 +662,8 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
631662 // messages, and registered our delivery function.
632663 var toReplay [][]byte
633664 if c .eventStore != nil {
634- for data , err := range c .eventStore .After (ctx , c .SessionID (), stream .id , lastIdx ) {
665+ for data , err := range c .eventStore .After (ctx , c .SessionID (), stream .id , * lastIdx ) {
635666 if err != nil {
636- stream .mu .Unlock ()
637667 // We can't replay events, perhaps because the underlying event store
638668 // has garbage collected its storage.
639669 //
@@ -643,7 +673,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
643673 // 400 is not really accurate, but should at least have no side effects.
644674 // Other SDKs (typescript) do not have a mechanism for events to be purged.
645675 http .Error (w , "failed to replay events" , http .StatusBadRequest )
646- return
676+ return nil , nil
647677 }
648678 toReplay = append (toReplay , data )
649679 }
@@ -656,27 +686,21 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
656686 if stream .id == "" {
657687 // Issue #410: the standalone SSE stream is likely not to receive messages
658688 // for a long time. Ensure that headers are flushed.
659- //
660- // For other requests, delay the writing of the header in case we
661- // may want to set an error status.
662- // (see the TODO: this probably isn't worth it).
663689 w .WriteHeader (http .StatusOK )
664690 if f , ok := w .(http.Flusher ); ok {
665691 f .Flush ()
666692 }
667693 }
668694
669695 for _ , data := range toReplay {
670- if err := write (data ); err != nil {
671- stream .mu .Unlock ()
672- return
696+ if err := c .writeEvent (w , stream , data , lastIdx ); err != nil {
697+ return nil , nil
673698 }
674699 }
675700
676- if stream .done () {
701+ if stream .doneLocked () {
677702 // Nothing more to do.
678- stream .mu .Unlock ()
679- return
703+ return nil , nil
680704 }
681705
682706 // Finally register a delivery function and unlock the stream, allowing the
@@ -686,29 +710,13 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
686710 if err := ctx .Err (); err != nil {
687711 return err
688712 }
689- err := write ( data )
713+ err := c . writeEvent ( w , stream , data , lastIdx )
690714 if final {
691715 close (done )
692716 }
693717 return err
694718 }
695- stream .mu .Unlock ()
696-
697- // Release the stream when we're done.
698- defer func () {
699- stream .mu .Lock ()
700- stream .deliver = nil
701- stream .mu .Unlock ()
702- }()
703-
704- select {
705- case <- ctx .Done ():
706- // request cancelled
707- case <- done :
708- // request complete
709- case <- c .done :
710- // session closed
711- }
719+ return stream , done
712720}
713721
714722// servePOST handles an incoming message, and replies with either an outgoing
@@ -854,23 +862,12 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
854862 }
855863 } else {
856864 // Write events in the order we receive them.
857- lastIndex := 0
865+ lastIndex := - 1
858866 stream .deliver = func (data []byte , final bool ) error {
859867 if final {
860868 defer close (done )
861869 }
862- e := Event {
863- Name : "message" ,
864- Data : data ,
865- }
866- if c .eventStore != nil {
867- e .ID = formatEventID (stream .id , lastIndex )
868- }
869- lastIndex ++
870- if _ , err := writeEvent (w , e ); err != nil {
871- return err
872- }
873- return nil
870+ return c .writeEvent (w , stream , data , & lastIndex )
874871 }
875872 }
876873
@@ -1024,7 +1021,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
10241021
10251022 s .mu .Lock ()
10261023 defer s .mu .Unlock ()
1027- if s .done () {
1024+ if s .doneLocked () {
10281025 return fmt .Errorf ("%w: write to closed stream" , jsonrpc2 .ErrRejected )
10291026 }
10301027 if responseTo .IsValid () {
@@ -1040,7 +1037,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
10401037 }
10411038 }
10421039 if s .deliver != nil {
1043- if err := s .deliver (data , s .done ()); err != nil {
1040+ if err := s .deliver (data , s .doneLocked ()); err != nil {
10441041 // TODO: report a side-channel error.
10451042 } else {
10461043 delivered = true
0 commit comments