Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
if protocolVersion == "" {
protocolVersion = protocolVersion20250326
}
protocolVersion = negotiatedVersion(protocolVersion)

if isBatch && protocolVersion >= protocolVersion20250618 {
http.Error(w, fmt.Sprintf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, protocolVersion), http.StatusBadRequest)
Expand Down
17 changes: 17 additions & 0 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ func (r rwc) Close() error {
//
// See [msgBatch] for more discussion of message batching.
type ioConn struct {
protocolVersion string

writeMu sync.Mutex // guards Write, which must be concurrency safe.
rwc io.ReadWriteCloser // the underlying stream

Expand Down Expand Up @@ -360,6 +362,17 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {

func (c *ioConn) SessionID() string { return "" }

func (c *ioConn) sessionUpdated(state ServerSessionState) {
protocolVersion := ""
if state.InitializeParams != nil {
protocolVersion = state.InitializeParams.ProtocolVersion
}
if protocolVersion == "" {
protocolVersion = protocolVersion20250326
}
c.protocolVersion = negotiatedVersion(protocolVersion)
}

// addBatch records a msgBatch for an incoming batch payload.
// It returns an error if batch is malformed, containing previously seen IDs.
//
Expand Down Expand Up @@ -458,6 +471,10 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) {
if err != nil {
return nil, err
}
if batch && t.protocolVersion >= protocolVersion20250618 {
return nil, fmt.Errorf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, t.protocolVersion)
}

t.queue = msgs[1:]

if batch {
Expand Down
37 changes: 32 additions & 5 deletions mcp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,16 @@ func TestBatchFraming(t *testing.T) {

func TestIOConnRead(t *testing.T) {
tests := []struct {
name string
input string
want string
name string
input string
want string
protocolVersion string
}{

{
name: "valid json input",
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`,
want: "",
},

{
name: "newline at the end of first valid json input",
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}
Expand All @@ -77,13 +76,41 @@ func TestIOConnRead(t *testing.T) {
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`,
want: "invalid trailing data at the end of stream",
},
{
name: "batching unknown protocol",
input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`,
want: "",
protocolVersion: "",
},
{
name: "batching old protocol",
input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`,
want: "",
protocolVersion: protocolVersion20241105,
},
{
name: "batching new protocol",
input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`,
want: "JSON-RPC batching is not supported in 2025-06-18 and later (request version: 2025-06-18)",
protocolVersion: protocolVersion20250618,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tr := newIOConn(rwc{
rc: io.NopCloser(strings.NewReader(tt.input)),
})
if tt.protocolVersion != "" {
tr.sessionUpdated(ServerSessionState{
InitializeParams: &InitializeParams{
ProtocolVersion: tt.protocolVersion,
},
})
}
_, err := tr.Read(context.Background())
if err == nil && tt.want != "" {
t.Errorf("ioConn.Read() got nil error but wanted %v", tt.want)
}
if err != nil && err.Error() != tt.want {
t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want)
}
Expand Down