Skip to content

Commit 0a8fe40

Browse files
2bitburritofindleyr
authored andcommitted
mcp: changed streamID's from int64 to random strings (#266)
- Changed StreamID's to store randomly generated strings. - Updated all tests. - Resolved conflicts
1 parent 54c9981 commit 0a8fe40

File tree

4 files changed

+61
-66
lines changed

4 files changed

+61
-66
lines changed

mcp/event.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ func (s *MemoryEventStore) debugString() string {
392392
fmt.Fprintf(&b, "; ")
393393
}
394394
dl := sm[sid]
395-
fmt.Fprintf(&b, "%s %d first=%d", sess, sid, dl.first)
395+
fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first)
396396
for _, d := range dl.data {
397397
fmt.Fprintf(&b, " %s", d)
398398
}

mcp/event_test.go

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,21 @@ func TestMemoryEventStoreState(t *testing.T) {
119119
{
120120
"appends",
121121
func(s *MemoryEventStore) {
122-
appendEvent(s, "S1", 1, "d1")
123-
appendEvent(s, "S1", 2, "d2")
124-
appendEvent(s, "S1", 1, "d3")
125-
appendEvent(s, "S2", 8, "d4")
122+
appendEvent(s, "S1", "1", "d1")
123+
appendEvent(s, "S1", "2", "d2")
124+
appendEvent(s, "S1", "1", "d3")
125+
appendEvent(s, "S2", "8", "d4")
126126
},
127127
"S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4",
128128
8,
129129
},
130130
{
131131
"session close",
132132
func(s *MemoryEventStore) {
133-
appendEvent(s, "S1", 1, "d1")
134-
appendEvent(s, "S1", 2, "d2")
135-
appendEvent(s, "S1", 1, "d3")
136-
appendEvent(s, "S2", 8, "d4")
133+
appendEvent(s, "S1", "1", "d1")
134+
appendEvent(s, "S1", "2", "d2")
135+
appendEvent(s, "S1", "1", "d3")
136+
appendEvent(s, "S2", "8", "d4")
137137
s.SessionClosed(ctx, "S1")
138138
},
139139
"S2 8 first=0 d4",
@@ -142,10 +142,10 @@ func TestMemoryEventStoreState(t *testing.T) {
142142
{
143143
"purge",
144144
func(s *MemoryEventStore) {
145-
appendEvent(s, "S1", 1, "d1")
146-
appendEvent(s, "S1", 2, "d2")
147-
appendEvent(s, "S1", 1, "d3")
148-
appendEvent(s, "S2", 8, "d4")
145+
appendEvent(s, "S1", "1", "d1")
146+
appendEvent(s, "S1", "2", "d2")
147+
appendEvent(s, "S1", "1", "d3")
148+
appendEvent(s, "S2", "8", "d4")
149149
// We are using 8 bytes (d1,d2, d3, d4).
150150
// To purge 6, we remove the first of each stream, leaving only d3.
151151
s.SetMaxBytes(2)
@@ -157,31 +157,31 @@ func TestMemoryEventStoreState(t *testing.T) {
157157
{
158158
"purge append",
159159
func(s *MemoryEventStore) {
160-
appendEvent(s, "S1", 1, "d1")
161-
appendEvent(s, "S1", 2, "d2")
162-
appendEvent(s, "S1", 1, "d3")
163-
appendEvent(s, "S2", 8, "d4")
160+
appendEvent(s, "S1", "1", "d1")
161+
appendEvent(s, "S1", "2", "d2")
162+
appendEvent(s, "S1", "1", "d3")
163+
appendEvent(s, "S2", "8", "d4")
164164
s.SetMaxBytes(2)
165165
// Up to here, identical to the "purge" case.
166166
// Each of these additions will result in a purge.
167-
appendEvent(s, "S1", 2, "d5") // remove d3
168-
appendEvent(s, "S1", 2, "d6") // remove d5
167+
appendEvent(s, "S1", "2", "d5") // remove d3
168+
appendEvent(s, "S1", "2", "d6") // remove d5
169169
},
170170
"S1 1 first=2; S1 2 first=2 d6; S2 8 first=1",
171171
2,
172172
},
173173
{
174174
"purge resize append",
175175
func(s *MemoryEventStore) {
176-
appendEvent(s, "S1", 1, "d1")
177-
appendEvent(s, "S1", 2, "d2")
178-
appendEvent(s, "S1", 1, "d3")
179-
appendEvent(s, "S2", 8, "d4")
176+
appendEvent(s, "S1", "1", "d1")
177+
appendEvent(s, "S1", "2", "d2")
178+
appendEvent(s, "S1", "1", "d3")
179+
appendEvent(s, "S2", "8", "d4")
180180
s.SetMaxBytes(2)
181181
// Up to here, identical to the "purge" case.
182182
s.SetMaxBytes(6) // make room
183-
appendEvent(s, "S1", 2, "d5")
184-
appendEvent(s, "S1", 2, "d6")
183+
appendEvent(s, "S1", "2", "d5")
184+
appendEvent(s, "S1", "2", "d6")
185185
},
186186
// The other streams remain, because we may add to them.
187187
"S1 1 first=1 d3; S1 2 first=1 d5 d6; S2 8 first=1",
@@ -206,10 +206,10 @@ func TestMemoryEventStoreAfter(t *testing.T) {
206206
ctx := context.Background()
207207
s := NewMemoryEventStore(nil)
208208
s.SetMaxBytes(4)
209-
s.Append(ctx, "S1", 1, []byte("d1"))
210-
s.Append(ctx, "S1", 1, []byte("d2"))
211-
s.Append(ctx, "S1", 1, []byte("d3"))
212-
s.Append(ctx, "S1", 2, []byte("d4")) // will purge d1
209+
s.Append(ctx, "S1", "1", []byte("d1"))
210+
s.Append(ctx, "S1", "1", []byte("d2"))
211+
s.Append(ctx, "S1", "1", []byte("d3"))
212+
s.Append(ctx, "S1", "2", []byte("d4")) // will purge d1
213213
want := "S1 1 first=1 d2 d3; S1 2 first=0 d4"
214214
if got := s.debugString(); got != want {
215215
t.Fatalf("got state %q, want %q", got, want)
@@ -222,14 +222,14 @@ func TestMemoryEventStoreAfter(t *testing.T) {
222222
want []string
223223
wantErr string // if non-empty, error should contain this string
224224
}{
225-
{"S1", 1, 0, []string{"d2", "d3"}, ""},
226-
{"S1", 1, 1, []string{"d3"}, ""},
227-
{"S1", 1, 2, nil, ""},
228-
{"S1", 2, 0, nil, ""},
229-
{"S1", 3, 0, nil, "unknown stream ID"},
230-
{"S2", 0, 0, nil, "unknown session ID"},
225+
{"S1", "1", 0, []string{"d2", "d3"}, ""},
226+
{"S1", "1", 1, []string{"d3"}, ""},
227+
{"S1", "1", 2, nil, ""},
228+
{"S1", "2", 0, nil, ""},
229+
{"S1", "3", 0, nil, "unknown stream ID"},
230+
{"S2", "0", 0, nil, "unknown session ID"},
231231
} {
232-
t.Run(fmt.Sprintf("%s-%d-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) {
232+
t.Run(fmt.Sprintf("%s-%s-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) {
233233
var got []string
234234
for d, err := range s.After(ctx, tt.sessionID, tt.streamID, tt.index) {
235235
if err != nil {

mcp/streamable.go

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
276276
//
277277
// It is always text/event-stream, since it must carry arbitrarily many
278278
// messages.
279-
t.connection.streams[0] = newStream(0, false)
279+
t.connection.streams[""] = newStream("", false)
280280
if t.connection.eventStore == nil {
281281
t.connection.eventStore = NewMemoryEventStore(nil)
282282
}
@@ -334,7 +334,7 @@ func (c *streamableServerConn) SessionID() string {
334334
// at any time.
335335
type stream struct {
336336
// id is the logical ID for the stream, unique within a session.
337-
// ID 0 is used for messages that don't correlate with an incoming request.
337+
// an empty string is used for messages that don't correlate with an incoming request.
338338
id StreamID
339339

340340
// jsonResponse records whether this stream should respond with application/json
@@ -382,9 +382,9 @@ func signalChanPtr() *chan struct{} {
382382
return &c
383383
}
384384

385-
// A StreamID identifies a stream of SSE events. It is unique within the stream's
385+
// A StreamID identifies a stream of SSE events. It is globally unique.
386386
// [ServerSession].
387-
type StreamID int64
387+
type StreamID string
388388

389389
// We track the incoming request ID inside the handler context using
390390
// idContextValue, so that notifications and server->client calls that occur in
@@ -434,7 +434,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R
434434
// It returns an HTTP status code and error message.
435435
func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) {
436436
// connID 0 corresponds to the default GET request.
437-
id := StreamID(0)
437+
id := StreamID("")
438438
// By default, we haven't seen a last index. Since indices start at 0, we represent
439439
// that by -1. This is incremented just before each event is written, in streamResponse
440440
// around L407.
@@ -462,7 +462,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
462462
return
463463
}
464464
defer stream.signal.Store(nil)
465-
persistent := id == 0 // Only the special stream 0 is a hanging get.
465+
persistent := id == "" // Only the special stream "" is a hanging get.
466466
c.respondSSE(stream, w, req, lastIdx, persistent)
467467
}
468468

@@ -520,7 +520,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
520520
// notifications or server->client requests made in the course of handling.
521521
// Update accounting for this incoming payload.
522522
if len(requests) > 0 {
523-
stream = newStream(StreamID(c.lastStreamID.Add(1)), c.jsonResponse)
523+
stream = newStream(StreamID(randText()), c.jsonResponse)
524524
c.mu.Lock()
525525
c.streams[stream.id] = stream
526526
stream.requests = requests
@@ -719,7 +719,7 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per
719719
//
720720
// See also [parseEventID].
721721
func formatEventID(sid StreamID, idx int) string {
722-
return fmt.Sprintf("%d_%d", sid, idx)
722+
return fmt.Sprintf("%s_%d", sid, idx)
723723
}
724724

725725
// parseEventID parses a Last-Event-ID value into a logical stream id and
@@ -729,15 +729,12 @@ func formatEventID(sid StreamID, idx int) string {
729729
func parseEventID(eventID string) (sid StreamID, idx int, ok bool) {
730730
parts := strings.Split(eventID, "_")
731731
if len(parts) != 2 {
732-
return 0, 0, false
732+
return "", 0, false
733733
}
734-
stream, err := strconv.ParseInt(parts[0], 10, 64)
735-
if err != nil || stream < 0 {
736-
return 0, 0, false
737-
}
738-
idx, err = strconv.Atoi(parts[1])
734+
stream := StreamID(parts[0])
735+
idx, err := strconv.Atoi(parts[1])
739736
if err != nil || idx < 0 {
740-
return 0, 0, false
737+
return "", 0, false
741738
}
742739
return StreamID(stream), idx, true
743740
}
@@ -778,7 +775,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
778775
// Find the logical connection corresponding to this request.
779776
//
780777
// For messages sent outside of a request context, this is the default
781-
// connection 0.
778+
// connection "".
782779
var forStream StreamID
783780
if forRequest.IsValid() {
784781
c.mu.Lock()
@@ -799,7 +796,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
799796

800797
stream := c.streams[forStream]
801798
if stream == nil {
802-
return fmt.Errorf("no stream with ID %d", forStream)
799+
return fmt.Errorf("no stream with ID %s", forStream)
803800
}
804801

805802
// Special case a few conditions where we fall back on stream 0 (the hanging GET):
@@ -809,11 +806,11 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
809806
//
810807
// TODO(rfindley): either of these, particularly the first, might be
811808
// considered a bug in the server. Report it through a side-channel?
812-
if len(stream.requests) == 0 && forStream != 0 || stream.jsonResponse && !isResponse {
813-
stream = c.streams[0]
809+
if len(stream.requests) == 0 && forStream != "" || stream.jsonResponse && !isResponse {
810+
stream = c.streams[""]
814811
}
815812

816-
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == 0
813+
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == ""
817814
// and the client never did a GET), then memory will grow without bound. Consider a mitigation.
818815
stream.outgoing = append(stream.outgoing, data)
819816
if isResponse {

mcp/streamable_test.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -888,22 +888,23 @@ func TestEventID(t *testing.T) {
888888
sid StreamID
889889
idx int
890890
}{
891-
{0, 0},
892-
{0, 1},
893-
{1, 0},
894-
{1, 1},
895-
{1234, 5678},
891+
{"0", 0},
892+
{"0", 1},
893+
{"1", 0},
894+
{"1", 1},
895+
{"", 1},
896+
{"1234", 5678},
896897
}
897898

898899
for _, test := range tests {
899-
t.Run(fmt.Sprintf("%d_%d", test.sid, test.idx), func(t *testing.T) {
900+
t.Run(fmt.Sprintf("%s_%d", test.sid, test.idx), func(t *testing.T) {
900901
eventID := formatEventID(test.sid, test.idx)
901902
gotSID, gotIdx, ok := parseEventID(eventID)
902903
if !ok {
903904
t.Fatalf("parseEventID(%q) failed, want ok", eventID)
904905
}
905906
if gotSID != test.sid || gotIdx != test.idx {
906-
t.Errorf("parseEventID(%q) = %d, %d, want %d, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
907+
t.Errorf("parseEventID(%q) = %s, %d, want %s, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
907908
}
908909
})
909910
}
@@ -912,10 +913,7 @@ func TestEventID(t *testing.T) {
912913
"",
913914
"_",
914915
"1_",
915-
"_1",
916-
"a_1",
917916
"1_a",
918-
"-1_1",
919917
"1_-1",
920918
}
921919

0 commit comments

Comments
 (0)