Skip to content

Commit 4b5e210

Browse files
committed
mcp: fix sseClientConn to use custom HTTP client
The sseClientConn was using http.DefaultClient for requests instead of the custom client provided via SSEClientTransportOptions. This change stores the HTTP client in the sseClientConn struct and uses it for all requests. Also adds test to verify the custom client is properly used.
1 parent aebd244 commit 4b5e210

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

mcp/sse.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
391391

392392
// From here on, the stream takes ownership of resp.Body.
393393
s := &sseClientConn{
394+
client: httpClient,
394395
sseEndpoint: c.sseEndpoint,
395396
msgEndpoint: msgEndpoint,
396397
incoming: make(chan []byte, 100),
@@ -511,9 +512,10 @@ func scanEvents(r io.Reader) iter.Seq2[event, error] {
511512
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
512513
// - Close terminates the GET request.
513514
type sseClientConn struct {
514-
sseEndpoint *url.URL // SSE endpoint for the GET
515-
msgEndpoint *url.URL // session endpoint for POSTs
516-
incoming chan []byte // queue of incoming messages
515+
client *http.Client // HTTP client to use for requests
516+
sseEndpoint *url.URL // SSE endpoint for the GET
517+
msgEndpoint *url.URL // session endpoint for POSTs
518+
incoming chan []byte // queue of incoming messages
517519

518520
mu sync.Mutex
519521
body io.ReadCloser // body of the hanging GET
@@ -564,7 +566,7 @@ func (c *sseClientConn) Write(ctx context.Context, msg JSONRPCMessage) error {
564566
return err
565567
}
566568
req.Header.Set("Content-Type", "application/json")
567-
resp, err := http.DefaultClient.Do(req)
569+
resp, err := c.client.Do(req)
568570
if err != nil {
569571
return err
570572
}

mcp/sse_test.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/http"
1111
"net/http/httptest"
1212
"strings"
13+
"sync/atomic"
1314
"testing"
1415

1516
"github.com/google/go-cmp/cmp"
@@ -34,7 +35,19 @@ func TestSSEServer(t *testing.T) {
3435
httpServer := httptest.NewServer(sseHandler)
3536
defer httpServer.Close()
3637

37-
clientTransport := NewSSEClientTransport(httpServer.URL, nil)
38+
var customClientUsed int64
39+
customClient := &http.Client{
40+
Transport: &roundTripperFunc{
41+
fn: func(req *http.Request) (*http.Response, error) {
42+
atomic.AddInt64(&customClientUsed, 1)
43+
return http.DefaultTransport.RoundTrip(req)
44+
},
45+
},
46+
}
47+
48+
clientTransport := NewSSEClientTransport(httpServer.URL, &SSEClientTransportOptions{
49+
HTTPClient: customClient,
50+
})
3851

3952
c := NewClient("testClient", "v1.0.0", nil)
4053
cs, err := c.Connect(ctx, clientTransport)
@@ -61,6 +74,11 @@ func TestSSEServer(t *testing.T) {
6174
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
6275
}
6376

77+
// Verify that customClient was used
78+
if atomic.LoadInt64(&customClientUsed) == 0 {
79+
t.Error("Expected custom HTTP client to be used, but it wasn't")
80+
}
81+
6482
// Test that closing either end of the connection terminates the other
6583
// end.
6684
if closeServerFirst {
@@ -162,3 +180,12 @@ func TestScanEvents(t *testing.T) {
162180
})
163181
}
164182
}
183+
184+
// roundTripperFunc is a helper to create a custom RoundTripper
185+
type roundTripperFunc struct {
186+
fn func(*http.Request) (*http.Response, error)
187+
}
188+
189+
func (f *roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
190+
return f.fn(req)
191+
}

0 commit comments

Comments
 (0)