diff --git a/mcp/sse.go b/mcp/sse.go index 0a1f9b1b..d1b52599 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -391,6 +391,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // From here on, the stream takes ownership of resp.Body. s := &sseClientConn{ + client: httpClient, sseEndpoint: c.sseEndpoint, msgEndpoint: msgEndpoint, incoming: make(chan []byte, 100), @@ -511,9 +512,10 @@ func scanEvents(r io.Reader) iter.Seq2[event, error] { // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - sseEndpoint *url.URL // SSE endpoint for the GET - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + sseEndpoint *url.URL // SSE endpoint for the GET + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan []byte // queue of incoming messages mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -564,7 +566,7 @@ func (c *sseClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { return err } req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := c.client.Do(req) if err != nil { return err } diff --git a/mcp/sse_test.go b/mcp/sse_test.go index e1df9536..153185d3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "github.com/google/go-cmp/cmp" @@ -34,7 +35,17 @@ func TestSSEServer(t *testing.T) { httpServer := httptest.NewServer(sseHandler) defer httpServer.Close() - clientTransport := NewSSEClientTransport(httpServer.URL, nil) + var customClientUsed int64 + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + atomic.AddInt64(&customClientUsed, 1) + return http.DefaultTransport.RoundTrip(req) + }), + } + + clientTransport := NewSSEClientTransport(httpServer.URL, &SSEClientTransportOptions{ + HTTPClient: customClient, + }) c := NewClient("testClient", "v1.0.0", nil) cs, err := c.Connect(ctx, clientTransport) @@ -61,6 +72,11 @@ func TestSSEServer(t *testing.T) { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } + // Verify that customClient was used + if atomic.LoadInt64(&customClientUsed) == 0 { + t.Error("Expected custom HTTP client to be used, but it wasn't") + } + // Test that closing either end of the connection terminates the other // end. if closeServerFirst { @@ -162,3 +178,10 @@ func TestScanEvents(t *testing.T) { }) } } + +// roundTripperFunc is a helper to create a custom RoundTripper +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +}