Skip to content

Commit 2ad24b9

Browse files
committed
mcp: changed streamID's from int64 to random strings
Changed StreamID's to store randomly generated strings. Created a defaultStreamID field on the StreamableServerTransport struct which is generated and saved on instansiation. Updated all tests. Fixes #259
1 parent 8186bf3 commit 2ad24b9

File tree

4 files changed

+82
-79
lines changed

4 files changed

+82
-79
lines changed

mcp/event.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ func (s *MemoryEventStore) debugString() string {
392392
fmt.Fprintf(&b, "; ")
393393
}
394394
dl := sm[sid]
395-
fmt.Fprintf(&b, "%s %d first=%d", sess, sid, dl.first)
395+
fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first)
396396
for _, d := range dl.data {
397397
fmt.Fprintf(&b, " %s", d)
398398
}

mcp/event_test.go

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,21 @@ func TestMemoryEventStoreState(t *testing.T) {
119119
{
120120
"appends",
121121
func(s *MemoryEventStore) {
122-
appendEvent(s, "S1", 1, "d1")
123-
appendEvent(s, "S1", 2, "d2")
124-
appendEvent(s, "S1", 1, "d3")
125-
appendEvent(s, "S2", 8, "d4")
122+
appendEvent(s, "S1", "1", "d1")
123+
appendEvent(s, "S1", "2", "d2")
124+
appendEvent(s, "S1", "1", "d3")
125+
appendEvent(s, "S2", "8", "d4")
126126
},
127127
"S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4",
128128
8,
129129
},
130130
{
131131
"session close",
132132
func(s *MemoryEventStore) {
133-
appendEvent(s, "S1", 1, "d1")
134-
appendEvent(s, "S1", 2, "d2")
135-
appendEvent(s, "S1", 1, "d3")
136-
appendEvent(s, "S2", 8, "d4")
133+
appendEvent(s, "S1", "1", "d1")
134+
appendEvent(s, "S1", "2", "d2")
135+
appendEvent(s, "S1", "1", "d3")
136+
appendEvent(s, "S2", "8", "d4")
137137
s.SessionClosed(ctx, "S1")
138138
},
139139
"S2 8 first=0 d4",
@@ -142,10 +142,10 @@ func TestMemoryEventStoreState(t *testing.T) {
142142
{
143143
"purge",
144144
func(s *MemoryEventStore) {
145-
appendEvent(s, "S1", 1, "d1")
146-
appendEvent(s, "S1", 2, "d2")
147-
appendEvent(s, "S1", 1, "d3")
148-
appendEvent(s, "S2", 8, "d4")
145+
appendEvent(s, "S1", "1", "d1")
146+
appendEvent(s, "S1", "2", "d2")
147+
appendEvent(s, "S1", "1", "d3")
148+
appendEvent(s, "S2", "8", "d4")
149149
// We are using 8 bytes (d1,d2, d3, d4).
150150
// To purge 6, we remove the first of each stream, leaving only d3.
151151
s.SetMaxBytes(2)
@@ -157,31 +157,31 @@ func TestMemoryEventStoreState(t *testing.T) {
157157
{
158158
"purge append",
159159
func(s *MemoryEventStore) {
160-
appendEvent(s, "S1", 1, "d1")
161-
appendEvent(s, "S1", 2, "d2")
162-
appendEvent(s, "S1", 1, "d3")
163-
appendEvent(s, "S2", 8, "d4")
160+
appendEvent(s, "S1", "1", "d1")
161+
appendEvent(s, "S1", "2", "d2")
162+
appendEvent(s, "S1", "1", "d3")
163+
appendEvent(s, "S2", "8", "d4")
164164
s.SetMaxBytes(2)
165165
// Up to here, identical to the "purge" case.
166166
// Each of these additions will result in a purge.
167-
appendEvent(s, "S1", 2, "d5") // remove d3
168-
appendEvent(s, "S1", 2, "d6") // remove d5
167+
appendEvent(s, "S1", "2", "d5") // remove d3
168+
appendEvent(s, "S1", "2", "d6") // remove d5
169169
},
170170
"S1 1 first=2; S1 2 first=2 d6; S2 8 first=1",
171171
2,
172172
},
173173
{
174174
"purge resize append",
175175
func(s *MemoryEventStore) {
176-
appendEvent(s, "S1", 1, "d1")
177-
appendEvent(s, "S1", 2, "d2")
178-
appendEvent(s, "S1", 1, "d3")
179-
appendEvent(s, "S2", 8, "d4")
176+
appendEvent(s, "S1", "1", "d1")
177+
appendEvent(s, "S1", "2", "d2")
178+
appendEvent(s, "S1", "1", "d3")
179+
appendEvent(s, "S2", "8", "d4")
180180
s.SetMaxBytes(2)
181181
// Up to here, identical to the "purge" case.
182182
s.SetMaxBytes(6) // make room
183-
appendEvent(s, "S1", 2, "d5")
184-
appendEvent(s, "S1", 2, "d6")
183+
appendEvent(s, "S1", "2", "d5")
184+
appendEvent(s, "S1", "2", "d6")
185185
},
186186
// The other streams remain, because we may add to them.
187187
"S1 1 first=1 d3; S1 2 first=1 d5 d6; S2 8 first=1",
@@ -206,10 +206,10 @@ func TestMemoryEventStoreAfter(t *testing.T) {
206206
ctx := context.Background()
207207
s := NewMemoryEventStore(nil)
208208
s.SetMaxBytes(4)
209-
s.Append(ctx, "S1", 1, []byte("d1"))
210-
s.Append(ctx, "S1", 1, []byte("d2"))
211-
s.Append(ctx, "S1", 1, []byte("d3"))
212-
s.Append(ctx, "S1", 2, []byte("d4")) // will purge d1
209+
s.Append(ctx, "S1", "1", []byte("d1"))
210+
s.Append(ctx, "S1", "1", []byte("d2"))
211+
s.Append(ctx, "S1", "1", []byte("d3"))
212+
s.Append(ctx, "S1", "2", []byte("d4")) // will purge d1
213213
want := "S1 1 first=1 d2 d3; S1 2 first=0 d4"
214214
if got := s.debugString(); got != want {
215215
t.Fatalf("got state %q, want %q", got, want)
@@ -222,14 +222,14 @@ func TestMemoryEventStoreAfter(t *testing.T) {
222222
want []string
223223
wantErr string // if non-empty, error should contain this string
224224
}{
225-
{"S1", 1, 0, []string{"d2", "d3"}, ""},
226-
{"S1", 1, 1, []string{"d3"}, ""},
227-
{"S1", 1, 2, nil, ""},
228-
{"S1", 2, 0, nil, ""},
229-
{"S1", 3, 0, nil, "unknown stream ID"},
230-
{"S2", 0, 0, nil, "unknown session ID"},
225+
{"S1", "1", 0, []string{"d2", "d3"}, ""},
226+
{"S1", "1", 1, []string{"d3"}, ""},
227+
{"S1", "1", 2, nil, ""},
228+
{"S1", "2", 0, nil, ""},
229+
{"S1", "3", 0, nil, "unknown stream ID"},
230+
{"S2", "0", 0, nil, "unknown session ID"},
231231
} {
232-
t.Run(fmt.Sprintf("%s-%d-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) {
232+
t.Run(fmt.Sprintf("%s-%s-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) {
233233
var got []string
234234
for d, err := range s.After(ctx, tt.sessionID, tt.streamID, tt.index) {
235235
if err != nil {

mcp/streamable.go

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,14 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
176176
opts = &StreamableServerTransportOptions{}
177177
}
178178
t := &StreamableServerTransport{
179-
sessionID: sessionID,
180-
incoming: make(chan jsonrpc.Message, 10),
181-
done: make(chan struct{}),
182-
streams: make(map[StreamID]*stream),
183-
requestStreams: make(map[jsonrpc.ID]StreamID),
184-
}
185-
t.streams[0] = newStream(0)
179+
sessionID: sessionID,
180+
incoming: make(chan jsonrpc.Message, 10),
181+
done: make(chan struct{}),
182+
streams: make(map[StreamID]*stream),
183+
requestStreams: make(map[jsonrpc.ID]StreamID),
184+
defaultStreamID: StreamID(randText()),
185+
}
186+
t.streams[t.defaultStreamID] = newStream(t.defaultStreamID)
186187
if opts != nil {
187188
t.opts = *opts
188189
}
@@ -199,12 +200,11 @@ func (t *StreamableServerTransport) SessionID() string {
199200
// A StreamableServerTransport implements the [Transport] interface for a
200201
// single session.
201202
type StreamableServerTransport struct {
202-
nextStreamID atomic.Int64 // incrementing next stream ID
203-
204-
sessionID string
205-
opts StreamableServerTransportOptions
206-
incoming chan jsonrpc.Message // messages from the client to the server
207-
done chan struct{}
203+
sessionID string
204+
defaultStreamID StreamID
205+
opts StreamableServerTransportOptions
206+
incoming chan jsonrpc.Message // messages from the client to the server
207+
done chan struct{}
208208

209209
mu sync.Mutex
210210
// Sessions are closed exactly once.
@@ -242,7 +242,7 @@ type StreamableServerTransport struct {
242242
// at any time.
243243
type stream struct {
244244
// id is the logical ID for the stream, unique within a session.
245-
// ID 0 is used for messages that don't correlate with an incoming request.
245+
// defaultStreamID is used for messages that don't correlate with an incoming request.
246246
id StreamID
247247

248248
// signal is a 1-buffered channel, owned by an incoming HTTP request, that signals
@@ -283,9 +283,10 @@ func signalChanPtr() *chan struct{} {
283283
return &c
284284
}
285285

286-
// A StreamID identifies a stream of SSE events. It is unique within the stream's
286+
// A StreamID identifies a stream of SSE events. It is a random string
287+
// 26 characters in length
287288
// [ServerSession].
288-
type StreamID int64
289+
type StreamID string
289290

290291
// Connect implements the [Transport] interface.
291292
//
@@ -338,8 +339,8 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R
338339
}
339340

340341
func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) (int, string) {
341-
// connID 0 corresponds to the default GET request.
342-
id := StreamID(0)
342+
// connID = t.defaultStreamID corresponds to the default GET request.
343+
id := t.defaultStreamID
343344
// By default, we haven't seen a last index. Since indices start at 0, we represent
344345
// that by -1. This is incremented just before each event is written, in streamResponse
345346
// around L407.
@@ -399,7 +400,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
399400
}
400401

401402
// Update accounting for this request.
402-
stream := newStream(StreamID(t.nextStreamID.Add(1)))
403+
stream := newStream(StreamID(randText()))
403404
t.mu.Lock()
404405
t.streams[stream.id] = stream
405406
if len(requests) > 0 {
@@ -533,7 +534,7 @@ stream:
533534
//
534535
// See also [parseEventID].
535536
func formatEventID(sid StreamID, idx int) string {
536-
return fmt.Sprintf("%d_%d", sid, idx)
537+
return fmt.Sprintf("%s_%d", sid, idx)
537538
}
538539

539540
// parseEventID parses a Last-Event-ID value into a logical stream id and
@@ -543,15 +544,15 @@ func formatEventID(sid StreamID, idx int) string {
543544
func parseEventID(eventID string) (sid StreamID, idx int, ok bool) {
544545
parts := strings.Split(eventID, "_")
545546
if len(parts) != 2 {
546-
return 0, 0, false
547+
return "", 0, false
547548
}
548-
stream, err := strconv.ParseInt(parts[0], 10, 64)
549-
if err != nil || stream < 0 {
550-
return 0, 0, false
549+
stream := StreamID(parts[0])
550+
if len(stream) == 0 {
551+
return "", 0, false
551552
}
552-
idx, err = strconv.Atoi(parts[1])
553+
idx, err := strconv.Atoi(parts[1])
553554
if err != nil || idx < 0 {
554-
return 0, 0, false
555+
return "", 0, false
555556
}
556557
return StreamID(stream), idx, true
557558
}
@@ -592,7 +593,7 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
592593
// Find the logical connection corresponding to this request.
593594
//
594595
// For messages sent outside of a request context, this is the default
595-
// connection 0.
596+
// connection.
596597
var forConn StreamID
597598
if forRequest.IsValid() {
598599
t.mu.Lock()
@@ -612,18 +613,21 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
612613
}
613614

614615
stream := t.streams[forConn]
616+
if forConn == "" {
617+
stream = t.streams[t.defaultStreamID]
618+
}
615619
if stream == nil {
616-
return fmt.Errorf("no stream with ID %d", forConn)
620+
return fmt.Errorf("no stream with ID %s", forConn)
617621
}
618-
if len(stream.requests) == 0 && forConn != 0 {
622+
if len(stream.requests) == 0 && forConn != "" {
619623
// No outstanding requests for this connection, which means it is logically
620624
// done. This is a sequencing violation from the server, so we should report
621625
// a side-channel error here. Put the message on the general queue to avoid
622626
// dropping messages.
623-
stream = t.streams[0]
627+
stream = t.streams[t.defaultStreamID]
624628
}
625629

626-
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == 0
630+
// TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == ""
627631
// and the client never did a GET), then memory will grow without bound. Consider a mitigation.
628632
stream.outgoing = append(stream.outgoing, data)
629633
if isResponse {

mcp/streamable_test.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,10 @@ func TestServerInitiatedSSE(t *testing.T) {
222222

223223
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
224224
defer cancel()
225-
client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) {
226-
notifications <- "toolListChanged"
227-
},
225+
client := NewClient(testImpl, &ClientOptions{
226+
ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) {
227+
notifications <- "toolListChanged"
228+
},
228229
})
229230
clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil))
230231
if err != nil {
@@ -768,22 +769,22 @@ func TestEventID(t *testing.T) {
768769
sid StreamID
769770
idx int
770771
}{
771-
{0, 0},
772-
{0, 1},
773-
{1, 0},
774-
{1, 1},
775-
{1234, 5678},
772+
{"0", 0},
773+
{"0", 1},
774+
{"1", 0},
775+
{"1", 1},
776+
{"1234", 5678},
776777
}
777778

778779
for _, test := range tests {
779-
t.Run(fmt.Sprintf("%d_%d", test.sid, test.idx), func(t *testing.T) {
780+
t.Run(fmt.Sprintf("%s_%d", test.sid, test.idx), func(t *testing.T) {
780781
eventID := formatEventID(test.sid, test.idx)
781782
gotSID, gotIdx, ok := parseEventID(eventID)
782783
if !ok {
783784
t.Fatalf("parseEventID(%q) failed, want ok", eventID)
784785
}
785786
if gotSID != test.sid || gotIdx != test.idx {
786-
t.Errorf("parseEventID(%q) = %d, %d, want %d, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
787+
t.Errorf("parseEventID(%q) = %s, %d, want %s, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
787788
}
788789
})
789790
}
@@ -793,9 +794,7 @@ func TestEventID(t *testing.T) {
793794
"_",
794795
"1_",
795796
"_1",
796-
"a_1",
797797
"1_a",
798-
"-1_1",
799798
"1_-1",
800799
}
801800

0 commit comments

Comments
 (0)