-
Notifications
You must be signed in to change notification settings - Fork 293
mcp/streamable: add resumability for the Streamable transport #133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ import ( | |
| "github.com/google/go-cmp/cmp/cmpopts" | ||
| "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" | ||
| "github.com/modelcontextprotocol/go-sdk/jsonrpc" | ||
| "github.com/modelcontextprotocol/go-sdk/jsonschema" | ||
| ) | ||
|
|
||
| func TestStreamableTransports(t *testing.T) { | ||
|
|
@@ -108,37 +109,37 @@ func TestStreamableTransports(t *testing.T) { | |
| } | ||
| } | ||
|
|
||
| // TestClientReplayAfterProxyBreak verifies that the client can recover from a | ||
| // TestClientReplay verifies that the client can recover from a | ||
| // mid-stream network failure and receive replayed messages. It uses a proxy | ||
| // that is killed and restarted to simulate a recoverable network outage. | ||
| func TestClientReplayAfterProxyBreak(t *testing.T) { | ||
| func TestClientReplay(t *testing.T) { | ||
| notifications := make(chan string, 10) | ||
|
||
| // 1. Configure the real MCP server. | ||
| server := NewServer(testImpl, nil) | ||
|
|
||
| // Use a channel to synchronize the server's message sending with the test's | ||
| // proxy-killing action. | ||
| serverReadyToKillProxy := make(chan struct{}) | ||
| var serverClosed sync.WaitGroup | ||
| AddTool(server, &Tool{Name: "multiMessageTool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { | ||
| go func() { | ||
| bgCtx := context.Background() | ||
| // Send the first two messages immediately. | ||
| _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) | ||
| _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) | ||
|
|
||
| // Signal the test that it can now kill the proxy. | ||
| serverClosed.Add(1) | ||
| close(serverReadyToKillProxy) | ||
| // Wait for the test to kill the proxy before sending the rest. | ||
| serverClosed.Wait() | ||
|
|
||
| // These messages should be queued for replay by the server after | ||
| // the client's connection drops. | ||
| _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) | ||
| _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) | ||
| }() | ||
| return &CallToolResultFor[any]{}, nil | ||
| }) | ||
| serverClosed := make(chan struct{}) | ||
| server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, | ||
| func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { | ||
| go func() { | ||
| bgCtx := context.Background() | ||
| // Send the first two messages immediately. | ||
| ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) | ||
| ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) | ||
|
|
||
| // Signal the test that it can now kill the proxy. | ||
| close(serverReadyToKillProxy) | ||
| <-serverClosed | ||
|
|
||
| // These messages should be queued for replay by the server after | ||
| // the client's connection drops. | ||
| ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) | ||
| ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) | ||
| }() | ||
| return &CallToolResult{}, nil | ||
| }) | ||
| realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) | ||
| defer realServer.Close() | ||
| realServerURL, err := url.Parse(realServer.URL) | ||
|
|
@@ -152,52 +153,25 @@ func TestClientReplayAfterProxyBreak(t *testing.T) { | |
| proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. | ||
|
|
||
| // 3. Configure the client to connect to the proxy with default options. | ||
| clientTransport := NewStreamableClientTransport(proxy.URL, &StreamableClientTransportOptions{ | ||
| ReconnectOptions: &StreamableReconnectOptions{ | ||
| maxDelay: 50 * time.Millisecond, | ||
| MaxRetries: 5, | ||
| growFactor: 1.0, | ||
| initialDelay: 10 * time.Millisecond, | ||
| }, | ||
| }) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) | ||
| defer cancel() | ||
|
|
||
| // 4. Connect, perform handshake, and trigger the tool. | ||
| conn, err := clientTransport.Connect(ctx) | ||
| client := NewClient(testImpl, &ClientOptions{ | ||
| ProgressNotificationHandler: func(ctx context.Context, cc *ClientSession, params *ProgressNotificationParams) { | ||
| notifications <- params.Message | ||
| }}) | ||
| clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil)) | ||
| if err != nil { | ||
| t.Fatalf("Connect() failed: %v", err) | ||
| } | ||
|
|
||
| // Perform handshake. | ||
| initReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(100), Method: "initialize", Params: mustMarshal(t, &InitializeParams{})} | ||
| if err := conn.Write(ctx, initReq); err != nil { | ||
| t.Fatalf("Write(initialize) failed: %v", err) | ||
| } | ||
| if _, err := conn.Read(ctx); err != nil { | ||
| t.Fatalf("Read(initialize resp) failed: %v", err) | ||
| } | ||
| if err := conn.Write(ctx, &jsonrpc.Request{Method: "initialized", Params: mustMarshal(t, &InitializedParams{})}); err != nil { | ||
| t.Fatalf("Write(initialized) failed: %v", err) | ||
| } | ||
|
|
||
| callReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "tools/call", Params: mustMarshal(t, &CallToolParams{Name: "multiMessageTool"})} | ||
| if err := conn.Write(ctx, callReq); err != nil { | ||
| t.Fatalf("Write(tool/call) failed: %v", err) | ||
| t.Fatalf("client.Connect() failed: %v", err) | ||
| } | ||
| defer clientSession.Close() | ||
| clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) | ||
|
|
||
| // 5. Read and verify messages until the server signals it's ready for the proxy kill. | ||
| receivedNotifications := readProgressNotifications(t, ctx, conn, 2) | ||
|
|
||
| wantReceived := []jsonrpc.Message{ | ||
| &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg1"})}, | ||
| &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg2"})}, | ||
| } | ||
| transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) | ||
| // 4. Read and verify messages until the server signals it's ready for the proxy kill. | ||
| receivedNotifications := readProgressNotifications(t, ctx, notifications, 2) | ||
|
|
||
| if diff := cmp.Diff(wantReceived, receivedNotifications, transform); diff != "" { | ||
| t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) | ||
| wantReceived := []string{"msg1", "msg2"} | ||
| if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { | ||
| t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) | ||
| } | ||
|
|
||
| select { | ||
|
|
@@ -207,13 +181,13 @@ func TestClientReplayAfterProxyBreak(t *testing.T) { | |
| t.Fatalf("Context timed out before server was ready to kill proxy") | ||
| } | ||
|
|
||
| // 6. Simulate a total network failure by closing the proxy. | ||
| // 5. Simulate a total network failure by closing the proxy. | ||
| t.Log("--- Killing proxy to simulate network failure ---") | ||
| proxy.CloseClientConnections() | ||
| proxy.Close() | ||
| serverClosed.Done() | ||
| close(serverClosed) | ||
|
|
||
| // 7. Simulate network recovery by restarting the proxy on the same address. | ||
| // 6. Simulate network recovery by restarting the proxy on the same address. | ||
| t.Logf("--- Restarting proxy on %s ---", proxyAddr) | ||
| listener, err := net.Listen("tcp", proxyAddr) | ||
| if err != nil { | ||
|
|
@@ -223,39 +197,37 @@ func TestClientReplayAfterProxyBreak(t *testing.T) { | |
| go restartedProxy.Serve(listener) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is this goroutine stopped?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. restartedProxy.Close() will terminate it. "Serve always returns a non-nil error and closes l. After [Server.Shutdown] or [Server.Close], the returned error is [ErrServerClosed]." |
||
| defer restartedProxy.Close() | ||
|
|
||
| // 8. Continue reading from the same connection object. | ||
| // 7. Continue reading from the same connection object. | ||
| // Its internal logic should successfully retry, reconnect to the new proxy, | ||
| // and receive the replayed messages. | ||
| recoveredNotifications := readProgressNotifications(t, ctx, conn, 2) | ||
| recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2) | ||
|
|
||
| // 9. Verify the correct messages were received on the recovered connection. | ||
| wantRecovered := []jsonrpc.Message{ | ||
| &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg3"})}, | ||
| &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg4"})}, | ||
| } | ||
| // 8. Verify the correct messages were received on the recovered connection. | ||
| wantRecovered := []string{"msg3", "msg4"} | ||
|
|
||
| if diff := cmp.Diff(wantRecovered, recoveredNotifications, transform); diff != "" { | ||
| if diff := cmp.Diff(wantRecovered, recoveredNotifications); diff != "" { | ||
| t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) | ||
| } | ||
| } | ||
|
|
||
| // Helper to read a specific number of progress notifications. | ||
| func readProgressNotifications(t *testing.T, ctx context.Context, conn Connection, count int) []jsonrpc.Message { | ||
| func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { | ||
| t.Helper() | ||
| var notifications []jsonrpc.Message | ||
| for len(notifications) < count && ctx.Err() == nil { | ||
| msg, err := conn.Read(ctx) | ||
| if err != nil { | ||
| t.Fatalf("Failed to read notification: %v", err) | ||
| } | ||
| if req, ok := msg.(*jsonrpc.Request); ok && req.Method == "notifications/progress" { | ||
| notifications = append(notifications, req) | ||
| var collectedNotifications []string | ||
| for { | ||
| select { | ||
| case n := <-notifications: | ||
| collectedNotifications = append(collectedNotifications, n) | ||
| if len(collectedNotifications) == count { | ||
| return collectedNotifications | ||
| } | ||
| case <-ctx.Done(): | ||
| if len(collectedNotifications) != count { | ||
| t.Fatalf("readProgressNotifications(): did not receive expected notifications, got %d, want %d", len(collectedNotifications), count) | ||
| } | ||
| return collectedNotifications | ||
| } | ||
| } | ||
| if len(notifications) != count { | ||
| t.Fatalf("Expected to read %d notifications, but got %d", count, len(notifications)) | ||
| } | ||
| return notifications | ||
| } | ||
|
|
||
| func TestStreamableServerTransport(t *testing.T) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.