Skip to content

Commit b666793

Browse files
transport.go: disable stdio batching for newer protocols (#453)
This PR disables batching support for stdio by storing the negotiated protocol version and comparing it to protocolVersion20250618. Fixes: #21
1 parent bf3ff50 commit b666793

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

mcp/transport.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ func (r rwc) Close() error {
283283
//
284284
// See [msgBatch] for more discussion of message batching.
285285
type ioConn struct {
286+
protocolVersion string // negotiated version, set during session initialization.
287+
286288
writeMu sync.Mutex // guards Write, which must be concurrency safe.
287289
rwc io.ReadWriteCloser // the underlying stream
288290

@@ -360,6 +362,17 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
360362

361363
func (c *ioConn) SessionID() string { return "" }
362364

365+
func (c *ioConn) sessionUpdated(state ServerSessionState) {
366+
protocolVersion := ""
367+
if state.InitializeParams != nil {
368+
protocolVersion = state.InitializeParams.ProtocolVersion
369+
}
370+
if protocolVersion == "" {
371+
protocolVersion = protocolVersion20250326
372+
}
373+
c.protocolVersion = negotiatedVersion(protocolVersion)
374+
}
375+
363376
// addBatch records a msgBatch for an incoming batch payload.
364377
// It returns an error if batch is malformed, containing previously seen IDs.
365378
//
@@ -458,6 +471,10 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) {
458471
if err != nil {
459472
return nil, err
460473
}
474+
if batch && t.protocolVersion >= protocolVersion20250618 {
475+
return nil, fmt.Errorf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, t.protocolVersion)
476+
}
477+
461478
t.queue = msgs[1:]
462479

463480
if batch {

mcp/transport_test.go

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,16 @@ func TestBatchFraming(t *testing.T) {
5555

5656
func TestIOConnRead(t *testing.T) {
5757
tests := []struct {
58-
name string
59-
input string
60-
want string
58+
name string
59+
input string
60+
want string
61+
protocolVersion string
6162
}{
62-
6363
{
6464
name: "valid json input",
6565
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`,
6666
want: "",
6767
},
68-
6968
{
7069
name: "newline at the end of first valid json input",
7170
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}
@@ -77,13 +76,41 @@ func TestIOConnRead(t *testing.T) {
7776
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`,
7877
want: "invalid trailing data at the end of stream",
7978
},
79+
{
80+
name: "batching unknown protocol",
81+
input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`,
82+
want: "",
83+
protocolVersion: "",
84+
},
85+
{
86+
name: "batching old protocol",
87+
input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`,
88+
want: "",
89+
protocolVersion: protocolVersion20241105,
90+
},
91+
{
92+
name: "batching new protocol",
93+
input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`,
94+
want: "JSON-RPC batching is not supported in 2025-06-18 and later (request version: 2025-06-18)",
95+
protocolVersion: protocolVersion20250618,
96+
},
8097
}
8198
for _, tt := range tests {
8299
t.Run(tt.name, func(t *testing.T) {
83100
tr := newIOConn(rwc{
84101
rc: io.NopCloser(strings.NewReader(tt.input)),
85102
})
103+
if tt.protocolVersion != "" {
104+
tr.sessionUpdated(ServerSessionState{
105+
InitializeParams: &InitializeParams{
106+
ProtocolVersion: tt.protocolVersion,
107+
},
108+
})
109+
}
86110
_, err := tr.Read(context.Background())
111+
if err == nil && tt.want != "" {
112+
t.Errorf("ioConn.Read() got nil error but wanted %v", tt.want)
113+
}
87114
if err != nil && err.Error() != tt.want {
88115
t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want)
89116
}

0 commit comments

Comments
 (0)