Skip to content

Commit b468679

Browse files
committed
mcp: changed streamID's from int64 to random strings (#266)
- Changed StreamID's to store randomly generated strings. - Updated all tests. - Resolved conflicts
1 parent 6e03217 commit b468679

File tree

4 files changed

+61
-66
lines changed

4 files changed

+61
-66
lines changed

mcp/event.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ func (s *MemoryEventStore) debugString() string {
392392
fmt.Fprintf(&b, "; ")
393393
}
394394
dl := sm[sid]
395-
fmt.Fprintf(&b, "%s %d first=%d", sess, sid, dl.first)
395+
fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first)
396396
for _, d := range dl.data {
397397
fmt.Fprintf(&b, " %s", d)
398398
}

mcp/event_test.go

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,21 @@ func TestMemoryEventStoreState(t *testing.T) {
119119
{
120120
"appends",
121121
func(s *MemoryEventStore) {
122-
appendEvent(s, "S1", 1, "d1")
123-
appendEvent(s, "S1", 2, "d2")
124-
appendEvent(s, "S1", 1, "d3")
125-
appendEvent(s, "S2", 8, "d4")
122+
appendEvent(s, "S1", "1", "d1")
123+
appendEvent(s, "S1", "2", "d2")
124+
appendEvent(s, "S1", "1", "d3")
125+
appendEvent(s, "S2", "8", "d4")
126126
},
127127
"S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4",
128128
8,
129129
},
130130
{
131131
"session close",
132132
func(s *MemoryEventStore) {
133-
appendEvent(s, "S1", 1, "d1")
134-
appendEvent(s, "S1", 2, "d2")
135-
appendEvent(s, "S1", 1, "d3")
136-
appendEvent(s, "S2", 8, "d4")
133+
appendEvent(s, "S1", "1", "d1")
134+
appendEvent(s, "S1", "2", "d2")
135+
appendEvent(s, "S1", "1", "d3")
136+
appendEvent(s, "S2", "8", "d4")
137137
s.SessionClosed(ctx, "S1")
138138
},
139139
"S2 8 first=0 d4",
@@ -142,10 +142,10 @@ func TestMemoryEventStoreState(t *testing.T) {
142142
{
143143
"purge",
144144
func(s *MemoryEventStore) {
145-
appendEvent(s, "S1", 1, "d1")
146-
appendEvent(s, "S1", 2, "d2")
147-
appendEvent(s, "S1", 1, "d3")
148-
appendEvent(s, "S2", 8, "d4")
145+
appendEvent(s, "S1", "1", "d1")
146+
appendEvent(s, "S1", "2", "d2")
147+
appendEvent(s, "S1", "1", "d3")
148+
appendEvent(s, "S2", "8", "d4")
149149
// We are using 8 bytes (d1,d2, d3, d4).
150150
// To purge 6, we remove the first of each stream, leaving only d3.
151151
s.SetMaxBytes(2)
@@ -157,31 +157,31 @@ func TestMemoryEventStoreState(t *testing.T) {
157157
{
158158
"purge append",
159159
func(s *MemoryEventStore) {
160-
appendEvent(s, "S1", 1, "d1")
161-
appendEvent(s, "S1", 2, "d2")
162-
appendEvent(s, "S1", 1, "d3")
163-
appendEvent(s, "S2", 8, "d4")
160+
appendEvent(s, "S1", "1", "d1")
161+
appendEvent(s, "S1", "2", "d2")
162+
appendEvent(s, "S1", "1", "d3")
163+
appendEvent(s, "S2", "8", "d4")
164164
s.SetMaxBytes(2)
165165
// Up to here, identical to the "purge" case.
166166
// Each of these additions will result in a purge.
167-
appendEvent(s, "S1", 2, "d5") // remove d3
168-
appendEvent(s, "S1", 2, "d6") // remove d5
167+
appendEvent(s, "S1", "2", "d5") // remove d3
168+
appendEvent(s, "S1", "2", "d6") // remove d5
169169
},
170170
"S1 1 first=2; S1 2 first=2 d6; S2 8 first=1",
171171
2,
172172
},
173173
{
174174
"purge resize append",
175175
func(s *MemoryEventStore) {
176-
appendEvent(s, "S1", 1, "d1")
177-
appendEvent(s, "S1", 2, "d2")
178-
appendEvent(s, "S1", 1, "d3")
179-
appendEvent(s, "S2", 8, "d4")
176+
appendEvent(s, "S1", "1", "d1")
177+
appendEvent(s, "S1", "2", "d2")
178+
appendEvent(s, "S1", "1", "d3")
179+
appendEvent(s, "S2", "8", "d4")
180180
s.SetMaxBytes(2)
181181
// Up to here, identical to the "purge" case.
182182
s.SetMaxBytes(6) // make room
183-
appendEvent(s, "S1", 2, "d5")
184-
appendEvent(s, "S1", 2, "d6")
183+
appendEvent(s, "S1", "2", "d5")
184+
appendEvent(s, "S1", "2", "d6")
185185
},
186186
// The other streams remain, because we may add to them.
187187
"S1 1 first=1 d3; S1 2 first=1 d5 d6; S2 8 first=1",
@@ -206,10 +206,10 @@ func TestMemoryEventStoreAfter(t *testing.T) {
206206
ctx := context.Background()
207207
s := NewMemoryEventStore(nil)
208208
s.SetMaxBytes(4)
209-
s.Append(ctx, "S1", 1, []byte("d1"))
210-
s.Append(ctx, "S1", 1, []byte("d2"))
211-
s.Append(ctx, "S1", 1, []byte("d3"))
212-
s.Append(ctx, "S1", 2, []byte("d4")) // will purge d1
209+
s.Append(ctx, "S1", "1", []byte("d1"))
210+
s.Append(ctx, "S1", "1", []byte("d2"))
211+
s.Append(ctx, "S1", "1", []byte("d3"))
212+
s.Append(ctx, "S1", "2", []byte("d4")) // will purge d1
213213
want := "S1 1 first=1 d2 d3; S1 2 first=0 d4"
214214
if got := s.debugString(); got != want {
215215
t.Fatalf("got state %q, want %q", got, want)
@@ -222,14 +222,14 @@ func TestMemoryEventStoreAfter(t *testing.T) {
222222
want []string
223223
wantErr string // if non-empty, error should contain this string
224224
}{
225-
{"S1", 1, 0, []string{"d2", "d3"}, ""},
226-
{"S1", 1, 1, []string{"d3"}, ""},
227-
{"S1", 1, 2, nil, ""},
228-
{"S1", 2, 0, nil, ""},
229-
{"S1", 3, 0, nil, "unknown stream ID"},
230-
{"S2", 0, 0, nil, "unknown session ID"},
225+
{"S1", "1", 0, []string{"d2", "d3"}, ""},
226+
{"S1", "1", 1, []string{"d3"}, ""},
227+
{"S1", "1", 2, nil, ""},
228+
{"S1", "2", 0, nil, ""},
229+
{"S1", "3", 0, nil, "unknown stream ID"},
230+
{"S2", "0", 0, nil, "unknown session ID"},
231231
} {
232-
t.Run(fmt.Sprintf("%s-%d-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) {
232+
t.Run(fmt.Sprintf("%s-%s-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) {
233233
var got []string
234234
for d, err := range s.After(ctx, tt.sessionID, tt.streamID, tt.index) {
235235
if err != nil {

mcp/streamable.go

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
273273
//
274274
// It is always text/event-stream, since it must carry arbitrarily many
275275
// messages.
276-
t.connection.streams[0] = newStream(0, false)
276+
t.connection.streams[""] = newStream("", false)
277277
if t.connection.eventStore == nil {
278278
t.connection.eventStore = NewMemoryEventStore(nil)
279279
}
@@ -331,7 +331,7 @@ func (c *streamableServerConn) SessionID() string {
331331
// at any time.
332332
type stream struct {
333333
// id is the logical ID for the stream, unique within a session.
334-
// ID 0 is used for messages that don't correlate with an incoming request.
334+
// an empty string is used for messages that don't correlate with an incoming request.
335335
id StreamID
336336

337337
// jsonResponse records whether this stream should respond with application/json
@@ -379,9 +379,9 @@ func signalChanPtr() *chan struct{} {
379379
return &c
380380
}
381381

382-
// A StreamID identifies a stream of SSE events. It is unique within the stream's
382+
// A StreamID identifies a stream of SSE events. It is globally unique.
383383
// [ServerSession].
384-
type StreamID int64
384+
type StreamID string
385385

386386
// We track the incoming request ID inside the handler context using
387387
// idContextValue, so that notifications and server->client calls that occur in
@@ -431,7 +431,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R
431431
// It returns an HTTP status code and error message.
432432
func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) {
433433
// connID 0 corresponds to the default GET request.
434-
id := StreamID(0)
434+
id := StreamID("")
435435
// By default, we haven't seen a last index. Since indices start at 0, we represent
436436
// that by -1. This is incremented just before each event is written, in streamResponse
437437
// around L407.
@@ -459,7 +459,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request
459459
return
460460
}
461461
defer stream.signal.Store(nil)
462-
persistent := id == 0 // Only the special stream 0 is a hanging get.
462+
persistent := id == "" // Only the special stream "" is a hanging get.
463463
c.respondSSE(stream, w, req, lastIdx, persistent)
464464
}
465465

@@ -517,7 +517,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
517517
// notifications or server->client requests made in the course of handling.
518518
// Update accounting for this incoming payload.
519519
if len(requests) > 0 {
520-
stream = newStream(StreamID(c.lastStreamID.Add(1)), c.jsonResponse)
520+
stream = newStream(StreamID(randText()), c.jsonResponse)
521521
c.mu.Lock()
522522
c.streams[stream.id] = stream
523523
stream.requests = requests
@@ -716,7 +716,7 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per
716716
//
717717
// See also [parseEventID].
718718
func formatEventID(sid StreamID, idx int) string {
719-
return fmt.Sprintf("%d_%d", sid, idx)
719+
return fmt.Sprintf("%s_%d", sid, idx)
720720
}
721721

722722
// parseEventID parses a Last-Event-ID value into a logical stream id and
@@ -726,15 +726,12 @@ func formatEventID(sid StreamID, idx int) string {
726726
func parseEventID(eventID string) (sid StreamID, idx int, ok bool) {
727727
parts := strings.Split(eventID, "_")
728728
if len(parts) != 2 {
729-
return 0, 0, false
729+
return "", 0, false
730730
}
731-
stream, err := strconv.ParseInt(parts[0], 10, 64)
732-
if err != nil || stream < 0 {
733-
return 0, 0, false
734-
}
735-
idx, err = strconv.Atoi(parts[1])
731+
stream := StreamID(parts[0])
732+
idx, err := strconv.Atoi(parts[1])
736733
if err != nil || idx < 0 {
737-
return 0, 0, false
734+
return "", 0, false
738735
}
739736
return StreamID(stream), idx, true
740737
}
@@ -775,7 +772,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
775772
// Find the logical connection corresponding to this request.
776773
//
777774
// For messages sent outside of a request context, this is the default
778-
// connection 0.
775+
// connection "".
779776
var forStream StreamID
780777
if forRequest.IsValid() {
781778
c.mu.Lock()
@@ -796,7 +793,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
796793

797794
stream := c.streams[forStream]
798795
if stream == nil {
799-
return fmt.Errorf("no stream with ID %d", forStream)
796+
return fmt.Errorf("no stream with ID %s", forStream)
800797
}
801798

802799
// Special case a few conditions where we fall back on stream 0 (the hanging GET):
@@ -806,11 +803,11 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
806803
//
807804
// TODO(rfindley): either of these, particularly the first, might be
808805
// considered a bug in the server. Report it through a side-channel?
809-
if len(stream.requests) == 0 && forStream != 0 || stream.jsonResponse && !isResponse {
810-
stream = c.streams[0]
806+
if len(stream.requests) == 0 && forStream != "" || stream.jsonResponse && !isResponse {
807+
stream = c.streams[""]
811808
}
812809

813-
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == 0
810+
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == ""
814811
// and the client never did a GET), then memory will grow without bound. Consider a mitigation.
815812
stream.outgoing = append(stream.outgoing, data)
816813
if isResponse {

mcp/streamable_test.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -784,22 +784,23 @@ func TestEventID(t *testing.T) {
784784
sid StreamID
785785
idx int
786786
}{
787-
{0, 0},
788-
{0, 1},
789-
{1, 0},
790-
{1, 1},
791-
{1234, 5678},
787+
{"0", 0},
788+
{"0", 1},
789+
{"1", 0},
790+
{"1", 1},
791+
{"", 1},
792+
{"1234", 5678},
792793
}
793794

794795
for _, test := range tests {
795-
t.Run(fmt.Sprintf("%d_%d", test.sid, test.idx), func(t *testing.T) {
796+
t.Run(fmt.Sprintf("%s_%d", test.sid, test.idx), func(t *testing.T) {
796797
eventID := formatEventID(test.sid, test.idx)
797798
gotSID, gotIdx, ok := parseEventID(eventID)
798799
if !ok {
799800
t.Fatalf("parseEventID(%q) failed, want ok", eventID)
800801
}
801802
if gotSID != test.sid || gotIdx != test.idx {
802-
t.Errorf("parseEventID(%q) = %d, %d, want %d, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
803+
t.Errorf("parseEventID(%q) = %s, %d, want %s, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
803804
}
804805
})
805806
}
@@ -808,10 +809,7 @@ func TestEventID(t *testing.T) {
808809
"",
809810
"_",
810811
"1_",
811-
"_1",
812-
"a_1",
813812
"1_a",
814-
"-1_1",
815813
"1_-1",
816814
}
817815

0 commit comments

Comments
 (0)