Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ func (s *MemoryEventStore) debugString() string {
fmt.Fprintf(&b, "; ")
}
dl := sm[sid]
fmt.Fprintf(&b, "%s %d first=%d", sess, sid, dl.first)
fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first)
for _, d := range dl.data {
fmt.Fprintf(&b, " %s", d)
}
Expand Down
70 changes: 35 additions & 35 deletions mcp/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,21 @@ func TestMemoryEventStoreState(t *testing.T) {
{
"appends",
func(s *MemoryEventStore) {
appendEvent(s, "S1", 1, "d1")
appendEvent(s, "S1", 2, "d2")
appendEvent(s, "S1", 1, "d3")
appendEvent(s, "S2", 8, "d4")
appendEvent(s, "S1", "1", "d1")
appendEvent(s, "S1", "2", "d2")
appendEvent(s, "S1", "1", "d3")
appendEvent(s, "S2", "8", "d4")
},
"S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4",
8,
},
{
"session close",
func(s *MemoryEventStore) {
appendEvent(s, "S1", 1, "d1")
appendEvent(s, "S1", 2, "d2")
appendEvent(s, "S1", 1, "d3")
appendEvent(s, "S2", 8, "d4")
appendEvent(s, "S1", "1", "d1")
appendEvent(s, "S1", "2", "d2")
appendEvent(s, "S1", "1", "d3")
appendEvent(s, "S2", "8", "d4")
s.SessionClosed(ctx, "S1")
},
"S2 8 first=0 d4",
Expand All @@ -142,10 +142,10 @@ func TestMemoryEventStoreState(t *testing.T) {
{
"purge",
func(s *MemoryEventStore) {
appendEvent(s, "S1", 1, "d1")
appendEvent(s, "S1", 2, "d2")
appendEvent(s, "S1", 1, "d3")
appendEvent(s, "S2", 8, "d4")
appendEvent(s, "S1", "1", "d1")
appendEvent(s, "S1", "2", "d2")
appendEvent(s, "S1", "1", "d3")
appendEvent(s, "S2", "8", "d4")
// We are using 8 bytes (d1,d2, d3, d4).
// To purge 6, we remove the first of each stream, leaving only d3.
s.SetMaxBytes(2)
Expand All @@ -157,31 +157,31 @@ func TestMemoryEventStoreState(t *testing.T) {
{
"purge append",
func(s *MemoryEventStore) {
appendEvent(s, "S1", 1, "d1")
appendEvent(s, "S1", 2, "d2")
appendEvent(s, "S1", 1, "d3")
appendEvent(s, "S2", 8, "d4")
appendEvent(s, "S1", "1", "d1")
appendEvent(s, "S1", "2", "d2")
appendEvent(s, "S1", "1", "d3")
appendEvent(s, "S2", "8", "d4")
s.SetMaxBytes(2)
// Up to here, identical to the "purge" case.
// Each of these additions will result in a purge.
appendEvent(s, "S1", 2, "d5") // remove d3
appendEvent(s, "S1", 2, "d6") // remove d5
appendEvent(s, "S1", "2", "d5") // remove d3
appendEvent(s, "S1", "2", "d6") // remove d5
},
"S1 1 first=2; S1 2 first=2 d6; S2 8 first=1",
2,
},
{
"purge resize append",
func(s *MemoryEventStore) {
appendEvent(s, "S1", 1, "d1")
appendEvent(s, "S1", 2, "d2")
appendEvent(s, "S1", 1, "d3")
appendEvent(s, "S2", 8, "d4")
appendEvent(s, "S1", "1", "d1")
appendEvent(s, "S1", "2", "d2")
appendEvent(s, "S1", "1", "d3")
appendEvent(s, "S2", "8", "d4")
s.SetMaxBytes(2)
// Up to here, identical to the "purge" case.
s.SetMaxBytes(6) // make room
appendEvent(s, "S1", 2, "d5")
appendEvent(s, "S1", 2, "d6")
appendEvent(s, "S1", "2", "d5")
appendEvent(s, "S1", "2", "d6")
},
// The other streams remain, because we may add to them.
"S1 1 first=1 d3; S1 2 first=1 d5 d6; S2 8 first=1",
Expand All @@ -206,10 +206,10 @@ func TestMemoryEventStoreAfter(t *testing.T) {
ctx := context.Background()
s := NewMemoryEventStore(nil)
s.SetMaxBytes(4)
s.Append(ctx, "S1", 1, []byte("d1"))
s.Append(ctx, "S1", 1, []byte("d2"))
s.Append(ctx, "S1", 1, []byte("d3"))
s.Append(ctx, "S1", 2, []byte("d4")) // will purge d1
s.Append(ctx, "S1", "1", []byte("d1"))
s.Append(ctx, "S1", "1", []byte("d2"))
s.Append(ctx, "S1", "1", []byte("d3"))
s.Append(ctx, "S1", "2", []byte("d4")) // will purge d1
want := "S1 1 first=1 d2 d3; S1 2 first=0 d4"
if got := s.debugString(); got != want {
t.Fatalf("got state %q, want %q", got, want)
Expand All @@ -222,14 +222,14 @@ func TestMemoryEventStoreAfter(t *testing.T) {
want []string
wantErr string // if non-empty, error should contain this string
}{
{"S1", 1, 0, []string{"d2", "d3"}, ""},
{"S1", 1, 1, []string{"d3"}, ""},
{"S1", 1, 2, nil, ""},
{"S1", 2, 0, nil, ""},
{"S1", 3, 0, nil, "unknown stream ID"},
{"S2", 0, 0, nil, "unknown session ID"},
{"S1", "1", 0, []string{"d2", "d3"}, ""},
{"S1", "1", 1, []string{"d3"}, ""},
{"S1", "1", 2, nil, ""},
{"S1", "2", 0, nil, ""},
{"S1", "3", 0, nil, "unknown stream ID"},
{"S2", "0", 0, nil, "unknown session ID"},
} {
t.Run(fmt.Sprintf("%s-%d-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) {
t.Run(fmt.Sprintf("%s-%s-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) {
var got []string
for d, err := range s.After(ctx, tt.sessionID, tt.streamID, tt.index) {
if err != nil {
Expand Down
37 changes: 17 additions & 20 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
//
// It is always text/event-stream, since it must carry arbitrarily many
// messages.
t.connection.streams[0] = newStream(0, false)
t.connection.streams[""] = newStream("", false)
if t.connection.eventStore == nil {
t.connection.eventStore = NewMemoryEventStore(nil)
}
Expand Down Expand Up @@ -331,7 +331,7 @@ func (c *streamableServerConn) SessionID() string {
// at any time.
type stream struct {
// id is the logical ID for the stream, unique within a session.
// ID 0 is used for messages that don't correlate with an incoming request.
// an empty string is used for messages that don't correlate with an incoming request.
id StreamID

// jsonResponse records whether this stream should respond with application/json
Expand Down Expand Up @@ -379,9 +379,9 @@ func signalChanPtr() *chan struct{} {
return &c
}

// A StreamID identifies a stream of SSE events. It is unique within the stream's
// A StreamID identifies a stream of SSE events. It is globally unique.
// [ServerSession].
type StreamID int64
type StreamID string

// We track the incoming request ID inside the handler context using
// idContextValue, so that notifications and server->client calls that occur in
Expand Down Expand Up @@ -431,7 +431,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R
// It returns an HTTP status code and error message.
func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) {
// connID 0 corresponds to the default GET request.
id := StreamID(0)
id := StreamID("")
// By default, we haven't seen a last index. Since indices start at 0, we represent
// that by -1. This is incremented just before each event is written, in streamResponse
// around L407.
Expand Down Expand Up @@ -459,7 +459,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
return
}
defer stream.signal.Store(nil)
persistent := id == 0 // Only the special stream 0 is a hanging get.
persistent := id == "" // Only the special stream "" is a hanging get.
c.respondSSE(stream, w, req, lastIdx, persistent)
}

Expand Down Expand Up @@ -517,7 +517,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
// notifications or server->client requests made in the course of handling.
// Update accounting for this incoming payload.
if len(requests) > 0 {
stream = newStream(StreamID(c.lastStreamID.Add(1)), c.jsonResponse)
stream = newStream(StreamID(randText()), c.jsonResponse)
c.mu.Lock()
c.streams[stream.id] = stream
stream.requests = requests
Expand Down Expand Up @@ -716,7 +716,7 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per
//
// See also [parseEventID].
func formatEventID(sid StreamID, idx int) string {
return fmt.Sprintf("%d_%d", sid, idx)
return fmt.Sprintf("%s_%d", sid, idx)
}

// parseEventID parses a Last-Event-ID value into a logical stream id and
Expand All @@ -726,15 +726,12 @@ func formatEventID(sid StreamID, idx int) string {
func parseEventID(eventID string) (sid StreamID, idx int, ok bool) {
parts := strings.Split(eventID, "_")
if len(parts) != 2 {
return 0, 0, false
return "", 0, false
}
stream, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil || stream < 0 {
return 0, 0, false
}
idx, err = strconv.Atoi(parts[1])
stream := StreamID(parts[0])
idx, err := strconv.Atoi(parts[1])
if err != nil || idx < 0 {
return 0, 0, false
return "", 0, false
}
return StreamID(stream), idx, true
}
Expand Down Expand Up @@ -775,7 +772,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
// Find the logical connection corresponding to this request.
//
// For messages sent outside of a request context, this is the default
// connection 0.
// connection "".
var forStream StreamID
if forRequest.IsValid() {
c.mu.Lock()
Expand All @@ -796,7 +793,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e

stream := c.streams[forStream]
if stream == nil {
return fmt.Errorf("no stream with ID %d", forStream)
return fmt.Errorf("no stream with ID %s", forStream)
}

// Special case a few conditions where we fall back on stream 0 (the hanging GET):
Expand All @@ -806,11 +803,11 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
//
// TODO(rfindley): either of these, particularly the first, might be
// considered a bug in the server. Report it through a side-channel?
if len(stream.requests) == 0 && forStream != 0 || stream.jsonResponse && !isResponse {
stream = c.streams[0]
if len(stream.requests) == 0 && forStream != "" || stream.jsonResponse && !isResponse {
stream = c.streams[""]
}

// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == 0
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == ""
// and the client never did a GET), then memory will grow without bound. Consider a mitigation.
stream.outgoing = append(stream.outgoing, data)
if isResponse {
Expand Down
18 changes: 8 additions & 10 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,22 +784,23 @@ func TestEventID(t *testing.T) {
sid StreamID
idx int
}{
{0, 0},
{0, 1},
{1, 0},
{1, 1},
{1234, 5678},
{"0", 0},
{"0", 1},
{"1", 0},
{"1", 1},
{"", 1},
{"1234", 5678},
}

for _, test := range tests {
t.Run(fmt.Sprintf("%d_%d", test.sid, test.idx), func(t *testing.T) {
t.Run(fmt.Sprintf("%s_%d", test.sid, test.idx), func(t *testing.T) {
eventID := formatEventID(test.sid, test.idx)
gotSID, gotIdx, ok := parseEventID(eventID)
if !ok {
t.Fatalf("parseEventID(%q) failed, want ok", eventID)
}
if gotSID != test.sid || gotIdx != test.idx {
t.Errorf("parseEventID(%q) = %d, %d, want %d, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
t.Errorf("parseEventID(%q) = %s, %d, want %s, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
}
})
}
Expand All @@ -808,10 +809,7 @@ func TestEventID(t *testing.T) {
"",
"_",
"1_",
"_1",
"a_1",
"1_a",
"-1_1",
"1_-1",
}

Expand Down