diff --git a/jsonrpc/jsonrpc.go b/jsonrpc/jsonrpc.go new file mode 100644 index 00000000..f175e597 --- /dev/null +++ b/jsonrpc/jsonrpc.go @@ -0,0 +1,20 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc exposes part of a JSON-RPC v2 implementation +// for use by mcp transport authors. +package jsonrpc + +import "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + +type ( + // ID is a JSON-RPC request ID. + ID = jsonrpc2.ID + // Message is a JSON-RPC message. + Message = jsonrpc2.Message + // Request is a JSON-RPC request. + Request = jsonrpc2.Request + // Response is a JSON-RPC response. + Response = jsonrpc2.Response +) diff --git a/mcp/client.go b/mcp/client.go index 512be2cb..40d3c792 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -13,6 +13,7 @@ import ( "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // A Client is an MCP client, which may be connected to an MCP server @@ -301,7 +302,7 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { return clientMethodInfos } -func (cs *ClientSession) handle(ctx context.Context, req *JSONRPCRequest) (any, error) { +func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { return handleReceive(ctx, cs, req) } diff --git a/mcp/server.go b/mcp/server.go index de14ca06..f9b76539 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -21,6 +21,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) const DefaultPageSize = 1000 @@ -610,7 +611,7 @@ func (ss *ServerSession) receivingMethodHandler() methodHandler { func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. -func (ss *ServerSession) handle(ctx context.Context, req *JSONRPCRequest) (any, error) { +func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { ss.mu.Lock() initialized := ss.initialized ss.mu.Unlock() diff --git a/mcp/shared.go b/mcp/shared.go index 8a38777e..fef20946 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -20,6 +20,7 @@ import ( "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // latestProtocolVersion is the latest protocol version that this version of the SDK supports. @@ -121,7 +122,7 @@ func defaultReceivingMethodHandler[S Session](ctx context.Context, session S, me return info.handleMethod.(MethodHandler[S])(ctx, session, method, params) } -func handleReceive[S Session](ctx context.Context, session S, req *JSONRPCRequest) (Result, error) { +func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Request) (Result, error) { info, ok := session.receivingMethodInfos()[req.Method] if !ok { return nil, jsonrpc2.ErrNotHandled diff --git a/mcp/sse.go b/mcp/sse.go index d1b52599..f0d7b34c 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -18,6 +18,7 @@ import ( "sync" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // This file implements support for SSE (HTTP with server-sent events) @@ -111,7 +112,7 @@ func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { // - Close terminates the hanging GET. type SSEServerTransport struct { endpoint string - incoming chan JSONRPCMessage // queue of incoming messages; never closed + incoming chan jsonrpc.Message // queue of incoming messages; never closed // We must guard both pushes to the incoming queue and writes to the response // writer, because incoming POST requests are arbitrarily concurrent and we @@ -138,7 +139,7 @@ func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTra return &SSEServerTransport{ endpoint: endpoint, w: w, - incoming: make(chan JSONRPCMessage, 100), + incoming: make(chan jsonrpc.Message, 100), done: make(chan struct{}), } } @@ -267,7 +268,7 @@ type sseServerConn struct { func (s sseServerConn) SessionID() string { return "" } // Read implements jsonrpc2.Reader. -func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (s sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -279,7 +280,7 @@ func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) { } // Write implements jsonrpc2.Writer. -func (s sseServerConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (s sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { if ctx.Err() != nil { return ctx.Err() } @@ -532,7 +533,7 @@ func (c *sseClientConn) isDone() bool { return c.closed } -func (c *sseClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -553,7 +554,7 @@ func (c *sseClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { } } -func (c *sseClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { return err diff --git a/mcp/streamable.go b/mcp/streamable.go index 52208948..11d70a38 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) const ( @@ -157,12 +158,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque func NewStreamableServerTransport(sessionID string) *StreamableServerTransport { return &StreamableServerTransport{ id: sessionID, - incoming: make(chan JSONRPCMessage, 10), + incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), outgoingMessages: make(map[streamID][]*streamableMsg), signals: make(map[streamID]chan struct{}), - requestStreams: make(map[JSONRPCID]streamID), - streamRequests: make(map[streamID]map[JSONRPCID]struct{}), + requestStreams: make(map[jsonrpc.ID]streamID), + streamRequests: make(map[streamID]map[jsonrpc.ID]struct{}), } } @@ -176,7 +177,7 @@ type StreamableServerTransport struct { nextStreamID atomic.Int64 // incrementing next stream ID id string - incoming chan JSONRPCMessage // messages from the client to the server + incoming chan jsonrpc.Message // messages from the client to the server mu sync.Mutex @@ -226,7 +227,7 @@ type StreamableServerTransport struct { // Lifecycle: requestStreams persists for the duration of the session. // // TODO(rfindley): clean up once requests are handled. - requestStreams map[JSONRPCID]streamID + requestStreams map[jsonrpc.ID]streamID // streamRequests tracks the set of unanswered incoming RPCs for each logical // stream. @@ -237,7 +238,7 @@ type StreamableServerTransport struct { // Lifecycle: streamRequests values persist as until the requests have been // replied to by the server. Notably, NOT until they are sent to an HTTP // response, as delivery is not guaranteed. - streamRequests map[streamID]map[JSONRPCID]struct{} + streamRequests map[streamID]map[jsonrpc.ID]struct{} } type streamID int64 @@ -271,7 +272,7 @@ func (s *StreamableServerTransport) Connect(context.Context) (Connection, error) // 2. Expose a 'HandlerTransport' interface that allows transports to provide // a handler middleware, so that we don't hard-code this behavior in // ServerSession.handle. -// 3. Add a `func ForRequest(context.Context) JSONRPCID` accessor that lets +// 3. Add a `func ForRequest(context.Context) jsonrpc.ID` accessor that lets // any transport access the incoming request ID. // // For now, by giving only the StreamableServerTransport access to the request @@ -340,9 +341,9 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) return } - requests := make(map[JSONRPCID]struct{}) + requests := make(map[jsonrpc.ID]struct{}) for _, msg := range incoming { - if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() { + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() { requests[req.ID] = struct{}{} } } @@ -352,7 +353,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R signal := make(chan struct{}, 1) t.mu.Lock() if len(requests) > 0 { - t.streamRequests[id] = make(map[JSONRPCID]struct{}) + t.streamRequests[id] = make(map[jsonrpc.ID]struct{}) } for reqID := range requests { t.requestStreams[reqID] = id @@ -484,7 +485,7 @@ func parseEventID(eventID string) (sid streamID, idx int, ok bool) { } // Read implements the [Connection] interface. -func (t *StreamableServerTransport) Read(ctx context.Context) (JSONRPCMessage, error) { +func (t *StreamableServerTransport) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -499,10 +500,10 @@ func (t *StreamableServerTransport) Read(ctx context.Context) (JSONRPCMessage, e } // Write implements the [Connection] interface. -func (t *StreamableServerTransport) Write(ctx context.Context, msg JSONRPCMessage) error { +func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Message) error { // Find the incoming request that this write relates to, if any. - var forRequest, replyTo JSONRPCID - if resp, ok := msg.(*JSONRPCResponse); ok { + var forRequest, replyTo jsonrpc.ID + if resp, ok := msg.(*jsonrpc.Response); ok { // If the message is a response, it relates to its request (of course). forRequest = resp.ID replyTo = resp.ID @@ -511,7 +512,7 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg JSONRPCMessag // ongoing request. This may not be the case if the request way made with // an unrelated context. if v := ctx.Value(idContextKey{}); v != nil { - forRequest = v.(JSONRPCID) + forRequest = v.(jsonrpc.ID) } } @@ -661,7 +662,7 @@ func (c *streamableClientConn) SessionID() string { } // Read implements the [Connection] interface. -func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -673,7 +674,7 @@ func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) } // Write implements the [Connection] interface. -func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { s.mu.Lock() if s.err != nil { s.mu.Unlock() @@ -709,7 +710,7 @@ func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) er return nil } -func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg JSONRPCMessage) (string, error) { +func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { return "", err diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 3329caea..412d2e1d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) func TestStreamableTransports(t *testing.T) { @@ -126,16 +127,16 @@ func TestStreamableServerTransport(t *testing.T) { // Redundant with OnRequest: all OnRequest steps are asynchronous. Async bool - Method string // HTTP request method - Send []JSONRPCMessage // messages to send - CloseAfter int // if nonzero, close after receiving this many messages - StatusCode int // expected status code - Recv []JSONRPCMessage // expected messages to receive + Method string // HTTP request method + Send []jsonrpc.Message // messages to send + CloseAfter int // if nonzero, close after receiving this many messages + StatusCode int // expected status code + Recv []jsonrpc.Message // expected messages to receive } // JSON-RPC message constructors. - req := func(id int64, method string, params any) *JSONRPCRequest { - r := &JSONRPCRequest{ + req := func(id int64, method string, params any) *jsonrpc.Request { + r := &jsonrpc.Request{ Method: method, Params: mustMarshal(t, params), } @@ -144,8 +145,8 @@ func TestStreamableServerTransport(t *testing.T) { } return r } - resp := func(id int64, result any, err error) *JSONRPCResponse { - return &JSONRPCResponse{ + resp := func(id int64, result any, err error) *jsonrpc.Response { + return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), Result: mustMarshal(t, result), Error: err, @@ -168,13 +169,13 @@ func TestStreamableServerTransport(t *testing.T) { initializedMsg := req(0, "initialized", &InitializedParams{}) initialize := step{ Method: "POST", - Send: []JSONRPCMessage{initReq}, + Send: []jsonrpc.Message{initReq}, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{initResp}, + Recv: []jsonrpc.Message{initResp}, } initialized := step{ Method: "POST", - Send: []JSONRPCMessage{initializedMsg}, + Send: []jsonrpc.Message{initializedMsg}, StatusCode: http.StatusAccepted, } @@ -190,9 +191,9 @@ func TestStreamableServerTransport(t *testing.T) { initialized, { Method: "POST", - Send: []JSONRPCMessage{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{resp(2, &CallToolResult{}, nil)}, + Recv: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, }, }, }, @@ -209,11 +210,11 @@ func TestStreamableServerTransport(t *testing.T) { initialized, { Method: "POST", - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ + Recv: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, @@ -234,18 +235,18 @@ func TestStreamableServerTransport(t *testing.T) { { Method: "POST", OnRequest: 1, - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, StatusCode: http.StatusAccepted, }, { Method: "POST", - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ + Recv: []jsonrpc.Message{ req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, @@ -275,7 +276,7 @@ func TestStreamableServerTransport(t *testing.T) { { Method: "POST", OnRequest: 1, - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, StatusCode: http.StatusAccepted, @@ -285,18 +286,18 @@ func TestStreamableServerTransport(t *testing.T) { Async: true, StatusCode: http.StatusOK, CloseAfter: 2, - Recv: []JSONRPCMessage{ + Recv: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, }, { Method: "POST", - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ + Recv: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, }, @@ -315,9 +316,9 @@ func TestStreamableServerTransport(t *testing.T) { }, { Method: "POST", - Send: []JSONRPCMessage{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{resp(2, nil, &jsonrpc2.WireError{ + Recv: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, }, @@ -344,7 +345,7 @@ func TestStreamableServerTransport(t *testing.T) { httpServer := httptest.NewServer(handler) defer httpServer.Close() - // blocks records request blocks by JSONRPC ID. + // blocks records request blocks by jsonrpc. ID. // // When an OnRequest step is encountered, it waits on the corresponding // block. When a request with that ID is received, the block is closed. @@ -382,8 +383,8 @@ func TestStreamableServerTransport(t *testing.T) { // Collect messages received during this request, unblock other steps // when requests are received. - var got []JSONRPCMessage - out := make(chan JSONRPCMessage) + var got []jsonrpc.Message + out := make(chan jsonrpc.Message) // Cancel the step if we encounter a request that isn't going to be // handled. ctx, cancel := context.WithCancel(context.Background()) @@ -394,7 +395,7 @@ func TestStreamableServerTransport(t *testing.T) { defer wg.Done() for m := range out { - if req, ok := m.(*JSONRPCRequest); ok && req.ID.IsValid() { + if req, ok := m.(*jsonrpc.Request); ok && req.ID.IsValid() { // Encountered a server->client request. We should have a // response queued. Otherwise, we may deadlock. mu.Lock() @@ -427,7 +428,7 @@ func TestStreamableServerTransport(t *testing.T) { } wg.Wait() - transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id JSONRPCID) any { return id.Raw() }) + transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) if diff := cmp.Diff(step.Recv, got, transform); diff != "" { t.Errorf("received unexpected messages (-want +got):\n%s", diff) } @@ -469,7 +470,7 @@ func TestStreamableServerTransport(t *testing.T) { // Returns the sessionID and http status code from the response. If an error is // returned, sessionID and status code may still be set if the error occurs // after the response headers have been received. -func streamingRequest(ctx context.Context, serverURL, sessionID, method string, in []JSONRPCMessage, out chan<- JSONRPCMessage) (string, int, error) { +func streamingRequest(ctx context.Context, serverURL, sessionID, method string, in []jsonrpc.Message, out chan<- jsonrpc.Message) (string, int, error) { defer close(out) var body []byte diff --git a/mcp/transport.go b/mcp/transport.go index f0b81650..a7de5061 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -16,6 +16,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // ErrConnectionClosed is returned when sending a message to a connection that @@ -34,21 +35,10 @@ type Transport interface { Connect(ctx context.Context) (Connection, error) } -type ( - // JSONRPCID is a JSON-RPC request ID. - JSONRPCID = jsonrpc2.ID - // JSONRPCMessage is a JSON-RPC message. - JSONRPCMessage = jsonrpc2.Message - // JSONRPCRequest is a JSON-RPC request. - JSONRPCRequest = jsonrpc2.Request - // JSONRPCResponse is a JSON-RPC response. - JSONRPCResponse = jsonrpc2.Response -) - // A Connection is a logical bidirectional JSON-RPC connection. type Connection interface { - Read(context.Context) (JSONRPCMessage, error) - Write(context.Context, JSONRPCMessage) error + Read(context.Context) (jsonrpc.Message, error) + Write(context.Context, jsonrpc.Message) error Close() error // may be called concurrently by both peers SessionID() string } @@ -100,7 +90,7 @@ type binder[T handler] interface { } type handler interface { - handle(ctx context.Context, req *JSONRPCRequest) (any, error) + handle(ctx context.Context, req *jsonrpc.Request) (any, error) setConn(Connection) } @@ -143,7 +133,7 @@ type canceller struct { } // Preempt implements jsonrpc2.Preempter. -func (c *canceller) Preempt(ctx context.Context, req *JSONRPCRequest) (result any, err error) { +func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { if req.Method == "notifications/cancelled" { var params CancelledParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { @@ -212,7 +202,7 @@ type loggingConn struct { func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } // loggingReader is a stream middleware that logs incoming messages. -func (s *loggingConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { msg, err := s.delegate.Read(ctx) if err != nil { fmt.Fprintf(s.w, "read error: %v", err) @@ -227,7 +217,7 @@ func (s *loggingConn) Read(ctx context.Context) (JSONRPCMessage, error) { } // loggingWriter is a stream middleware that logs outgoing messages. -func (s *loggingConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (s *loggingConn) Write(ctx context.Context, msg jsonrpc.Message) error { err := s.delegate.Write(ctx, msg) if err != nil { fmt.Fprintf(s.w, "write error: %v", err) @@ -265,7 +255,7 @@ func (r rwc) Close() error { } // An ioConn is a transport that delimits messages with newlines across -// a bidirectional stream, and supports JSONRPC2 message batching. +// a bidirectional stream, and supports jsonrpc.2 message batching. // // See https://github.com/ndjson/ndjson-spec for discussion of newline // delimited JSON. @@ -277,11 +267,11 @@ type ioConn struct { // If outgoiBatch has a positive capacity, it will be used to batch requests // and notifications before sending. - outgoingBatch []JSONRPCMessage + outgoingBatch []jsonrpc.Message // Unread messages in the last batch. Since reads are serialized, there is no // need to guard here. - queue []JSONRPCMessage + queue []jsonrpc.Message // batches correlate incoming requests to the batch in which they arrived. // Since writes may be concurrent to reads, we need to guard this with a mutex. @@ -325,7 +315,7 @@ func (t *ioConn) addBatch(batch *msgBatch) error { // The second result reports whether resp was part of a batch. If this is true, // the first result is nil if the batch is still incomplete, or the full set of // batch responses if resp completed the batch. -func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { +func (t *ioConn) updateBatch(resp *jsonrpc.Response) ([]*jsonrpc.Response, bool) { t.batchMu.Lock() defer t.batchMu.Unlock() @@ -345,9 +335,9 @@ func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { return nil, false } -// A msgBatch records information about an incoming batch of JSONRPC2 calls. +// A msgBatch records information about an incoming batch of jsonrpc.2 calls. // -// The JSONRPC2 spec (https://www.jsonrpc.org/specification#batch) says: +// The jsonrpc.2 spec (https://www.jsonrpc.org/specification#batch) says: // // "The Server should respond with an Array containing the corresponding // Response objects, after all of the batch Request objects have been @@ -360,14 +350,14 @@ func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { // When there are no unresolved calls, the response payload is sent. type msgBatch struct { unresolved map[jsonrpc2.ID]int - responses []*JSONRPCResponse + responses []*jsonrpc.Response } -func (t *ioConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { return t.read(ctx, t.in) } -func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, error) { +func (t *ioConn) read(ctx context.Context, in *json.Decoder) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -392,7 +382,7 @@ func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, er if batch { var respBatch *msgBatch // track incoming requests in the batch for _, msg := range msgs { - if req, ok := msg.(*JSONRPCRequest); ok { + if req, ok := msg.(*jsonrpc.Request); ok { if respBatch == nil { respBatch = &msgBatch{ unresolved: make(map[jsonrpc2.ID]int), @@ -417,7 +407,7 @@ func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, er // readBatch reads batch data, which may be either a single JSON-RPC message, // or an array of JSON-RPC messages. -func readBatch(data []byte) (msgs []JSONRPCMessage, isBatch bool, _ error) { +func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { // Try to read an array of messages first. var rawBatch []json.RawMessage if err := json.Unmarshal(data, &rawBatch); err == nil { @@ -435,10 +425,10 @@ func readBatch(data []byte) (msgs []JSONRPCMessage, isBatch bool, _ error) { } // Try again with a single message. msg, err := jsonrpc2.DecodeMessage(data) - return []JSONRPCMessage{msg}, false, err + return []jsonrpc.Message{msg}, false, err } -func (t *ioConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { select { case <-ctx.Done(): return ctx.Err() @@ -449,7 +439,7 @@ func (t *ioConn) Write(ctx context.Context, msg JSONRPCMessage) error { // check that first. Otherwise, it is a request or notification, and we may // want to collect it into a batch before sending, if we're configured to use // outgoing batches. - if resp, ok := msg.(*JSONRPCResponse); ok { + if resp, ok := msg.(*jsonrpc.Response); ok { if batch, ok := t.updateBatch(resp); ok { if len(batch) > 0 { data, err := marshalMessages(batch) @@ -489,7 +479,7 @@ func (t *ioConn) Close() error { return t.rwc.Close() } -func marshalMessages[T JSONRPCMessage](msgs []T) ([]byte, error) { +func marshalMessages[T jsonrpc.Message](msgs []T) ([]byte, error) { var rawMsgs []json.RawMessage for _, msg := range msgs { raw, err := jsonrpc2.EncodeMessage(msg) diff --git a/mcp/transport_test.go b/mcp/transport_test.go index db18a352..c63b84ee 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) func TestBatchFraming(t *testing.T) { @@ -22,10 +23,10 @@ func TestBatchFraming(t *testing.T) { r, w := io.Pipe() tport := newIOConn(rwc{r, w}) - tport.outgoingBatch = make([]JSONRPCMessage, 0, 2) + tport.outgoingBatch = make([]jsonrpc.Message, 0, 2) // Read the two messages into a channel, for easy testing later. - read := make(chan JSONRPCMessage) + read := make(chan jsonrpc.Message) go func() { for range 2 { msg, _ := tport.Read(ctx) @@ -34,7 +35,7 @@ func TestBatchFraming(t *testing.T) { }() // The first write should not yet be observed by the reader. - tport.Write(ctx, &JSONRPCRequest{ID: jsonrpc2.Int64ID(1), Method: "test"}) + tport.Write(ctx, &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "test"}) select { case got := <-read: t.Fatalf("after one write, got message %v", got) @@ -42,10 +43,10 @@ func TestBatchFraming(t *testing.T) { } // ...but the second write causes both messages to be observed. - tport.Write(ctx, &JSONRPCRequest{ID: jsonrpc2.Int64ID(2), Method: "test"}) + tport.Write(ctx, &jsonrpc.Request{ID: jsonrpc2.Int64ID(2), Method: "test"}) for _, want := range []int64{1, 2} { got := <-read - if got := got.(*JSONRPCRequest).ID.Raw(); got != want { + if got := got.(*jsonrpc.Request).ID.Raw(); got != want { t.Errorf("got message #%d, want #%d", got, want) } }