diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 49902b00..5549ee1c 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -483,10 +483,32 @@ func (c *Connection) Cancel(id ID) { // Wait blocks until the connection is fully closed, but does not close it. func (c *Connection) Wait() error { + return c.wait(true) +} + +// wait for the connection to close, and aggregates the most cause of its +// termination, if abnormal. +// +// The fromWait argument allows this logic to be shared with Close, where we +// only want to expose the closeErr. +// +// (Previously, Wait also only returned the closeErr, which was misleading if +// the connection was broken for another reason). +func (c *Connection) wait(fromWait bool) error { var err error <-c.done c.updateInFlight(func(s *inFlightState) { - err = s.closeErr + if fromWait { + if !errors.Is(s.readErr, io.EOF) { + err = s.readErr + } + if err == nil && !errors.Is(s.writeErr, io.EOF) { + err = s.writeErr + } + } + if err == nil { + err = s.closeErr + } }) return err } @@ -502,8 +524,7 @@ func (c *Connection) Close() error { // Stop handling new requests, and interrupt the reader (by closing the // connection) as soon as the active requests finish. c.updateInFlight(func(s *inFlightState) { s.connClosing = true }) - - return c.Wait() + return c.wait(false) } // readIncoming collects inbound messages from the reader and delivers them, either responding diff --git a/mcp/streamable.go b/mcp/streamable.go index 1eef9a74..3e3cc6f8 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1119,6 +1119,17 @@ type streamableClientConn struct { sessionID string } +// errSessionMissing distinguishes if the session is known to not be present on +// the server (see [streamableClientConn.fail]). +// +// TODO(rfindley): should we expose this error value (and its corresponding +// API) to the user? +// +// The spec says that if the server returns 404, clients should reestablish +// a session. For now, we delegate that to the user, but do they need a way to +// differentiate a 'NotFound' error from other errors? +var errSessionMissing = errors.New("session not found") + var _ clientConnection = (*streamableClientConn)(nil) func (c *streamableClientConn) sessionUpdated(state clientSessionState) { @@ -1146,6 +1157,10 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // // If err is non-nil, it is terminal, and subsequent (or pending) Reads will // fail. +// +// If err wraps errSessionMissing, the failure indicates that the session is no +// longer present on the server, and no final DELETE will be performed when +// closing the connection. func (c *streamableClientConn) fail(err error) { if err != nil { c.failOnce.Do(func() { @@ -1193,9 +1208,19 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } + var requestSummary string + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + data, err := jsonrpc.EncodeMessage(msg) if err != nil { - return err + return fmt.Errorf("%s: %v", requestSummary, err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) @@ -1208,9 +1233,21 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e resp, err := c.client.Do(req) if err != nil { + return fmt.Errorf("%s: %v", requestSummary, err) + } + + // Section 2.5.3: "The server MAY terminate the session at any time, after + // which it MUST respond to requests containing that session ID with HTTP + // 404 Not Found." + if resp.StatusCode == http.StatusNotFound { + // Fail the session immediately, rather than relying on jsonrpc2 to fail + // (and close) it, because we want the call to Close to know that this + // session is missing (and therefore not send the DELETE). + err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing) + c.fail(err) + resp.Body.Close() return err } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() return fmt.Errorf("broken session: %v", resp.Status) @@ -1233,16 +1270,6 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } - var requestSummary string - switch msg := msg.(type) { - case *jsonrpc.Request: - requestSummary = fmt.Sprintf("sending %q", msg.Method) - case *jsonrpc.Response: - requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) - default: - panic("unreachable") - } - switch ct := resp.Header.Get("Content-Type"); ct { case "application/json": go c.handleJSON(requestSummary, resp) @@ -1333,6 +1360,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt resp.Body.Close() return } + // (see equivalent handling in [streamableClientConn.Write]). + if resp.StatusCode == http.StatusNotFound { + c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing)) + return + } if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode))) @@ -1423,13 +1455,17 @@ func (c *streamableClientConn) Close() error { c.cancel() close(c.done) - req, err := http.NewRequest(http.MethodDelete, c.url, nil) - if err != nil { - c.closeErr = err + if errors.Is(c.failure(), errSessionMissing) { + // If the session is missing, no need to delete it. } else { - c.setMCPHeaders(req) - if _, err := c.client.Do(req); err != nil { + req, err := http.NewRequest(http.MethodDelete, c.url, nil) + if err != nil { c.closeErr = err + } else { + c.setMCPHeaders(req) + if _, err := c.client.Do(req); err != nil { + c.closeErr = err + } } } }) diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index fe87b21c..001d3a64 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -29,7 +29,7 @@ type streamableRequestKey struct { type header map[string]string type streamableResponse struct { - header header + header header // response headers status int // or http.StatusOK body string // or "" optional bool // if set, request need not be sent @@ -187,6 +187,56 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { } } +func TestStreamableClientRedundantDelete(t *testing.T) { + ctx := context.Background() + + // The lifecycle test verifies various behavior of the streamable client + // initialization: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + status: http.StatusMethodNotAllowed, + optional: true, + }, + {"POST", "123", methodListTools}: { + status: http.StatusNotFound, + }, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + _, err = session.ListTools(ctx, nil) + if err == nil { + t.Errorf("Listing tools: got nil error, want non-nil") + } + _ = session.Wait() // must not hang + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } +} + func TestStreamableClientGETHandling(t *testing.T) { ctx := context.Background() diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a85fbec0..f0da3dc9 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -190,6 +190,80 @@ func TestStreamableTransports(t *testing.T) { } } +func TestStreamableServerShutdown(t *testing.T) { + ctx := context.Background() + + // This test checks that closing the streamable HTTP server actually results + // in client session termination, provided one of following holds: + // 1. The server is stateful, and therefore the hanging GET fails the connection. + // 2. The server is stateless, and the client uses a KeepAlive. + tests := []struct { + name string + stateless, keepalive bool + }{ + {"stateful", false, false}, + {"stateless with keepalive", true, true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := NewServer(testImpl, nil) + // Add a tool, just so we can check things are working. + AddTool(server, &Tool{Name: "greet"}, sayHi) + + handler := NewStreamableHTTPHandler( + func(req *http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: test.stateless}) + + // When we shut down the server, we need to explicitly close ongoing + // connections. Otherwise, the hanging GET may never terminate. + httpServer := httptest.NewUnstartedServer(handler) + httpServer.Config.RegisterOnShutdown(func() { + for session := range server.Sessions() { + session.Close() + } + }) + httpServer.Start() + defer httpServer.Close() + + // Connect and run a tool. + var opts ClientOptions + if test.keepalive { + opts.KeepAlive = 50 * time.Millisecond + } + client := NewClient(testImpl, &opts) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + MaxRetries: -1, // avoid slow tests during exponential retries + }, nil) + if err != nil { + t.Fatal(err) + } + + params := &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"name": "foo"}, + } + // Verify that we can call a tool. + if _, err := clientSession.CallTool(ctx, params); err != nil { + t.Fatalf("CallTool() failed: %v", err) + } + + // Shut down the server. Sessions should terminate. + go func() { + if err := httpServer.Config.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("closing http server: %v", err) + } + }() + + // Wait may return an error (after all, the connection failed), but it + // should not hang. + t.Log("Client waiting") + _ = clientSession.Wait() + }) + } +} + // TestClientReplay verifies that the client can recover from a mid-stream // network failure and receive replayed messages (if replay is configured). It // uses a proxy that is killed and restarted to simulate a recoverable network