Skip to content

Commit d936e1f

Browse files
authored
[v18][flaky] TestReplaceHTTPResponse and TestHTTPReaderWriter to use synctest (#61308)
1 parent 2a9476b commit d936e1f

File tree

1 file changed

+138
-90
lines changed

1 file changed

+138
-90
lines changed

lib/utils/mcputils/http_test.go

Lines changed: 138 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build go1.24 && enablesynctest
2+
13
/*
24
* Teleport
35
* Copyright (C) 2025 Gravitational, Inc.
@@ -26,10 +28,11 @@ import (
2628
"log/slog"
2729
"maps"
2830
"net/http"
31+
"os"
2932
"sync"
3033
"sync/atomic"
3134
"testing"
32-
"time"
35+
"testing/synctest"
3336

3437
"github.com/gravitational/trace"
3538
mcpclient "github.com/mark3labs/mcp-go/client"
@@ -39,66 +42,86 @@ import (
3942
"github.com/stretchr/testify/assert"
4043
"github.com/stretchr/testify/require"
4144

45+
listenerutils "github.com/gravitational/teleport/lib/utils/listener"
46+
"github.com/gravitational/teleport/lib/utils/log/logtest"
4247
"github.com/gravitational/teleport/lib/utils/mcptest"
4348
)
4449

50+
func TestMain(m *testing.M) {
51+
logtest.InitLogger(testing.Verbose)
52+
os.Exit(m.Run())
53+
}
54+
4555
func TestReplaceHTTPResponse(t *testing.T) {
46-
t.Parallel()
47-
ctx := t.Context()
48-
49-
// Set up a server.
50-
mcpServer := mcptest.NewServerWithVersion("11.22.33")
51-
httpServer := mcpserver.NewTestStreamableHTTPServer(mcpServer)
52-
t.Cleanup(httpServer.Close)
53-
54-
// Set up a client with custom transport which calls "ReplaceHTTPResponse".
55-
httpClientTransport := newTestReplaceHTTPResponseTransport()
56-
mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
57-
httpServer.URL+"/mcp",
58-
mcpclienttransport.WithHTTPBasicClient(
59-
&http.Client{Transport: httpClientTransport},
60-
),
61-
mcpclienttransport.WithContinuousListening(),
62-
)
63-
require.NoError(t, err)
64-
client := mcpclient.NewClient(mcpClientTransport)
65-
require.NoError(t, client.Start(ctx))
66-
67-
// Initialize client and call a tool.
68-
result := mcptest.MustInitializeClient(t, client)
69-
require.Equal(t, "111.222.333", result.ServerInfo.Version)
70-
mcptest.MustCallServerTool(t, client)
71-
require.Equal(t, uint32(2), httpClientTransport.countMCPResponse.Load())
72-
73-
// Send notifications from server. Notifications will be sent through SSE.
74-
require.EventuallyWithT(t, func(collect *assert.CollectT) {
75-
assert.Greater(collect, httpClientTransport.getCountMethods()["GET"], 0)
76-
}, 2*time.Second, 100*time.Millisecond, "client SSE connected")
77-
mcpServer.SendNotificationToAllClients("notifications/test", nil)
78-
mcpServer.SendNotificationToAllClients("notifications/test", nil)
79-
require.EventuallyWithT(t, func(collect *assert.CollectT) {
80-
assert.Equal(collect, uint32(2), httpClientTransport.countMCPNotification.Load())
81-
}, 2*time.Second, 100*time.Millisecond, "expected to receive notification")
82-
83-
// Close client and count the requests.
84-
require.NoError(t, client.Close())
85-
require.Equal(t, map[string]int{
86-
"GET": 1, // For listening on SSE events.
87-
"POST": 3, // "initialize", "notifications/initialize", and "tools/call".
88-
"DELETE": 1, // Close session.
89-
}, httpClientTransport.getCountMethods())
56+
synctest.Run(func() {
57+
// TODO(greedy52) replace [synctest.Run + t.Run] with [synctest.Test]
58+
// when switching to go1.25. The extra layer t.Run is a temp workaround
59+
// to synctest deadlock with go 1.24 since t.Context() is used in many
60+
// places for background goroutines to exit.
61+
t.Run("synctest", func(t *testing.T) {
62+
ctx := t.Context()
63+
64+
// Set up a server.
65+
mcpServer := mcptest.NewServerWithVersion("11.22.33")
66+
listener := makeHTTPServerWithInMemoryListener(t, mcpServer)
67+
68+
// Set up a client with custom transport which calls "ReplaceHTTPResponse".
69+
httpClientTransport := newTestReplaceHTTPResponseTransport(listener)
70+
mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
71+
"http://memory/mcp",
72+
mcpclienttransport.WithHTTPBasicClient(
73+
&http.Client{Transport: httpClientTransport},
74+
),
75+
mcpclienttransport.WithContinuousListening(),
76+
)
77+
require.NoError(t, err)
78+
var countMCPNotification atomic.Uint32
79+
client := mcpclient.NewClient(mcpClientTransport)
80+
client.OnNotification(func(notification mcp.JSONRPCNotification) {
81+
countMCPNotification.Add(1)
82+
})
83+
require.NoError(t, client.Start(ctx))
84+
85+
// Initialize client and call a tool.
86+
result := mcptest.MustInitializeClient(t, client)
87+
require.Equal(t, "111.222.333", result.ServerInfo.Version)
88+
mcptest.MustCallServerTool(t, client)
89+
require.Equal(t, uint32(2), httpClientTransport.countMCPResponse.Load())
90+
91+
// Send notifications from server. Notifications will be sent through SSE.
92+
synctest.Wait() // Wait for client to establish the GET connection.
93+
mcpServer.SendNotificationToAllClients("notifications/test", nil)
94+
mcpServer.SendNotificationToAllClients("notifications/test", nil)
95+
synctest.Wait() // Wait for client to receive notifications.
96+
require.Equal(t, uint32(2), countMCPNotification.Load())
97+
98+
// Close client and count the requests.
99+
require.NoError(t, client.Close())
100+
synctest.Wait()
101+
require.Equal(t, map[string]int{
102+
"GET": 1, // For listening on SSE events.
103+
"POST": 3, // "initialize", "notifications/initialize", and "tools/call".
104+
"DELETE": 1, // Close session.
105+
}, httpClientTransport.getCountMethods())
106+
})
107+
})
90108
}
91109

92110
type testReplaceHTTPResponseTransport struct {
93-
countMethods map[string]int
94-
countMethodsMu sync.Mutex
95-
countMCPResponse atomic.Uint32
96-
countMCPNotification atomic.Uint32
111+
countMethods map[string]int
112+
countMethodsMu sync.Mutex
113+
countMCPResponse atomic.Uint32
114+
client http.Client
97115
}
98116

99-
func newTestReplaceHTTPResponseTransport() *testReplaceHTTPResponseTransport {
117+
func newTestReplaceHTTPResponseTransport(inMemoryListener *listenerutils.InMemoryListener) *testReplaceHTTPResponseTransport {
100118
return &testReplaceHTTPResponseTransport{
101119
countMethods: make(map[string]int),
120+
client: http.Client{
121+
Transport: &http.Transport{
122+
DialContext: inMemoryListener.DialContext,
123+
},
124+
},
102125
}
103126
}
104127

@@ -113,7 +136,7 @@ func (t *testReplaceHTTPResponseTransport) RoundTrip(r *http.Request) (*http.Res
113136
t.countMethods[r.Method]++
114137
t.countMethodsMu.Unlock()
115138

116-
resp, err := http.DefaultClient.Do(r)
139+
resp, err := t.client.Do(r)
117140
if err != nil {
118141
return nil, trace.Wrap(err)
119142
}
@@ -147,55 +170,65 @@ func (t *testReplaceHTTPResponseTransport) ProcessResponse(_ context.Context, re
147170
}
148171

149172
func (t *testReplaceHTTPResponseTransport) ProcessNotification(_ context.Context, notification *JSONRPCNotification) mcp.JSONRPCMessage {
150-
t.countMCPNotification.Add(1)
151173
return notification
152174
}
153175

154176
func TestHTTPReaderWriter(t *testing.T) {
155-
t.Parallel()
156-
ctx := t.Context()
177+
synctest.Run(func() {
178+
// TODO(greedy52) replace [synctest.Run + t.Run] with [synctest.Test]
179+
// when switching to go1.25. The extra layer t.Run is a temp workaround
180+
// to synctest deadlock with go 1.24 since t.Context() is used in many
181+
// places for background goroutines to exit.
182+
t.Run("synctest", func(t *testing.T) {
183+
ctx := t.Context()
157184

158-
// Set up an MCP server.
159-
mcpServer := mcptest.NewServer()
160-
httpServer := mcpserver.NewTestStreamableHTTPServer(mcpServer)
161-
t.Cleanup(httpServer.Close)
185+
// Set up an MCP server.
186+
mcpServer := mcptest.NewServer()
187+
listener := makeHTTPServerWithInMemoryListener(t, mcpServer)
162188

163-
// Create a proxy that converts from stdio to HTTP.
164-
clientStdin, writeToClient := io.Pipe()
165-
readFromClient, clientStdout := io.Pipe()
166-
t.Cleanup(func() {
167-
assert.NoError(t, trace.NewAggregate(
168-
clientStdin.Close(), writeToClient.Close(),
169-
readFromClient.Close(), clientStdout.Close(),
170-
))
171-
})
189+
// Create a proxy that converts from stdio to HTTP.
190+
clientStdin, writeToClient := io.Pipe()
191+
readFromClient, clientStdout := io.Pipe()
192+
t.Cleanup(func() {
193+
assert.NoError(t, trace.NewAggregate(
194+
clientStdin.Close(), writeToClient.Close(),
195+
readFromClient.Close(), clientStdout.Close(),
196+
))
197+
})
172198

173-
serverReaderWriter, err := NewHTTPReaderWriter(ctx, httpServer.URL, mcpclienttransport.WithContinuousListening())
174-
require.NoError(t, err)
175-
defer serverReaderWriter.Close() // Send DELETE before server is shutdown
199+
serverReaderWriter, err := NewHTTPReaderWriter(
200+
ctx,
201+
"http://memory/mcp",
202+
mcpclienttransport.WithContinuousListening(),
203+
mcpclienttransport.WithHTTPBasicClient(listener.MakeHTTPClient()),
204+
)
205+
require.NoError(t, err)
206+
defer serverReaderWriter.Close() // Send DELETE before server is shutdown
207+
208+
clientTransportReader := NewStdioReader(readFromClient)
209+
clientWriter := NewStdioMessageWriter(writeToClient)
210+
proxyReaderWriter(t, clientTransportReader, clientWriter, serverReaderWriter, serverReaderWriter)
176211

177-
clientTransportReader := NewStdioReader(readFromClient)
178-
clientWriter := NewStdioMessageWriter(writeToClient)
179-
proxyReaderWriter(t, clientTransportReader, clientWriter, serverReaderWriter, serverReaderWriter)
212+
// Make a "high-level" stdio MCP client and test the proxy.
213+
var receivedNotifications []mcp.JSONRPCNotification
214+
stdioClient := mcptest.NewStdioClient(t, clientStdin, clientStdout)
215+
stdioClient.OnNotification(func(notification mcp.JSONRPCNotification) {
216+
receivedNotifications = append(receivedNotifications, notification)
217+
})
218+
mcptest.MustInitializeClient(t, stdioClient)
219+
mcptest.MustCallServerTool(t, stdioClient)
180220

181-
// Make a "high-level" stdio MCP client and test the proxy.
182-
notificationsChan := make(chan mcp.JSONRPCNotification, 1)
183-
stdioClient := mcptest.NewStdioClient(t, clientStdin, clientStdout)
184-
stdioClient.OnNotification(func(notification mcp.JSONRPCNotification) {
185-
notificationsChan <- notification
221+
// Test listening notifications from server.
222+
// First do a synctest.Wait until the client establish the listening
223+
// stream before sending the notification from the server. Then do
224+
// another synctest.Wait for the client to receive the notification.
225+
synctest.Wait()
226+
mcpServer.SendNotificationToAllClients("notifications/test", nil)
227+
synctest.Wait()
228+
require.Len(t, receivedNotifications, 1)
229+
require.Equal(t, "notifications/test", receivedNotifications[0].Notification.Method)
230+
})
186231
})
187-
mcptest.MustInitializeClient(t, stdioClient)
188-
mcptest.MustCallServerTool(t, stdioClient)
189-
190-
// Test listening notifications from server.
191-
mcpServer.SendNotificationToAllClients("notifications/test", nil)
192-
select {
193-
case notification := <-notificationsChan:
194-
require.NotNil(t, notification)
195-
require.Equal(t, "notifications/test", notification.Notification.Method)
196-
case <-time.After(time.Second):
197-
require.Fail(t, "timeout waiting for notification")
198-
}
199232
}
200233

201234
func proxyReaderWriter(
@@ -214,3 +247,18 @@ func proxyReaderWriter(
214247
go clientMessageReader.Run(t.Context())
215248
go serverMessageReader.Run(t.Context())
216249
}
250+
251+
// makeHTTPServerWithInMemoryListener starts a streamable-HTTP MCP server using
252+
// an InMemoryListener. InMemoryListener is good to use with synctest.
253+
func makeHTTPServerWithInMemoryListener(t *testing.T, mcpServer *mcpserver.MCPServer) *listenerutils.InMemoryListener {
254+
t.Helper()
255+
listener := listenerutils.NewInMemoryListener()
256+
httpServer := &http.Server{
257+
Handler: mcpserver.NewStreamableHTTPServer(mcpServer),
258+
}
259+
go httpServer.Serve(listener)
260+
t.Cleanup(func() {
261+
httpServer.Close()
262+
})
263+
return listener
264+
}

0 commit comments

Comments
 (0)