Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -826,12 +826,12 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
}

// The stream was interrupted or ended by the server. Attempt to reconnect.
newResp, reconnectErr := s.reconnect(lastEventID)
if reconnectErr != nil {
newResp, err := s.reconnect(lastEventID)
if err != nil {
// All reconnection attempts failed. Set the final error, close the
// connection, and exit the goroutine.
s.mu.Lock()
s.err = reconnectErr
s.err = err
s.mu.Unlock()
s.Close()
return
Expand Down Expand Up @@ -881,9 +881,9 @@ func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
case <-s.done:
return nil, fmt.Errorf("connection closed by client during reconnect")
case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)):
resp, reconnectErr := s.establishSSE(lastEventID)
if reconnectErr != nil {
finalErr = reconnectErr // Store the error and try again.
resp, err := s.establishSSE(lastEventID)
if err != nil {
finalErr = err // Store the error and try again.
continue
}

Expand All @@ -910,10 +910,7 @@ func isResumable(resp *http.Response) bool {
return false
}

if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
return false
}
return true
return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
}

// Close implements the [Connection] interface.
Expand Down Expand Up @@ -958,8 +955,7 @@ func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
}
req.Header.Set("Accept", "text/event-stream")

resp, err := s.client.Do(req)
return resp, err
return s.client.Do(req)
}

// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
Expand Down
146 changes: 59 additions & 87 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"github.com/modelcontextprotocol/go-sdk/jsonschema"
)

func TestStreamableTransports(t *testing.T) {
Expand Down Expand Up @@ -108,37 +109,37 @@ func TestStreamableTransports(t *testing.T) {
}
}

// TestClientReplayAfterProxyBreak verifies that the client can recover from a
// TestClientReplay verifies that the client can recover from a
// mid-stream network failure and receive replayed messages. It uses a proxy
// that is killed and restarted to simulate a recoverable network outage.
func TestClientReplayAfterProxyBreak(t *testing.T) {
func TestClientReplay(t *testing.T) {
notifications := make(chan string, 10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document how the channel buffer size affects correctness. (e.g. must be at least as large as the number of notifications sent in the tool handler on L125?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like buffer size has no effect since we collect them all. Removed the size of 10.

// 1. Configure the real MCP server.
server := NewServer(testImpl, nil)

// Use a channel to synchronize the server's message sending with the test's
// proxy-killing action.
serverReadyToKillProxy := make(chan struct{})
var serverClosed sync.WaitGroup
AddTool(server, &Tool{Name: "multiMessageTool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) {
go func() {
bgCtx := context.Background()
// Send the first two messages immediately.
_ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"})
_ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"})

// Signal the test that it can now kill the proxy.
serverClosed.Add(1)
close(serverReadyToKillProxy)
// Wait for the test to kill the proxy before sending the rest.
serverClosed.Wait()

// These messages should be queued for replay by the server after
// the client's connection drops.
_ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"})
_ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"})
}()
return &CallToolResultFor[any]{}, nil
})
serverClosed := make(chan struct{})
server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}},
func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
go func() {
bgCtx := context.Background()
// Send the first two messages immediately.
ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"})
ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"})

// Signal the test that it can now kill the proxy.
close(serverReadyToKillProxy)
<-serverClosed

// These messages should be queued for replay by the server after
// the client's connection drops.
ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"})
ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"})
}()
return &CallToolResult{}, nil
})
realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
defer realServer.Close()
realServerURL, err := url.Parse(realServer.URL)
Expand All @@ -152,52 +153,25 @@ func TestClientReplayAfterProxyBreak(t *testing.T) {
proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later.

// 3. Configure the client to connect to the proxy with default options.
clientTransport := NewStreamableClientTransport(proxy.URL, &StreamableClientTransportOptions{
ReconnectOptions: &StreamableReconnectOptions{
maxDelay: 50 * time.Millisecond,
MaxRetries: 5,
growFactor: 1.0,
initialDelay: 10 * time.Millisecond,
},
})

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// 4. Connect, perform handshake, and trigger the tool.
conn, err := clientTransport.Connect(ctx)
client := NewClient(testImpl, &ClientOptions{
ProgressNotificationHandler: func(ctx context.Context, cc *ClientSession, params *ProgressNotificationParams) {
notifications <- params.Message
}})
clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil))
if err != nil {
t.Fatalf("Connect() failed: %v", err)
}

// Perform handshake.
initReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(100), Method: "initialize", Params: mustMarshal(t, &InitializeParams{})}
if err := conn.Write(ctx, initReq); err != nil {
t.Fatalf("Write(initialize) failed: %v", err)
}
if _, err := conn.Read(ctx); err != nil {
t.Fatalf("Read(initialize resp) failed: %v", err)
}
if err := conn.Write(ctx, &jsonrpc.Request{Method: "initialized", Params: mustMarshal(t, &InitializedParams{})}); err != nil {
t.Fatalf("Write(initialized) failed: %v", err)
}

callReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "tools/call", Params: mustMarshal(t, &CallToolParams{Name: "multiMessageTool"})}
if err := conn.Write(ctx, callReq); err != nil {
t.Fatalf("Write(tool/call) failed: %v", err)
t.Fatalf("client.Connect() failed: %v", err)
}
defer clientSession.Close()
clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})

// 5. Read and verify messages until the server signals it's ready for the proxy kill.
receivedNotifications := readProgressNotifications(t, ctx, conn, 2)

wantReceived := []jsonrpc.Message{
&jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg1"})},
&jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg2"})},
}
transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() })
// 4. Read and verify messages until the server signals it's ready for the proxy kill.
receivedNotifications := readProgressNotifications(t, ctx, notifications, 2)

if diff := cmp.Diff(wantReceived, receivedNotifications, transform); diff != "" {
t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff)
wantReceived := []string{"msg1", "msg2"}
if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" {
t.Errorf("Received notifications mismatch (-want +got):\n%s", diff)
}

select {
Expand All @@ -207,13 +181,13 @@ func TestClientReplayAfterProxyBreak(t *testing.T) {
t.Fatalf("Context timed out before server was ready to kill proxy")
}

// 6. Simulate a total network failure by closing the proxy.
// 5. Simulate a total network failure by closing the proxy.
t.Log("--- Killing proxy to simulate network failure ---")
proxy.CloseClientConnections()
proxy.Close()
serverClosed.Done()
close(serverClosed)

// 7. Simulate network recovery by restarting the proxy on the same address.
// 6. Simulate network recovery by restarting the proxy on the same address.
t.Logf("--- Restarting proxy on %s ---", proxyAddr)
listener, err := net.Listen("tcp", proxyAddr)
if err != nil {
Expand All @@ -223,39 +197,37 @@ func TestClientReplayAfterProxyBreak(t *testing.T) {
go restartedProxy.Serve(listener)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this goroutine stopped?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restartedProxy.Close() will terminate it.

"Serve always returns a non-nil error and closes l. After [Server.Shutdown] or [Server.Close], the returned error is [ErrServerClosed]."

defer restartedProxy.Close()

// 8. Continue reading from the same connection object.
// 7. Continue reading from the same connection object.
// Its internal logic should successfully retry, reconnect to the new proxy,
// and receive the replayed messages.
recoveredNotifications := readProgressNotifications(t, ctx, conn, 2)
recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2)

// 9. Verify the correct messages were received on the recovered connection.
wantRecovered := []jsonrpc.Message{
&jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg3"})},
&jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg4"})},
}
// 8. Verify the correct messages were received on the recovered connection.
wantRecovered := []string{"msg3", "msg4"}

if diff := cmp.Diff(wantRecovered, recoveredNotifications, transform); diff != "" {
if diff := cmp.Diff(wantRecovered, recoveredNotifications); diff != "" {
t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff)
}
}

// Helper to read a specific number of progress notifications.
func readProgressNotifications(t *testing.T, ctx context.Context, conn Connection, count int) []jsonrpc.Message {
func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string {
t.Helper()
var notifications []jsonrpc.Message
for len(notifications) < count && ctx.Err() == nil {
msg, err := conn.Read(ctx)
if err != nil {
t.Fatalf("Failed to read notification: %v", err)
}
if req, ok := msg.(*jsonrpc.Request); ok && req.Method == "notifications/progress" {
notifications = append(notifications, req)
var collectedNotifications []string
for {
select {
case n := <-notifications:
collectedNotifications = append(collectedNotifications, n)
if len(collectedNotifications) == count {
return collectedNotifications
}
case <-ctx.Done():
if len(collectedNotifications) != count {
t.Fatalf("readProgressNotifications(): did not receive expected notifications, got %d, want %d", len(collectedNotifications), count)
}
return collectedNotifications
}
}
if len(notifications) != count {
t.Fatalf("Expected to read %d notifications, but got %d", count, len(notifications))
}
return notifications
}

func TestStreamableServerTransport(t *testing.T) {
Expand Down
Loading