1+ //go:build go1.24 && enablesynctest
2+
13/*
24 * Teleport
35 * Copyright (C) 2025 Gravitational, Inc.
@@ -26,10 +28,11 @@ import (
2628 "log/slog"
2729 "maps"
2830 "net/http"
31+ "os"
2932 "sync"
3033 "sync/atomic"
3134 "testing"
32- "time "
35+ "testing/synctest "
3336
3437 "github.com/gravitational/trace"
3538 mcpclient "github.com/mark3labs/mcp-go/client"
@@ -39,66 +42,86 @@ import (
3942 "github.com/stretchr/testify/assert"
4043 "github.com/stretchr/testify/require"
4144
45+ listenerutils "github.com/gravitational/teleport/lib/utils/listener"
46+ "github.com/gravitational/teleport/lib/utils/log/logtest"
4247 "github.com/gravitational/teleport/lib/utils/mcptest"
4348)
4449
50+ func TestMain (m * testing.M ) {
51+ logtest .InitLogger (testing .Verbose )
52+ os .Exit (m .Run ())
53+ }
54+
4555func TestReplaceHTTPResponse (t * testing.T ) {
46- t .Parallel ()
47- ctx := t .Context ()
48-
49- // Set up a server.
50- mcpServer := mcptest .NewServerWithVersion ("11.22.33" )
51- httpServer := mcpserver .NewTestStreamableHTTPServer (mcpServer )
52- t .Cleanup (httpServer .Close )
53-
54- // Set up a client with custom transport which calls "ReplaceHTTPResponse".
55- httpClientTransport := newTestReplaceHTTPResponseTransport ()
56- mcpClientTransport , err := mcpclienttransport .NewStreamableHTTP (
57- httpServer .URL + "/mcp" ,
58- mcpclienttransport .WithHTTPBasicClient (
59- & http.Client {Transport : httpClientTransport },
60- ),
61- mcpclienttransport .WithContinuousListening (),
62- )
63- require .NoError (t , err )
64- client := mcpclient .NewClient (mcpClientTransport )
65- require .NoError (t , client .Start (ctx ))
66-
67- // Initialize client and call a tool.
68- result := mcptest .MustInitializeClient (t , client )
69- require .Equal (t , "111.222.333" , result .ServerInfo .Version )
70- mcptest .MustCallServerTool (t , client )
71- require .Equal (t , uint32 (2 ), httpClientTransport .countMCPResponse .Load ())
72-
73- // Send notifications from server. Notifications will be sent through SSE.
74- require .EventuallyWithT (t , func (collect * assert.CollectT ) {
75- assert .Greater (collect , httpClientTransport .getCountMethods ()["GET" ], 0 )
76- }, 2 * time .Second , 100 * time .Millisecond , "client SSE connected" )
77- mcpServer .SendNotificationToAllClients ("notifications/test" , nil )
78- mcpServer .SendNotificationToAllClients ("notifications/test" , nil )
79- require .EventuallyWithT (t , func (collect * assert.CollectT ) {
80- assert .Equal (collect , uint32 (2 ), httpClientTransport .countMCPNotification .Load ())
81- }, 2 * time .Second , 100 * time .Millisecond , "expected to receive notification" )
82-
83- // Close client and count the requests.
84- require .NoError (t , client .Close ())
85- require .Equal (t , map [string ]int {
86- "GET" : 1 , // For listening on SSE events.
87- "POST" : 3 , // "initialize", "notifications/initialize", and "tools/call".
88- "DELETE" : 1 , // Close session.
89- }, httpClientTransport .getCountMethods ())
56+ synctest .Run (func () {
57+ // TODO(greedy52) replace [synctest.Run + t.Run] with [synctest.Test]
58+ // when switching to go1.25. The extra layer t.Run is a temp workaround
59+ // to synctest deadlock with go 1.24 since t.Context() is used in many
60+ // places for background goroutines to exit.
61+ t .Run ("synctest" , func (t * testing.T ) {
62+ ctx := t .Context ()
63+
64+ // Set up a server.
65+ mcpServer := mcptest .NewServerWithVersion ("11.22.33" )
66+ listener := makeHTTPServerWithInMemoryListener (t , mcpServer )
67+
68+ // Set up a client with custom transport which calls "ReplaceHTTPResponse".
69+ httpClientTransport := newTestReplaceHTTPResponseTransport (listener )
70+ mcpClientTransport , err := mcpclienttransport .NewStreamableHTTP (
71+ "http://memory/mcp" ,
72+ mcpclienttransport .WithHTTPBasicClient (
73+ & http.Client {Transport : httpClientTransport },
74+ ),
75+ mcpclienttransport .WithContinuousListening (),
76+ )
77+ require .NoError (t , err )
78+ var countMCPNotification atomic.Uint32
79+ client := mcpclient .NewClient (mcpClientTransport )
80+ client .OnNotification (func (notification mcp.JSONRPCNotification ) {
81+ countMCPNotification .Add (1 )
82+ })
83+ require .NoError (t , client .Start (ctx ))
84+
85+ // Initialize client and call a tool.
86+ result := mcptest .MustInitializeClient (t , client )
87+ require .Equal (t , "111.222.333" , result .ServerInfo .Version )
88+ mcptest .MustCallServerTool (t , client )
89+ require .Equal (t , uint32 (2 ), httpClientTransport .countMCPResponse .Load ())
90+
91+ // Send notifications from server. Notifications will be sent through SSE.
92+ synctest .Wait () // Wait for client to establish the GET connection.
93+ mcpServer .SendNotificationToAllClients ("notifications/test" , nil )
94+ mcpServer .SendNotificationToAllClients ("notifications/test" , nil )
95+ synctest .Wait () // Wait for client to receive notifications.
96+ require .Equal (t , uint32 (2 ), countMCPNotification .Load ())
97+
98+ // Close client and count the requests.
99+ require .NoError (t , client .Close ())
100+ synctest .Wait ()
101+ require .Equal (t , map [string ]int {
102+ "GET" : 1 , // For listening on SSE events.
103+ "POST" : 3 , // "initialize", "notifications/initialize", and "tools/call".
104+ "DELETE" : 1 , // Close session.
105+ }, httpClientTransport .getCountMethods ())
106+ })
107+ })
90108}
91109
92110type testReplaceHTTPResponseTransport struct {
93- countMethods map [string ]int
94- countMethodsMu sync.Mutex
95- countMCPResponse atomic.Uint32
96- countMCPNotification atomic. Uint32
111+ countMethods map [string ]int
112+ countMethodsMu sync.Mutex
113+ countMCPResponse atomic.Uint32
114+ client http. Client
97115}
98116
99- func newTestReplaceHTTPResponseTransport () * testReplaceHTTPResponseTransport {
117+ func newTestReplaceHTTPResponseTransport (inMemoryListener * listenerutils. InMemoryListener ) * testReplaceHTTPResponseTransport {
100118 return & testReplaceHTTPResponseTransport {
101119 countMethods : make (map [string ]int ),
120+ client : http.Client {
121+ Transport : & http.Transport {
122+ DialContext : inMemoryListener .DialContext ,
123+ },
124+ },
102125 }
103126}
104127
@@ -113,7 +136,7 @@ func (t *testReplaceHTTPResponseTransport) RoundTrip(r *http.Request) (*http.Res
113136 t .countMethods [r .Method ]++
114137 t .countMethodsMu .Unlock ()
115138
116- resp , err := http . DefaultClient .Do (r )
139+ resp , err := t . client .Do (r )
117140 if err != nil {
118141 return nil , trace .Wrap (err )
119142 }
@@ -147,55 +170,65 @@ func (t *testReplaceHTTPResponseTransport) ProcessResponse(_ context.Context, re
147170}
148171
149172func (t * testReplaceHTTPResponseTransport ) ProcessNotification (_ context.Context , notification * JSONRPCNotification ) mcp.JSONRPCMessage {
150- t .countMCPNotification .Add (1 )
151173 return notification
152174}
153175
154176func TestHTTPReaderWriter (t * testing.T ) {
155- t .Parallel ()
156- ctx := t .Context ()
177+ synctest .Run (func () {
178+ // TODO(greedy52) replace [synctest.Run + t.Run] with [synctest.Test]
179+ // when switching to go1.25. The extra layer t.Run is a temp workaround
180+ // to synctest deadlock with go 1.24 since t.Context() is used in many
181+ // places for background goroutines to exit.
182+ t .Run ("synctest" , func (t * testing.T ) {
183+ ctx := t .Context ()
157184
158- // Set up an MCP server.
159- mcpServer := mcptest .NewServer ()
160- httpServer := mcpserver .NewTestStreamableHTTPServer (mcpServer )
161- t .Cleanup (httpServer .Close )
185+ // Set up an MCP server.
186+ mcpServer := mcptest .NewServer ()
187+ listener := makeHTTPServerWithInMemoryListener (t , mcpServer )
162188
163- // Create a proxy that converts from stdio to HTTP.
164- clientStdin , writeToClient := io .Pipe ()
165- readFromClient , clientStdout := io .Pipe ()
166- t .Cleanup (func () {
167- assert .NoError (t , trace .NewAggregate (
168- clientStdin .Close (), writeToClient .Close (),
169- readFromClient .Close (), clientStdout .Close (),
170- ))
171- })
189+ // Create a proxy that converts from stdio to HTTP.
190+ clientStdin , writeToClient := io .Pipe ()
191+ readFromClient , clientStdout := io .Pipe ()
192+ t .Cleanup (func () {
193+ assert .NoError (t , trace .NewAggregate (
194+ clientStdin .Close (), writeToClient .Close (),
195+ readFromClient .Close (), clientStdout .Close (),
196+ ))
197+ })
172198
173- serverReaderWriter , err := NewHTTPReaderWriter (ctx , httpServer .URL , mcpclienttransport .WithContinuousListening ())
174- require .NoError (t , err )
175- defer serverReaderWriter .Close () // Send DELETE before server is shutdown
199+ serverReaderWriter , err := NewHTTPReaderWriter (
200+ ctx ,
201+ "http://memory/mcp" ,
202+ mcpclienttransport .WithContinuousListening (),
203+ mcpclienttransport .WithHTTPBasicClient (listener .MakeHTTPClient ()),
204+ )
205+ require .NoError (t , err )
206+ defer serverReaderWriter .Close () // Send DELETE before server is shutdown
207+
208+ clientTransportReader := NewStdioReader (readFromClient )
209+ clientWriter := NewStdioMessageWriter (writeToClient )
210+ proxyReaderWriter (t , clientTransportReader , clientWriter , serverReaderWriter , serverReaderWriter )
176211
177- clientTransportReader := NewStdioReader (readFromClient )
178- clientWriter := NewStdioMessageWriter (writeToClient )
179- proxyReaderWriter (t , clientTransportReader , clientWriter , serverReaderWriter , serverReaderWriter )
212+ // Make a "high-level" stdio MCP client and test the proxy.
213+ var receivedNotifications []mcp.JSONRPCNotification
214+ stdioClient := mcptest .NewStdioClient (t , clientStdin , clientStdout )
215+ stdioClient .OnNotification (func (notification mcp.JSONRPCNotification ) {
216+ receivedNotifications = append (receivedNotifications , notification )
217+ })
218+ mcptest .MustInitializeClient (t , stdioClient )
219+ mcptest .MustCallServerTool (t , stdioClient )
180220
181- // Make a "high-level" stdio MCP client and test the proxy.
182- notificationsChan := make (chan mcp.JSONRPCNotification , 1 )
183- stdioClient := mcptest .NewStdioClient (t , clientStdin , clientStdout )
184- stdioClient .OnNotification (func (notification mcp.JSONRPCNotification ) {
185- notificationsChan <- notification
221+ // Test listening notifications from server.
222+ // First do a synctest.Wait until the client establish the listening
223+ // stream before sending the notification from the server. Then do
224+ // another synctest.Wait for the client to receive the notification.
225+ synctest .Wait ()
226+ mcpServer .SendNotificationToAllClients ("notifications/test" , nil )
227+ synctest .Wait ()
228+ require .Len (t , receivedNotifications , 1 )
229+ require .Equal (t , "notifications/test" , receivedNotifications [0 ].Notification .Method )
230+ })
186231 })
187- mcptest .MustInitializeClient (t , stdioClient )
188- mcptest .MustCallServerTool (t , stdioClient )
189-
190- // Test listening notifications from server.
191- mcpServer .SendNotificationToAllClients ("notifications/test" , nil )
192- select {
193- case notification := <- notificationsChan :
194- require .NotNil (t , notification )
195- require .Equal (t , "notifications/test" , notification .Notification .Method )
196- case <- time .After (time .Second ):
197- require .Fail (t , "timeout waiting for notification" )
198- }
199232}
200233
201234func proxyReaderWriter (
@@ -214,3 +247,18 @@ func proxyReaderWriter(
214247 go clientMessageReader .Run (t .Context ())
215248 go serverMessageReader .Run (t .Context ())
216249}
250+
251+ // makeHTTPServerWithInMemoryListener starts a streamable-HTTP MCP server using
252+ // an InMemoryListener. InMemoryListener is good to use with synctest.
253+ func makeHTTPServerWithInMemoryListener (t * testing.T , mcpServer * mcpserver.MCPServer ) * listenerutils.InMemoryListener {
254+ t .Helper ()
255+ listener := listenerutils .NewInMemoryListener ()
256+ httpServer := & http.Server {
257+ Handler : mcpserver .NewStreamableHTTPServer (mcpServer ),
258+ }
259+ go httpServer .Serve (listener )
260+ t .Cleanup (func () {
261+ httpServer .Close ()
262+ })
263+ return listener
264+ }
0 commit comments