Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions internal/jsonrpc2/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,32 @@ func (c *Connection) Cancel(id ID) {

// Wait blocks until the connection is fully closed, but does not close it.
func (c *Connection) Wait() error {
return c.wait(true)
}

// wait for the connection to close, and aggregates the most cause of its
// termination, if abnormal.
//
// The fromWait argument allows this logic to be shared with Close, where we
// only want to expose the closeErr.
//
// (Previously, Wait also only returned the closeErr, which was misleading if
// the connection was broken for another reason).
func (c *Connection) wait(fromWait bool) error {
var err error
<-c.done
c.updateInFlight(func(s *inFlightState) {
err = s.closeErr
if fromWait {
if !errors.Is(s.readErr, io.EOF) {
err = s.readErr
}
if err == nil && !errors.Is(s.writeErr, io.EOF) {
err = s.writeErr
}
}
if err == nil {
err = s.closeErr
}
})
return err
}
Expand All @@ -502,8 +524,7 @@ func (c *Connection) Close() error {
// Stop handling new requests, and interrupt the reader (by closing the
// connection) as soon as the active requests finish.
c.updateInFlight(func(s *inFlightState) { s.connClosing = true })

return c.Wait()
return c.wait(false)
}

// readIncoming collects inbound messages from the reader and delivers them, either responding
Expand Down
70 changes: 53 additions & 17 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,17 @@ type streamableClientConn struct {
sessionID string
}

// errSessionMissing distinguishes if the session is known to not be present on
// the server (see [streamableClientConn.fail]).
//
// TODO(rfindley): should we expose this error value (and its corresponding
// API) to the user?
//
// The spec says that if the server returns 404, clients should reestablish
// a session. For now, we delegate that to the user, but do they need a way to
// differentiate a 'NotFound' error from other errors?
var errSessionMissing = errors.New("session not found")

var _ clientConnection = (*streamableClientConn)(nil)

func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
Expand Down Expand Up @@ -1146,6 +1157,10 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
//
// If err is non-nil, it is terminal, and subsequent (or pending) Reads will
// fail.
//
// If err wraps errSessionMissing, the failure indicates that the session is no
// longer present on the server, and no final DELETE will be performed when
// closing the connection.
func (c *streamableClientConn) fail(err error) {
if err != nil {
c.failOnce.Do(func() {
Expand Down Expand Up @@ -1193,9 +1208,19 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
return err
}

var requestSummary string
switch msg := msg.(type) {
case *jsonrpc.Request:
requestSummary = fmt.Sprintf("sending %q", msg.Method)
case *jsonrpc.Response:
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
default:
panic("unreachable")
}

data, err := jsonrpc.EncodeMessage(msg)
if err != nil {
return err
return fmt.Errorf("%s: %v", requestSummary, err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data))
Expand All @@ -1208,9 +1233,21 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e

resp, err := c.client.Do(req)
if err != nil {
return fmt.Errorf("%s: %v", requestSummary, err)
}

// Section 2.5.3: "The server MAY terminate the session at any time, after
// which it MUST respond to requests containing that session ID with HTTP
// 404 Not Found."
if resp.StatusCode == http.StatusNotFound {
// Fail the session immediately, rather than relying on jsonrpc2 to fail
// (and close) it, because we want the call to Close to know that this
// session is missing (and therefore not send the DELETE).
err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing)
c.fail(err)
resp.Body.Close()
return err
}

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
resp.Body.Close()
return fmt.Errorf("broken session: %v", resp.Status)
Expand All @@ -1233,16 +1270,6 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
return nil
}

var requestSummary string
switch msg := msg.(type) {
case *jsonrpc.Request:
requestSummary = fmt.Sprintf("sending %q", msg.Method)
case *jsonrpc.Response:
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
default:
panic("unreachable")
}

switch ct := resp.Header.Get("Content-Type"); ct {
case "application/json":
go c.handleJSON(requestSummary, resp)
Expand Down Expand Up @@ -1333,6 +1360,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
resp.Body.Close()
return
}
// (see equivalent handling in [streamableClientConn.Write]).
if resp.StatusCode == http.StatusNotFound {
c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing))
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
resp.Body.Close()
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode)))
Expand Down Expand Up @@ -1423,13 +1455,17 @@ func (c *streamableClientConn) Close() error {
c.cancel()
close(c.done)

req, err := http.NewRequest(http.MethodDelete, c.url, nil)
if err != nil {
c.closeErr = err
if errors.Is(c.failure(), errSessionMissing) {
// If the session is missing, no need to delete it.
} else {
c.setMCPHeaders(req)
if _, err := c.client.Do(req); err != nil {
req, err := http.NewRequest(http.MethodDelete, c.url, nil)
if err != nil {
c.closeErr = err
} else {
c.setMCPHeaders(req)
if _, err := c.client.Do(req); err != nil {
c.closeErr = err
}
}
}
})
Expand Down
52 changes: 51 additions & 1 deletion mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type streamableRequestKey struct {
type header map[string]string

type streamableResponse struct {
header header
header header // response headers
status int // or http.StatusOK
body string // or ""
optional bool // if set, request need not be sent
Expand Down Expand Up @@ -187,6 +187,56 @@ func TestStreamableClientTransportLifecycle(t *testing.T) {
}
}

func TestStreamableClientRedundantDelete(t *testing.T) {
ctx := context.Background()

// The lifecycle test verifies various behavior of the streamable client
// initialization:
// - check that it can handle application/json responses
// - check that it sends the negotiated protocol version
fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
},
{"POST", "123", notificationInitialized}: {
status: http.StatusAccepted,
wantProtocolVersion: latestProtocolVersion,
},
{"GET", "123", ""}: {
status: http.StatusMethodNotAllowed,
optional: true,
},
{"POST", "123", methodListTools}: {
status: http.StatusNotFound,
},
},
}

httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
_, err = session.ListTools(ctx, nil)
if err == nil {
t.Errorf("Listing tools: got nil error, want non-nil")
}
_ = session.Wait() // must not hang
if missing := fake.missingRequests(); len(missing) > 0 {
t.Errorf("did not receive expected requests: %v", missing)
}
}

func TestStreamableClientGETHandling(t *testing.T) {
ctx := context.Background()

Expand Down
74 changes: 74 additions & 0 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,80 @@ func TestStreamableTransports(t *testing.T) {
}
}

func TestStreamableServerShutdown(t *testing.T) {
ctx := context.Background()

// This test checks that closing the streamable HTTP server actually results
// in client session termination, provided one of following holds:
// 1. The server is stateful, and therefore the hanging GET fails the connection.
// 2. The server is stateless, and the client uses a KeepAlive.
tests := []struct {
name string
stateless, keepalive bool
}{
{"stateful", false, false},
{"stateless with keepalive", true, true},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := NewServer(testImpl, nil)
// Add a tool, just so we can check things are working.
AddTool(server, &Tool{Name: "greet"}, sayHi)

handler := NewStreamableHTTPHandler(
func(req *http.Request) *Server { return server },
&StreamableHTTPOptions{Stateless: test.stateless})

// When we shut down the server, we need to explicitly close ongoing
// connections. Otherwise, the hanging GET may never terminate.
httpServer := httptest.NewUnstartedServer(handler)
httpServer.Config.RegisterOnShutdown(func() {
for session := range server.Sessions() {
session.Close()
}
})
httpServer.Start()
defer httpServer.Close()

// Connect and run a tool.
var opts ClientOptions
if test.keepalive {
opts.KeepAlive = 50 * time.Millisecond
}
client := NewClient(testImpl, &opts)
clientSession, err := client.Connect(ctx, &StreamableClientTransport{
Endpoint: httpServer.URL,
MaxRetries: -1, // avoid slow tests during exponential retries
}, nil)
if err != nil {
t.Fatal(err)
}

params := &CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "foo"},
}
// Verify that we can call a tool.
if _, err := clientSession.CallTool(ctx, params); err != nil {
t.Fatalf("CallTool() failed: %v", err)
}

// Shut down the server. Sessions should terminate.
go func() {
if err := httpServer.Config.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Errorf("closing http server: %v", err)
}
}()

// Wait may return an error (after all, the connection failed), but it
// should not hang.
t.Log("Client waiting")
_ = clientSession.Wait()
})
}
}

// TestClientReplay verifies that the client can recover from a mid-stream
// network failure and receive replayed messages (if replay is configured). It
// uses a proxy that is killed and restarted to simulate a recoverable network
Expand Down