From b468679362a4291baa70662b0d539944c42a35d2 Mon Sep 17 00:00:00 2001 From: Hugh Palmer Date: Fri, 8 Aug 2025 11:51:53 +0200 Subject: [PATCH] mcp: changed streamID's from int64 to random strings (#266) - Changed StreamID's to store randomly generated strings. - Updated all tests. - Resolved conflicts --- mcp/event.go | 2 +- mcp/event_test.go | 70 +++++++++++++++++++++--------------------- mcp/streamable.go | 37 ++++++++++------------ mcp/streamable_test.go | 18 +++++------ 4 files changed, 61 insertions(+), 66 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index 9092da76..f4f4eeea 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -392,7 +392,7 @@ func (s *MemoryEventStore) debugString() string { fmt.Fprintf(&b, "; ") } dl := sm[sid] - fmt.Fprintf(&b, "%s %d first=%d", sess, sid, dl.first) + fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first) for _, d := range dl.data { fmt.Fprintf(&b, " %s", d) } diff --git a/mcp/event_test.go b/mcp/event_test.go index 147a947a..601e8300 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -119,10 +119,10 @@ func TestMemoryEventStoreState(t *testing.T) { { "appends", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") }, "S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4", 8, @@ -130,10 +130,10 @@ func TestMemoryEventStoreState(t *testing.T) { { "session close", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") s.SessionClosed(ctx, "S1") }, "S2 8 first=0 d4", @@ -142,10 +142,10 @@ func TestMemoryEventStoreState(t *testing.T) { { "purge", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") // We are using 8 bytes (d1,d2, d3, d4). // To purge 6, we remove the first of each stream, leaving only d3. s.SetMaxBytes(2) @@ -157,15 +157,15 @@ func TestMemoryEventStoreState(t *testing.T) { { "purge append", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") s.SetMaxBytes(2) // Up to here, identical to the "purge" case. // Each of these additions will result in a purge. - appendEvent(s, "S1", 2, "d5") // remove d3 - appendEvent(s, "S1", 2, "d6") // remove d5 + appendEvent(s, "S1", "2", "d5") // remove d3 + appendEvent(s, "S1", "2", "d6") // remove d5 }, "S1 1 first=2; S1 2 first=2 d6; S2 8 first=1", 2, @@ -173,15 +173,15 @@ func TestMemoryEventStoreState(t *testing.T) { { "purge resize append", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") s.SetMaxBytes(2) // Up to here, identical to the "purge" case. s.SetMaxBytes(6) // make room - appendEvent(s, "S1", 2, "d5") - appendEvent(s, "S1", 2, "d6") + appendEvent(s, "S1", "2", "d5") + appendEvent(s, "S1", "2", "d6") }, // The other streams remain, because we may add to them. "S1 1 first=1 d3; S1 2 first=1 d5 d6; S2 8 first=1", @@ -206,10 +206,10 @@ func TestMemoryEventStoreAfter(t *testing.T) { ctx := context.Background() s := NewMemoryEventStore(nil) s.SetMaxBytes(4) - s.Append(ctx, "S1", 1, []byte("d1")) - s.Append(ctx, "S1", 1, []byte("d2")) - s.Append(ctx, "S1", 1, []byte("d3")) - s.Append(ctx, "S1", 2, []byte("d4")) // will purge d1 + s.Append(ctx, "S1", "1", []byte("d1")) + s.Append(ctx, "S1", "1", []byte("d2")) + s.Append(ctx, "S1", "1", []byte("d3")) + s.Append(ctx, "S1", "2", []byte("d4")) // will purge d1 want := "S1 1 first=1 d2 d3; S1 2 first=0 d4" if got := s.debugString(); got != want { t.Fatalf("got state %q, want %q", got, want) @@ -222,14 +222,14 @@ func TestMemoryEventStoreAfter(t *testing.T) { want []string wantErr string // if non-empty, error should contain this string }{ - {"S1", 1, 0, []string{"d2", "d3"}, ""}, - {"S1", 1, 1, []string{"d3"}, ""}, - {"S1", 1, 2, nil, ""}, - {"S1", 2, 0, nil, ""}, - {"S1", 3, 0, nil, "unknown stream ID"}, - {"S2", 0, 0, nil, "unknown session ID"}, + {"S1", "1", 0, []string{"d2", "d3"}, ""}, + {"S1", "1", 1, []string{"d3"}, ""}, + {"S1", "1", 2, nil, ""}, + {"S1", "2", 0, nil, ""}, + {"S1", "3", 0, nil, "unknown stream ID"}, + {"S2", "0", 0, nil, "unknown session ID"}, } { - t.Run(fmt.Sprintf("%s-%d-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) { + t.Run(fmt.Sprintf("%s-%s-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) { var got []string for d, err := range s.After(ctx, tt.sessionID, tt.streamID, tt.index) { if err != nil { diff --git a/mcp/streamable.go b/mcp/streamable.go index 048b99aa..80c973e9 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -273,7 +273,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) // // It is always text/event-stream, since it must carry arbitrarily many // messages. - t.connection.streams[0] = newStream(0, false) + t.connection.streams[""] = newStream("", false) if t.connection.eventStore == nil { t.connection.eventStore = NewMemoryEventStore(nil) } @@ -331,7 +331,7 @@ func (c *streamableServerConn) SessionID() string { // at any time. type stream struct { // id is the logical ID for the stream, unique within a session. - // ID 0 is used for messages that don't correlate with an incoming request. + // an empty string is used for messages that don't correlate with an incoming request. id StreamID // jsonResponse records whether this stream should respond with application/json @@ -379,9 +379,9 @@ func signalChanPtr() *chan struct{} { return &c } -// A StreamID identifies a stream of SSE events. It is unique within the stream's +// A StreamID identifies a stream of SSE events. It is globally unique. // [ServerSession]. -type StreamID int64 +type StreamID string // We track the incoming request ID inside the handler context using // idContextValue, so that notifications and server->client calls that occur in @@ -431,7 +431,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R // It returns an HTTP status code and error message. func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id := StreamID(0) + id := StreamID("") // By default, we haven't seen a last index. Since indices start at 0, we represent // that by -1. This is incremented just before each event is written, in streamResponse // around L407. @@ -459,7 +459,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request return } defer stream.signal.Store(nil) - persistent := id == 0 // Only the special stream 0 is a hanging get. + persistent := id == "" // Only the special stream "" is a hanging get. c.respondSSE(stream, w, req, lastIdx, persistent) } @@ -517,7 +517,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream = newStream(StreamID(c.lastStreamID.Add(1)), c.jsonResponse) + stream = newStream(StreamID(randText()), c.jsonResponse) c.mu.Lock() c.streams[stream.id] = stream stream.requests = requests @@ -716,7 +716,7 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per // // See also [parseEventID]. func formatEventID(sid StreamID, idx int) string { - return fmt.Sprintf("%d_%d", sid, idx) + return fmt.Sprintf("%s_%d", sid, idx) } // parseEventID parses a Last-Event-ID value into a logical stream id and @@ -726,15 +726,12 @@ func formatEventID(sid StreamID, idx int) string { func parseEventID(eventID string) (sid StreamID, idx int, ok bool) { parts := strings.Split(eventID, "_") if len(parts) != 2 { - return 0, 0, false + return "", 0, false } - stream, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil || stream < 0 { - return 0, 0, false - } - idx, err = strconv.Atoi(parts[1]) + stream := StreamID(parts[0]) + idx, err := strconv.Atoi(parts[1]) if err != nil || idx < 0 { - return 0, 0, false + return "", 0, false } return StreamID(stream), idx, true } @@ -775,7 +772,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e // Find the logical connection corresponding to this request. // // For messages sent outside of a request context, this is the default - // connection 0. + // connection "". var forStream StreamID if forRequest.IsValid() { c.mu.Lock() @@ -796,7 +793,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e stream := c.streams[forStream] if stream == nil { - return fmt.Errorf("no stream with ID %d", forStream) + return fmt.Errorf("no stream with ID %s", forStream) } // Special case a few conditions where we fall back on stream 0 (the hanging GET): @@ -806,11 +803,11 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e // // TODO(rfindley): either of these, particularly the first, might be // considered a bug in the server. Report it through a side-channel? - if len(stream.requests) == 0 && forStream != 0 || stream.jsonResponse && !isResponse { - stream = c.streams[0] + if len(stream.requests) == 0 && forStream != "" || stream.jsonResponse && !isResponse { + stream = c.streams[""] } - // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == 0 + // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == "" // and the client never did a GET), then memory will grow without bound. Consider a mitigation. stream.outgoing = append(stream.outgoing, data) if isResponse { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index fd1dc3e4..1941313e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -784,22 +784,23 @@ func TestEventID(t *testing.T) { sid StreamID idx int }{ - {0, 0}, - {0, 1}, - {1, 0}, - {1, 1}, - {1234, 5678}, + {"0", 0}, + {"0", 1}, + {"1", 0}, + {"1", 1}, + {"", 1}, + {"1234", 5678}, } for _, test := range tests { - t.Run(fmt.Sprintf("%d_%d", test.sid, test.idx), func(t *testing.T) { + t.Run(fmt.Sprintf("%s_%d", test.sid, test.idx), func(t *testing.T) { eventID := formatEventID(test.sid, test.idx) gotSID, gotIdx, ok := parseEventID(eventID) if !ok { t.Fatalf("parseEventID(%q) failed, want ok", eventID) } if gotSID != test.sid || gotIdx != test.idx { - t.Errorf("parseEventID(%q) = %d, %d, want %d, %d", eventID, gotSID, gotIdx, test.sid, test.idx) + t.Errorf("parseEventID(%q) = %s, %d, want %s, %d", eventID, gotSID, gotIdx, test.sid, test.idx) } }) } @@ -808,10 +809,7 @@ func TestEventID(t *testing.T) { "", "_", "1_", - "_1", - "a_1", "1_a", - "-1_1", "1_-1", }