Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {

// From here on, the stream takes ownership of resp.Body.
s := &sseClientConn{
client: httpClient,
sseEndpoint: c.sseEndpoint,
msgEndpoint: msgEndpoint,
incoming: make(chan []byte, 100),
Expand Down Expand Up @@ -511,9 +512,10 @@ func scanEvents(r io.Reader) iter.Seq2[event, error] {
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
// - Close terminates the GET request.
type sseClientConn struct {
sseEndpoint *url.URL // SSE endpoint for the GET
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages
client *http.Client // HTTP client to use for requests
sseEndpoint *url.URL // SSE endpoint for the GET
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages

mu sync.Mutex
body io.ReadCloser // body of the hanging GET
Expand Down Expand Up @@ -564,7 +566,7 @@ func (c *sseClientConn) Write(ctx context.Context, msg JSONRPCMessage) error {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return err
}
Expand Down
25 changes: 24 additions & 1 deletion mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -34,7 +35,17 @@ func TestSSEServer(t *testing.T) {
httpServer := httptest.NewServer(sseHandler)
defer httpServer.Close()

clientTransport := NewSSEClientTransport(httpServer.URL, nil)
var customClientUsed int64
customClient := &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt64(&customClientUsed, 1)
return http.DefaultTransport.RoundTrip(req)
}),
}

clientTransport := NewSSEClientTransport(httpServer.URL, &SSEClientTransportOptions{
HTTPClient: customClient,
})

c := NewClient("testClient", "v1.0.0", nil)
cs, err := c.Connect(ctx, clientTransport)
Expand All @@ -61,6 +72,11 @@ func TestSSEServer(t *testing.T) {
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
}

// Verify that customClient was used
if atomic.LoadInt64(&customClientUsed) == 0 {
t.Error("Expected custom HTTP client to be used, but it wasn't")
}

// Test that closing either end of the connection terminates the other
// end.
if closeServerFirst {
Expand Down Expand Up @@ -162,3 +178,10 @@ func TestScanEvents(t *testing.T) {
})
}
}

// roundTripperFunc is a helper to create a custom RoundTripper
type roundTripperFunc func(*http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
Loading