diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 6bacfa7e..49902b00 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -739,7 +739,23 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e // write is used by all things that write outgoing messages, including replies. // it makes sure that writes are atomic func (c *Connection) write(ctx context.Context, msg Message) error { - err := c.writer.Write(ctx, msg) + var err error + // Fail writes immediately if the connection is shutting down. + // + // TODO(rfindley): should we allow cancellation notifications through? It + // could be the case that writes can still succeed. + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrServerClosing) + }) + if err == nil { + err = c.writer.Write(ctx, msg) + } + + // For rejected requests, we don't set the writeErr (which would break the + // connection). They can just be returned to the caller. + if errors.Is(err, ErrRejected) { + return err + } if err != nil && ctx.Err() == nil { // The call to Write failed, and since ctx.Err() is nil we can't attribute diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index b143dcd3..8be2872e 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -37,6 +37,17 @@ var ( ErrServerClosing = NewError(-32004, "server is closing") // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. ErrClientClosing = NewError(-32003, "client is closing") + + // The following errors have special semantics for MCP transports + + // ErrRejected may be wrapped to return errors from calls to Writer.Write + // that signal that the request was rejected by the transport layer as + // invalid. + // + // Such failures do not indicate that the connection is broken, but rather + // should be returned to the caller to indicate that the specific request is + // invalid in the current context. + ErrRejected = NewError(-32004, "rejected by transport") ) const wireVersion = "2.0" diff --git a/mcp/server.go b/mcp/server.go index c3fbd9e3..88021336 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -355,6 +355,8 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam if !ok { return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name) } + // TODO: if handler returns nil content, it will serialize as null. + // Add a test and fix. return st.handler(ctx, req) } diff --git a/mcp/streamable.go b/mcp/streamable.go index 526ee515..572fe5de 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -38,18 +38,34 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - mu sync.Mutex + mu sync.Mutex + // TODO: we should store the ServerSession along with the transport, because + // we need to cancel keepalive requests when closing the transport. transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } // StreamableHTTPOptions configures the StreamableHTTPHandler. type StreamableHTTPOptions struct { // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. + // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. // - // If GetSessionID returns an empty string, the session is 'stateless', - // meaning it is not persisted and no session validation is performed. + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. GetSessionID func() string + // Stateless controls whether the session is 'stateless'. + // + // A stateless server does not validate the Mcp-Session-Id header, and uses a + // temporary session with default initialization parameters. Any + // server->client request is rejected immediately as there's no way for the + // client to respond. Server->Client notifications may reach the client if + // they are made in the context of an incoming request, as described in the + // documentation for [StreamableServerTransport]. + Stateless bool + // TODO: support session retention (?) // jsonResponse is forwarded to StreamableServerTransport.jsonResponse. @@ -118,36 +134,40 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } + sessionID := req.Header.Get(sessionIDHeader) var transport *StreamableServerTransport - if id := req.Header.Get(sessionIDHeader); id != "" { + if sessionID != "" { h.mu.Lock() - transport = h.transports[id] + transport = h.transports[sessionID] h.mu.Unlock() - if transport == nil { + if transport == nil && !h.opts.Stateless { + // Unless we're in 'stateless' mode, which doesn't perform any Session-ID + // validation, we require that the session ID matches a known session. + // + // In stateless mode, a temporary transport is be created below. http.Error(w, "session not found", http.StatusNotFound) return } } - // TODO(rfindley): simplify the locking so that each request has only one - // critical section. if req.Method == http.MethodDelete { - if transport == nil { - // => Mcp-Session-Id was not set; else we'd have returned NotFound above. + if sessionID == "" { http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } - h.mu.Lock() - delete(h.transports, transport.SessionID) - h.mu.Unlock() - transport.connection.Close() + if transport != nil { // transport may be nil in stateless mode + h.mu.Lock() + delete(h.transports, transport.SessionID) + h.mu.Unlock() + transport.connection.Close() + } w.WriteHeader(http.StatusNoContent) return } switch req.Method { case http.MethodPost, http.MethodGet: - if req.Method == http.MethodGet && transport == nil { + if req.Method == http.MethodGet && sessionID == "" { http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) return } @@ -164,37 +184,83 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "no server available", http.StatusBadRequest) return } - sessionID := h.opts.GetSessionID() - s := &StreamableServerTransport{SessionID: sessionID, jsonResponse: h.opts.jsonResponse} + if sessionID == "" { + // In stateless mode, sessionID may be nonempty even if there's no + // existing transport. + sessionID = h.opts.GetSessionID() + } + transport = &StreamableServerTransport{ + SessionID: sessionID, + Stateless: h.opts.Stateless, + jsonResponse: h.opts.jsonResponse, + } // To support stateless mode, we initialize the session with a default // state, so that it doesn't reject subsequent requests. var connectOpts *ServerSessionOptions - if sessionID == "" { + if h.opts.Stateless { + // Peek at the body to see if it is initialize or initialized. + // We want those to be handled as usual. + var hasInitialize, hasInitialized bool + { + // TODO: verify that this allows protocol version negotiation for + // stateless servers. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + req.Body.Close() + + // Reset the body so that it can be read later. + req.Body = io.NopCloser(bytes.NewBuffer(body)) + + msgs, _, err := readBatch(body) + if err == nil { + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + switch req.Method { + case methodInitialize: + hasInitialize = true + case notificationInitialized: + hasInitialized = true + } + } + } + } + } + + // If we don't have InitializeParams or InitializedParams in the request, + // set the initial state to a default value. + state := new(ServerSessionState) + if !hasInitialize { + state.InitializeParams = new(InitializeParams) + } + if !hasInitialized { + state.InitializedParams = new(InitializedParams) + } connectOpts = &ServerSessionOptions{ - State: &ServerSessionState{ - InitializeParams: new(InitializeParams), - InitializedParams: new(InitializedParams), - }, + State: state, } } + // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - ss, err := server.Connect(req.Context(), s, connectOpts) + ss, err := server.Connect(req.Context(), transport, connectOpts) if err != nil { http.Error(w, "failed connection", http.StatusInternalServerError) return } - if sessionID == "" { + if h.opts.Stateless { // Stateless mode: close the session when the request exits. defer ss.Close() // close the fake session after handling the request } else { + // Otherwise, save the transport so that it can be reused h.mu.Lock() - h.transports[s.SessionID] = s + h.transports[transport.SessionID] = transport h.mu.Unlock() } - transport = s } transport.ServeHTTP(w, req) @@ -212,9 +278,22 @@ type StreamableServerTransportOptions struct { // A StreamableServerTransport implements the server side of the MCP streamable // transport. // -// Each StreamableServerTransport may be connected (via [Server.Connect]) at +// Each StreamableServerTransport must be connected (via [Server.Connect]) at // most once, since [StreamableServerTransport.ServeHTTP] serves messages to // the connected session. +// +// Reads from the streamable server connection receive messages from http POST +// requests from the client. Writes to the streamable server connection are +// sent either to the hanging POST response, or to the hanging GET, according +// to the following rules: +// - JSON-RPC responses to incoming requests are always routed to the +// appropriate HTTP response. +// - Requests or notifications made with a context.Context value derived from +// an incoming request handler, are routed to the HTTP response +// corresponding to that request, unless it has already terminated, in +// which case they are routed to the hanging GET. +// - Requests or notifications made with a detached context.Context value are +// routed to the hanging GET. type StreamableServerTransport struct { // SessionID is the ID of this session. // @@ -225,6 +304,13 @@ type StreamableServerTransport struct { // generator to produce one, as with [crypto/rand.Text].) SessionID string + // Stateless controls whether the eventstore is 'Stateless'. Server sessions + // connected to a stateless transport are disallowed from making outgoing + // requests. + // + // See also [StreamableHTTPOptions.Stateless]. + Stateless bool + // Storage for events, to enable stream resumption. // If nil, a [MemoryEventStore] with the default maximum size will be used. EventStore EventStore @@ -265,6 +351,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) } t.connection = &streamableServerConn{ sessionID: t.SessionID, + stateless: t.Stateless, eventStore: t.EventStore, jsonResponse: t.jsonResponse, incoming: make(chan jsonrpc.Message, 10), @@ -285,6 +372,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) type streamableServerConn struct { sessionID string + stateless bool jsonResponse bool eventStore EventStore @@ -755,6 +843,10 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error // Write implements the [Connection] interface. func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") { + // Requests aren't possible with stateless servers, or when there's no session ID. + return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) + } // Find the incoming request that this write relates to, if any. var forRequest jsonrpc.ID isResponse := false @@ -1152,9 +1244,18 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent c.fail(err) return } - - // Reconnection was successful. Continue the loop with the new response. resp = newResp + if resp.StatusCode == http.StatusMethodNotAllowed && persistent { + // The server doesn't support the hanging GET. + resp.Body.Close() + return + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + resp.Body.Close() + c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode))) + return + } + // Reconnection was successful. Continue the loop with the new response. } } @@ -1222,13 +1323,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er finalErr = err // Store the error and try again. continue } - - if !isResumable(resp) { - // The server indicated we should not continue. - resp.Body.Close() - return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status) - } - return resp, nil } } @@ -1239,16 +1333,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries) } -// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. -func isResumable(resp *http.Response) bool { - // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint. - if resp.StatusCode == http.StatusMethodNotAllowed { - return false - } - - return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") -} - // Close implements the [Connection] interface. func (c *streamableClientConn) Close() error { c.closeOnce.Do(func() { @@ -1288,8 +1372,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, // calculateReconnectDelay calculates a delay using exponential backoff with full jitter. func calculateReconnectDelay(attempt int) time.Duration { + if attempt == 0 { + return 0 + } // Calculate the exponential backoff using the grow factor. - backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt))) + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1))) // Cap the backoffDuration at maxDelay. backoffDuration = min(backoffDuration, reconnectMaxDelay) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 52c47998..8334bc0d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -382,30 +382,31 @@ func readNotifications(t *testing.T, ctx context.Context, notifications chan str } } +// JSON-RPC message constructors. +func req(id int64, method string, params any) *jsonrpc.Request { + r := &jsonrpc.Request{ + Method: method, + Params: mustMarshal(params), + } + if id > 0 { + r.ID = jsonrpc2.Int64ID(id) + } + return r +} + +func resp(id int64, result any, err error) *jsonrpc.Response { + return &jsonrpc.Response{ + ID: jsonrpc2.Int64ID(id), + Result: mustMarshal(result), + Error: err, + } +} + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP // requests. - // JSON-RPC message constructors. - req := func(id int64, method string, params any) *jsonrpc.Request { - r := &jsonrpc.Request{ - Method: method, - Params: mustMarshal(t, params), - } - if id > 0 { - r.ID = jsonrpc2.Int64ID(id) - } - return r - } - resp := func(id int64, result any, err error) *jsonrpc.Response { - return &jsonrpc.Response{ - ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(t, result), - Error: err, - } - } - // Predefined steps, to avoid repetition below. initReq := req(1, methodInitialize, &InitializeParams{}) initResp := resp(1, &InitializeResult{ @@ -422,21 +423,23 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{initReq}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, } initialized := streamableRequest{ method: "POST", messages: []jsonrpc.Message{initializedMsg}, wantStatusCode: http.StatusAccepted, + wantSessionID: false, // TODO: should this be true? } tests := []struct { - name string - tool func(*testing.T, context.Context, *ServerSession) - steps []streamableRequest + name string + tool func(*testing.T, context.Context, *ServerSession) + requests []streamableRequest // http requests }{ { name: "basic", - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -444,12 +447,13 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, + wantSessionID: true, }, }, }, { name: "accept headers", - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, // Test various accept headers. @@ -458,12 +462,14 @@ func TestStreamableServerTransport(t *testing.T) { accept: []string{"text/plain", "application/*"}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing text/event-stream + wantSessionID: false, }, { method: "POST", accept: []string{"text/event-stream"}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing application/json + wantSessionID: false, }, { method: "POST", @@ -471,6 +477,7 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + wantSessionID: true, }, { method: "POST", @@ -478,6 +485,7 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + wantSessionID: true, }, }, }, @@ -489,7 +497,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -502,6 +510,7 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, @@ -513,7 +522,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Call failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -523,6 +532,7 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, + wantSessionID: false, }, { method: "POST", @@ -534,6 +544,7 @@ func TestStreamableServerTransport(t *testing.T) { req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, @@ -554,7 +565,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -564,6 +575,7 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, + wantSessionID: false, }, { method: "GET", @@ -574,6 +586,7 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, + wantSessionID: true, }, { method: "POST", @@ -584,12 +597,13 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, { name: "errors", - steps: []streamableRequest{ + requests: []streamableRequest{ { method: "PUT", wantStatusCode: http.StatusMethodNotAllowed, @@ -615,6 +629,7 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, + wantSessionID: true, // TODO: this is probably wrong; we don't have a valid session }, }, }, @@ -636,118 +651,127 @@ func TestStreamableServerTransport(t *testing.T) { handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) defer handler.closeAll() - httpServer := httptest.NewServer(handler) - defer httpServer.Close() + testStreamableHandler(t, handler, test.requests) + }) + } +} - // blocks records request blocks by jsonrpc. ID. - // - // When an OnRequest step is encountered, it waits on the corresponding - // block. When a request with that ID is received, the block is closed. - var mu sync.Mutex - blocks := make(map[int64]chan struct{}) - for _, step := range test.steps { - if step.onRequest > 0 { - blocks[step.onRequest] = make(chan struct{}) - } - } +func testStreamableHandler(t *testing.T, handler http.Handler, requests []streamableRequest) { + httpServer := httptest.NewServer(handler) + defer httpServer.Close() - // signal when all synchronous requests have executed, so we can fail - // async requests that are blocked. - syncRequestsDone := make(chan struct{}) + // blocks records request blocks by jsonrpc. ID. + // + // When an OnRequest step is encountered, it waits on the corresponding + // block. When a request with that ID is received, the block is closed. + var mu sync.Mutex + blocks := make(map[int64]chan struct{}) + for _, req := range requests { + if req.onRequest > 0 { + blocks[req.onRequest] = make(chan struct{}) + } + } - // To avoid complicated accounting for session ID, just set the first - // non-empty session ID from a response. - var sessionID atomic.Value - sessionID.Store("") + // signal when all synchronous requests have executed, so we can fail + // async requests that are blocked. + syncRequestsDone := make(chan struct{}) + + // To avoid complicated accounting for session ID, just set the first + // non-empty session ID from a response. + var sessionID atomic.Value + sessionID.Store("") + + // doStep executes a single step. + doStep := func(t *testing.T, i int, request streamableRequest) { + if request.onRequest > 0 { + // Block the step until we've received the server->client request. + mu.Lock() + block := blocks[request.onRequest] + mu.Unlock() + select { + case <-block: + case <-syncRequestsDone: + t.Errorf("after all sync requests are complete, request still blocked on %d", request.onRequest) + return + } + } - // doStep executes a single step. - doStep := func(t *testing.T, step streamableRequest) { - if step.onRequest > 0 { - // Block the step until we've received the server->client request. + // Collect messages received during this request, unblock other steps + // when requests are received. + var got []jsonrpc.Message + out := make(chan jsonrpc.Message) + // Cancel the step if we encounter a request that isn't going to be + // handled. + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + for m := range out { + if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { + // Encountered a server->client request. We should have a + // response queued. Otherwise, we may deadlock. mu.Lock() - block := blocks[step.onRequest] - mu.Unlock() - select { - case <-block: - case <-syncRequestsDone: - t.Errorf("after all sync requests are complete, request still blocked on %d", step.onRequest) - return + if block, ok := blocks[req.ID.Raw().(int64)]; ok { + close(block) + } else { + t.Errorf("no queued response for %v", req.ID) + cancel() } + mu.Unlock() } + got = append(got, m) + if request.closeAfter > 0 && len(got) == request.closeAfter { + cancel() + } + } + }() - // Collect messages received during this request, unblock other steps - // when requests are received. - var got []jsonrpc.Message - out := make(chan jsonrpc.Message) - // Cancel the step if we encounter a request that isn't going to be - // handled. - ctx, cancel := context.WithCancel(context.Background()) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - for m := range out { - if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { - // Encountered a server->client request. We should have a - // response queued. Otherwise, we may deadlock. - mu.Lock() - if block, ok := blocks[req.ID.Raw().(int64)]; ok { - close(block) - } else { - t.Errorf("no queued response for %v", req.ID) - cancel() - } - mu.Unlock() - } - got = append(got, m) - if step.closeAfter > 0 && len(got) == step.closeAfter { - cancel() - } - } - }() - - gotSessionID, gotStatusCode, err := step.do(ctx, httpServer.URL, sessionID.Load().(string), out) + gotSessionID, gotStatusCode, err := request.do(ctx, httpServer.URL, sessionID.Load().(string), out) - // Don't fail on cancelled requests: error (if any) is handled - // elsewhere. - if err != nil && ctx.Err() == nil { - t.Fatal(err) - } + // Don't fail on cancelled requests: error (if any) is handled + // elsewhere. + if err != nil && ctx.Err() == nil { + t.Fatal(err) + } - if gotStatusCode != step.wantStatusCode { - t.Errorf("got status %d, want %d", gotStatusCode, step.wantStatusCode) - } - wg.Wait() + if gotStatusCode != request.wantStatusCode { + t.Errorf("request #%d: got status %d, want %d", i, gotStatusCode, request.wantStatusCode) + } + if got := gotSessionID != ""; got != request.wantSessionID { + t.Errorf("request #%d: got session id: %t, want %t", i, got, request.wantSessionID) + } + wg.Wait() - transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) - if diff := cmp.Diff(step.wantMessages, got, transform); diff != "" { - t.Errorf("received unexpected messages (-want +got):\n%s", diff) - } - sessionID.CompareAndSwap("", gotSessionID) + if !request.ignoreResponse { + transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) + if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { + t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff) } + } + sessionID.CompareAndSwap("", gotSessionID) + } - var wg sync.WaitGroup - for _, step := range test.steps { - if step.async || step.onRequest > 0 { - wg.Add(1) - go func() { - defer wg.Done() - doStep(t, step) - }() - } else { - doStep(t, step) - } - } + var wg sync.WaitGroup + for i, request := range requests { + if request.async || request.onRequest > 0 { + wg.Add(1) + go func() { + defer wg.Done() + doStep(t, i, request) + }() + } else { + doStep(t, i, request) + } + } - // Fail any blocked responses if they weren't needed by a synchronous - // request. - close(syncRequestsDone) + // Fail any blocked responses if they weren't needed by a synchronous + // request. + close(syncRequestsDone) - wg.Wait() - }) - } + wg.Wait() } // A streamableRequest describes a single streamable HTTP request, consisting @@ -768,13 +792,15 @@ type streamableRequest struct { async bool // Request attributes - method string // HTTP request method + method string // HTTP request method (required) accept []string // if non-empty, the Accept header to use; otherwise the default header is used messages []jsonrpc.Message // messages to send closeAfter int // if nonzero, close after receiving this many messages wantStatusCode int // expected status code + ignoreResponse bool // if set, don't check the response messages wantMessages []jsonrpc.Message // expected messages to receive + wantSessionID bool // whether or not a session ID is expected in the response } // streamingRequest makes a request to the given streamable server with the @@ -840,7 +866,8 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, newSessionID := resp.Header.Get("Mcp-Session-Id") - if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") { + contentType := resp.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "text/event-stream") { for evt, err := range scanEvents(resp.Body) { if err != nil { return newSessionID, resp.StatusCode, fmt.Errorf("reading events: %v", err) @@ -853,7 +880,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, } out <- msg } - } else if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + } else if strings.HasPrefix(contentType, "application/json") { data, err := io.ReadAll(resp.Body) if err != nil { return newSessionID, resp.StatusCode, fmt.Errorf("reading json body: %w", err) @@ -868,14 +895,13 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, return newSessionID, resp.StatusCode, nil } -func mustMarshal(t *testing.T, v any) json.RawMessage { +func mustMarshal(v any) json.RawMessage { if v == nil { return nil } - t.Helper() data, err := json.Marshal(v) if err != nil { - t.Fatal(err) + panic(err) } return data } @@ -886,7 +912,7 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { resp := func(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(t, result), + Result: mustMarshal(result), Error: err, } } @@ -970,72 +996,105 @@ func TestEventID(t *testing.T) { } func TestStreamableStateless(t *testing.T) { - // Test stateless mode behavior - ctx := context.Background() - - // This version of sayHi doesn't make a ping request (we can't respond to + // This version of sayHi expects // that request from our client). - sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil + sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) { + if err := req.Session.Ping(ctx, nil); err == nil { + // ping should fail, but not break the connection + t.Errorf("ping succeeded unexpectedly") + } + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) - // Test stateless mode. - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ - GetSessionID: func() string { return "" }, - }) - httpServer := httptest.NewServer(handler) - defer httpServer.Close() - - checkRequest := func(body string) { - // Verify we can call tools/list directly without initialization in stateless mode - req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, strings.NewReader(body)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") + requests := []streamableRequest{ + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, + ignoreResponse: true, + wantSessionID: false, + }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{ + req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "World"}}), + }, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi World"}}}, nil), + }, + wantSessionID: false, + }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{ + req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "foo"}}), + }, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi foo"}}}, nil), + }, + wantSessionID: false, + }, + } - resp, err := http.DefaultClient.Do(req) + testClientCompatibility := func(t *testing.T, handler http.Handler) { + ctx := context.Background() + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) if err != nil { t.Fatal(err) } - defer resp.Body.Close() - - // Verify that no session ID header is returned in stateless mode - if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { - t.Errorf("%s = %s, want no session ID header", sessionIDHeader, sessionID) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("Status code = %d; want successful response", resp.StatusCode) - } - - var events []Event - for event, err := range scanEvents(resp.Body) { - if err != nil { - t.Fatal(err) - } - events = append(events, event) - } - if len(events) != 1 { - t.Fatalf("got %d SSE events, want 1; events: %v", len(events), events) - } - msg, err := jsonrpc.DecodeMessage(events[0].Data) + res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}}) if err != nil { t.Fatal(err) } - jsonResp, ok := msg.(*jsonrpc.Response) - if !ok { - t.Errorf("event is %T, want response", jsonResp) - } - if jsonResp.Error != nil { - t.Errorf("request failed: %v", jsonResp.Error) + if got, want := textContent(t, res), "hi bar"; got != want { + t.Errorf("Result = %q, want %q", got, want) } } - checkRequest(`{"jsonrpc":"2.0","method":"tools/list","id":1,"params":{}}`) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + GetSessionID: func() string { return "" }, + Stateless: true, + }) + + // Test the default stateless mode. + t.Run("stateless", func(t *testing.T) { + testStreamableHandler(t, handler, requests) + testClientCompatibility(t, handler) + }) + + // Test a "distributed" variant of stateless mode, where it has non-empty + // session IDs, but is otherwise stateless. + // + // This can be used by tools to look up application state preserved across + // subsequent requests. + for i, req := range requests { + // Now, we want a session for all requests. + req.wantSessionID = true + requests[i] = req + } + distributableHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) + t.Run("distributed", func(t *testing.T) { + testStreamableHandler(t, distributableHandler, requests) + testClientCompatibility(t, handler) + }) +} - // Verify we can make another request without session ID - checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`) +func textContent(t *testing.T, res *CallToolResult) string { + t.Helper() + if len(res.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(res.Content)) + } + text, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatalf("Content[0] is %T, want *TextContent", res.Content[0]) + } + return text.Text } diff --git a/mcp/transport.go b/mcp/transport.go index 6d25de33..8018910b 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -40,8 +40,8 @@ type Transport interface { type Connection interface { // Read reads the next message to process off the connection. // - // Read need not be safe for concurrent use: Read is called in a - // concurrency-safe manner by the JSON-RPC library. + // Connections must allow Read to be called concurrently with Close. In + // particular, calling Close should unblock a Read waiting for input. Read(context.Context) (jsonrpc.Message, error) // Write writes a new message to the connection.