diff --git a/mcp/transport.go b/mcp/transport.go index 608247cd..024863de 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -283,6 +283,8 @@ func (r rwc) Close() error { // // See [msgBatch] for more discussion of message batching. type ioConn struct { + protocolVersion string // negotiated version, set during session initialization. + writeMu sync.Mutex // guards Write, which must be concurrency safe. rwc io.ReadWriteCloser // the underlying stream @@ -360,6 +362,17 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { func (c *ioConn) SessionID() string { return "" } +func (c *ioConn) sessionUpdated(state ServerSessionState) { + protocolVersion := "" + if state.InitializeParams != nil { + protocolVersion = state.InitializeParams.ProtocolVersion + } + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + c.protocolVersion = negotiatedVersion(protocolVersion) +} + // addBatch records a msgBatch for an incoming batch payload. // It returns an error if batch is malformed, containing previously seen IDs. // @@ -458,6 +471,10 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { if err != nil { return nil, err } + if batch && t.protocolVersion >= protocolVersion20250618 { + return nil, fmt.Errorf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, t.protocolVersion) + } + t.queue = msgs[1:] if batch { diff --git a/mcp/transport_test.go b/mcp/transport_test.go index 18a326e8..d40ce10f 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -55,17 +55,16 @@ func TestBatchFraming(t *testing.T) { func TestIOConnRead(t *testing.T) { tests := []struct { - name string - input string - want string + name string + input string + want string + protocolVersion string }{ - { name: "valid json input", input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`, want: "", }, - { name: "newline at the end of first valid json input", input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}} @@ -77,13 +76,41 @@ func TestIOConnRead(t *testing.T) { input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`, want: "invalid trailing data at the end of stream", }, + { + name: "batching unknown protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "", + protocolVersion: "", + }, + { + name: "batching old protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "", + protocolVersion: protocolVersion20241105, + }, + { + name: "batching new protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "JSON-RPC batching is not supported in 2025-06-18 and later (request version: 2025-06-18)", + protocolVersion: protocolVersion20250618, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tr := newIOConn(rwc{ rc: io.NopCloser(strings.NewReader(tt.input)), }) + if tt.protocolVersion != "" { + tr.sessionUpdated(ServerSessionState{ + InitializeParams: &InitializeParams{ + ProtocolVersion: tt.protocolVersion, + }, + }) + } _, err := tr.Read(context.Background()) + if err == nil && tt.want != "" { + t.Errorf("ioConn.Read() got nil error but wanted %v", tt.want) + } if err != nil && err.Error() != tt.want { t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want) }