Skip to content

Commit d145124

Browse files
committed
address review comments
1 parent e16bd09 commit d145124

File tree

1 file changed

+72
-75
lines changed

1 file changed

+72
-75
lines changed

mcp/streamable.go

Lines changed: 72 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ type stream struct {
491491
// If non-nil, deliver writes data directly to the HTTP response.
492492
//
493493
// Only one HTTP response may receive messages at a given time. An active
494-
// HTTP connection acquires ownership of the stream by setting
494+
// HTTP connection acquires ownership of the stream by setting this field.
495495
deliver func(data []byte, final bool) error
496496

497497
// streamRequests is the set of unanswered incoming requests for the stream.
@@ -500,10 +500,10 @@ type stream struct {
500500
requests map[jsonrpc.ID]struct{}
501501
}
502502

503-
// done reports whether the stream is logically complete.
503+
// doneLocked reports whether the stream is logically complete.
504504
//
505505
// s.mu must be held while calling this function.
506-
func (s *stream) done() bool {
506+
func (s *stream) doneLocked() bool {
507507
return len(s.requests) == 0 && s.id != ""
508508
}
509509

@@ -585,42 +585,73 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
585585
}
586586
}
587587

588+
ctx, cancel := context.WithCancel(req.Context())
589+
defer cancel()
590+
591+
stream, done := c.acquireStream(ctx, w, streamID, &lastIdx)
592+
if stream == nil {
593+
return
594+
}
595+
// Release the stream when we're done.
596+
defer func() {
597+
stream.mu.Lock()
598+
stream.deliver = nil
599+
stream.mu.Unlock()
600+
}()
601+
602+
select {
603+
case <-ctx.Done():
604+
// request cancelled
605+
case <-done:
606+
// request complete
607+
case <-c.done:
608+
// session closed
609+
}
610+
}
611+
612+
// writeEvent writes an SSE event to w corresponding to the given stream, data, and index.
613+
// lastIdx is incremented before writing, so that it continues to point to the index of the
614+
// last event written to the stream.
615+
func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, data []byte, lastIdx *int) error {
616+
*lastIdx++
617+
e := Event{
618+
Name: "message",
619+
Data: data,
620+
}
621+
if c.eventStore != nil {
622+
e.ID = formatEventID(stream.id, *lastIdx)
623+
}
624+
if _, err := writeEvent(w, e); err != nil {
625+
return err
626+
}
627+
return nil
628+
}
629+
630+
// acquireStream acquires the stream and replays all events since lastIdx, if
631+
// any, updating lastIdx accordingly. If non-nil, the resulting stream will be
632+
// registered for receiving new messages, and the resulting done channel will
633+
// be closed when all related messages have been delivered.
634+
//
635+
// If any errors occur, they will be written to w and the resulting stream will
636+
// be nil. The resulting stream may also be nil if the stream is complete.
637+
//
638+
// Importantly, this function must hold the stream mutex until done replaying
639+
// all messages, so that no delivery or storage of new messages occurs while
640+
// the stream is still replaying.
641+
func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) {
588642
c.mu.Lock()
589643
stream, ok := c.streams[streamID]
590644
c.mu.Unlock()
591645
if !ok {
592646
http.Error(w, "unknown stream", http.StatusBadRequest)
593-
return
647+
return nil, nil
594648
}
595649

596-
write := func(data []byte) error {
597-
lastIdx++
598-
e := Event{
599-
Name: "message",
600-
Data: data,
601-
}
602-
if c.eventStore != nil {
603-
e.ID = formatEventID(stream.id, lastIdx)
604-
}
605-
if _, err := writeEvent(w, e); err != nil {
606-
return err
607-
}
608-
return nil
609-
}
610-
611-
ctx, cancel := context.WithCancel(req.Context())
612-
defer cancel()
613-
614-
// Acquire the stream.
615-
//
616-
// Importantly, hold the mutex until we've replayed all messages, since we
617-
// don't want to allow delivery (or storage) of new messages until we've
618-
// replayed everything that has been stored thus far.
619650
stream.mu.Lock()
651+
defer stream.mu.Unlock()
620652
if stream.deliver != nil {
621-
stream.mu.Unlock()
622653
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict)
623-
return
654+
return nil, nil
624655
}
625656

626657
// Collect events to replay. Collect them all before writing, so that we
@@ -631,9 +662,8 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
631662
// messages, and registered our delivery function.
632663
var toReplay [][]byte
633664
if c.eventStore != nil {
634-
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIdx) {
665+
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, *lastIdx) {
635666
if err != nil {
636-
stream.mu.Unlock()
637667
// We can't replay events, perhaps because the underlying event store
638668
// has garbage collected its storage.
639669
//
@@ -643,7 +673,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
643673
// 400 is not really accurate, but should at least have no side effects.
644674
// Other SDKs (typescript) do not have a mechanism for events to be purged.
645675
http.Error(w, "failed to replay events", http.StatusBadRequest)
646-
return
676+
return nil, nil
647677
}
648678
toReplay = append(toReplay, data)
649679
}
@@ -656,27 +686,21 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
656686
if stream.id == "" {
657687
// Issue #410: the standalone SSE stream is likely not to receive messages
658688
// for a long time. Ensure that headers are flushed.
659-
//
660-
// For other requests, delay the writing of the header in case we
661-
// may want to set an error status.
662-
// (see the TODO: this probably isn't worth it).
663689
w.WriteHeader(http.StatusOK)
664690
if f, ok := w.(http.Flusher); ok {
665691
f.Flush()
666692
}
667693
}
668694

669695
for _, data := range toReplay {
670-
if err := write(data); err != nil {
671-
stream.mu.Unlock()
672-
return
696+
if err := c.writeEvent(w, stream, data, lastIdx); err != nil {
697+
return nil, nil
673698
}
674699
}
675700

676-
if stream.done() {
701+
if stream.doneLocked() {
677702
// Nothing more to do.
678-
stream.mu.Unlock()
679-
return
703+
return nil, nil
680704
}
681705

682706
// Finally register a delivery function and unlock the stream, allowing the
@@ -686,29 +710,13 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
686710
if err := ctx.Err(); err != nil {
687711
return err
688712
}
689-
err := write(data)
713+
err := c.writeEvent(w, stream, data, lastIdx)
690714
if final {
691715
close(done)
692716
}
693717
return err
694718
}
695-
stream.mu.Unlock()
696-
697-
// Release the stream when we're done.
698-
defer func() {
699-
stream.mu.Lock()
700-
stream.deliver = nil
701-
stream.mu.Unlock()
702-
}()
703-
704-
select {
705-
case <-ctx.Done():
706-
// request cancelled
707-
case <-done:
708-
// request complete
709-
case <-c.done:
710-
// session closed
711-
}
719+
return stream, done
712720
}
713721

714722
// servePOST handles an incoming message, and replies with either an outgoing
@@ -854,23 +862,12 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
854862
}
855863
} else {
856864
// Write events in the order we receive them.
857-
lastIndex := 0
865+
lastIndex := -1
858866
stream.deliver = func(data []byte, final bool) error {
859867
if final {
860868
defer close(done)
861869
}
862-
e := Event{
863-
Name: "message",
864-
Data: data,
865-
}
866-
if c.eventStore != nil {
867-
e.ID = formatEventID(stream.id, lastIndex)
868-
}
869-
lastIndex++
870-
if _, err := writeEvent(w, e); err != nil {
871-
return err
872-
}
873-
return nil
870+
return c.writeEvent(w, stream, data, &lastIndex)
874871
}
875872
}
876873

@@ -1024,7 +1021,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
10241021

10251022
s.mu.Lock()
10261023
defer s.mu.Unlock()
1027-
if s.done() {
1024+
if s.doneLocked() {
10281025
return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected)
10291026
}
10301027
if responseTo.IsValid() {
@@ -1040,7 +1037,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
10401037
}
10411038
}
10421039
if s.deliver != nil {
1043-
if err := s.deliver(data, s.done()); err != nil {
1040+
if err := s.deliver(data, s.doneLocked()); err != nil {
10441041
// TODO: report a side-channel error.
10451042
} else {
10461043
delivered = true

0 commit comments

Comments
 (0)