Skip to content

Commit 84508c7

Browse files
findleyrgopherbot
authored andcommitted
internal/mcp: provide a customizable http.Client to client transports
Allow customizing the http.Client used for HTTP MCP client transports, by adding client options structs. Change-Id: I2297acb136f8d0f7fa70d58cd244a6a81cc89751 Reviewed-on: https://go-review.googlesource.com/c/tools/+/682756 Reviewed-by: Jonathan Amsterdam <[email protected]> Auto-Submit: Robert Findley <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent e5d7ba0 commit 84508c7

File tree

8 files changed

+91
-16
lines changed

8 files changed

+91
-16
lines changed

gopls/internal/cmd/mcp_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ func MyFun() {}
132132
t.Logf("failed %d, trying again", i)
133133
time.Sleep(50 * time.Millisecond << i) // retry with exponential backoff
134134
}
135-
serverConn, err := client.Connect(ctx, mcp.NewSSEClientTransport("http://"+addr))
135+
serverConn, err := client.Connect(ctx, mcp.NewSSEClientTransport("http://"+addr, nil))
136136
if err != nil {
137137
// This shouldn't happen because we already waited for the http server to start listening.
138138
t.Fatalf("connecting to server: %v", err)

gopls/internal/test/marker/marker_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ func newEnv(t *testing.T, cache *cache.Cache, files, proxyFiles map[string][]byt
10141014
var mcpSession *mcp.ClientSession
10151015
if enableMCP {
10161016
client := mcp.NewClient("test", "v1.0.0", nil)
1017-
mcpSession, err = client.Connect(ctx, mcp.NewSSEClientTransport(mcpServer.URL))
1017+
mcpSession, err = client.Connect(ctx, mcp.NewSSEClientTransport(mcpServer.URL, nil))
10181018
if err != nil {
10191019
t.Fatalf("fail to connect to mcp server: %v", err)
10201020
}

internal/mcp/design/design.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,17 @@ The SSE client transport is simpler, and hopefully self-explanatory.
187187
```go
188188
type SSEClientTransport struct { /* ... */ }
189189

190+
// SSEClientTransportOptions provides options for the [NewSSEClientTransport]
191+
// constructor.
192+
type SSEClientTransportOptions struct {
193+
// HTTPClient is the client to use for making HTTP requests. If nil,
194+
// http.DefaultClient is used.
195+
HTTPClient *http.Client
196+
}
197+
190198
// NewSSEClientTransport returns a new client transport that connects to the
191199
// SSE server at the provided URL.
192-
func NewSSEClientTransport(url string) (*SSEClientTransport, error) {
200+
func NewSSEClientTransport(url string, opts *SSEClientTransportOptions) (*SSEClientTransport, error)
193201

194202
// Connect connects through the client endpoint.
195203
func (*SSEClientTransport) Connect(ctx context.Context) (Connection, error)
@@ -217,7 +225,16 @@ func (*StreamableServerTransport) Connect(context.Context) (Connection, error)
217225

218226
// The streamable client handles reconnection transparently to the user.
219227
type StreamableClientTransport struct { /* ... */ }
220-
func NewStreamableClientTransport(url string) *StreamableClientTransport {
228+
229+
// StreamableClientTransportOptions provides options for the
230+
// [NewStreamableClientTransport] constructor.
231+
type StreamableClientTransportOptions struct {
232+
// HTTPClient is the client to use for making HTTP requests. If nil,
233+
// http.DefaultClient is used.
234+
HTTPClient *http.Client
235+
}
236+
237+
func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport
221238
func (*StreamableClientTransport) Connect(context.Context) (Connection, error)
222239
```
223240

internal/mcp/sse.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,20 +322,33 @@ func (s sseServerConn) Close() error {
322322
// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports
323323
type SSEClientTransport struct {
324324
sseEndpoint *url.URL
325+
opts SSEClientTransportOptions
326+
}
327+
328+
// SSEClientTransportOptions provides options for the [NewSSEClientTransport]
329+
// constructor.
330+
type SSEClientTransportOptions struct {
331+
// HTTPClient is the client to use for making HTTP requests. If nil,
332+
// http.DefaultClient is used.
333+
HTTPClient *http.Client
325334
}
326335

327336
// NewSSEClientTransport returns a new client transport that connects to the
328337
// SSE server at the provided URL.
329338
//
330339
// NewSSEClientTransport panics if the given URL is invalid.
331-
func NewSSEClientTransport(baseURL string) *SSEClientTransport {
340+
func NewSSEClientTransport(baseURL string, opts *SSEClientTransportOptions) *SSEClientTransport {
332341
url, err := url.Parse(baseURL)
333342
if err != nil {
334343
panic(fmt.Sprintf("invalid base url: %v", err))
335344
}
336-
return &SSEClientTransport{
345+
t := &SSEClientTransport{
337346
sseEndpoint: url,
338347
}
348+
if opts != nil {
349+
t.opts = *opts
350+
}
351+
return t
339352
}
340353

341354
// Connect connects through the client endpoint.
@@ -344,8 +357,12 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
344357
if err != nil {
345358
return nil, err
346359
}
360+
httpClient := c.opts.HTTPClient
361+
if httpClient == nil {
362+
httpClient = http.DefaultClient
363+
}
347364
req.Header.Set("Accept", "text/event-stream")
348-
resp, err := http.DefaultClient.Do(req)
365+
resp, err := httpClient.Do(req)
349366
if err != nil {
350367
return nil, err
351368
}

internal/mcp/sse_example_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func ExampleSSEHandler() {
3333
defer httpServer.Close()
3434

3535
ctx := context.Background()
36-
transport := mcp.NewSSEClientTransport(httpServer.URL)
36+
transport := mcp.NewSSEClientTransport(httpServer.URL, nil)
3737
client := mcp.NewClient("test", "v1.0.0", nil)
3838
cs, err := client.Connect(ctx, transport)
3939
if err != nil {

internal/mcp/sse_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestSSEServer(t *testing.T) {
3434
httpServer := httptest.NewServer(sseHandler)
3535
defer httpServer.Close()
3636

37-
clientTransport := NewSSEClientTransport(httpServer.URL)
37+
clientTransport := NewSSEClientTransport(httpServer.URL, nil)
3838

3939
c := NewClient("testClient", "v1.0.0", nil)
4040
cs, err := c.Connect(ctx, clientTransport)

internal/mcp/streamable.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -571,13 +571,26 @@ func (t *StreamableServerTransport) Close() error {
571571
//
572572
// TODO(rfindley): support retries and resumption tokens.
573573
type StreamableClientTransport struct {
574-
url string
574+
url string
575+
opts StreamableClientTransportOptions
576+
}
577+
578+
// StreamableClientTransportOptions provides options for the
579+
// [NewStreamableClientTransport] constructor.
580+
type StreamableClientTransportOptions struct {
581+
// HTTPClient is the client to use for making HTTP requests. If nil,
582+
// http.DefaultClient is used.
583+
HTTPClient *http.Client
575584
}
576585

577586
// NewStreamableClientTransport returns a new client transport that connects to
578587
// the streamable HTTP server at the provided URL.
579-
func NewStreamableClientTransport(url string) *StreamableClientTransport {
580-
return &StreamableClientTransport{url: url}
588+
func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport {
589+
t := &StreamableClientTransport{url: url}
590+
if opts != nil {
591+
t.opts = *opts
592+
}
593+
return t
581594
}
582595

583596
// Connect implements the [Transport] interface.
@@ -589,9 +602,13 @@ func NewStreamableClientTransport(url string) *StreamableClientTransport {
589602
// When closed, the connection issues a DELETE request to terminate the logical
590603
// session.
591604
func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) {
605+
client := t.opts.HTTPClient
606+
if client == nil {
607+
client = http.DefaultClient
608+
}
592609
return &streamableClientConn{
593610
url: t.url,
594-
client: http.DefaultClient,
611+
client: client,
595612
incoming: make(chan []byte, 100),
596613
done: make(chan struct{}),
597614
}, nil

internal/mcp/streamable_test.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"fmt"
1212
"io"
1313
"net/http"
14+
"net/http/cookiejar"
1415
"net/http/httptest"
16+
"net/url"
1517
"strings"
1618
"sync"
1719
"sync/atomic"
@@ -32,13 +34,35 @@ func TestStreamableTransports(t *testing.T) {
3234
server := NewServer("testServer", "v1.0.0", nil)
3335
server.AddTools(NewServerTool("greet", "say hi", sayHi))
3436

35-
// 2. Start an httptest.Server with the StreamableHTTPHandler.
37+
// 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
38+
// cookie-checking middleware.
3639
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
37-
httpServer := httptest.NewServer(handler)
40+
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41+
cookie, err := r.Cookie("test-cookie")
42+
if err != nil {
43+
t.Errorf("missing cookie: %v", err)
44+
} else if cookie.Value != "test-value" {
45+
t.Errorf("got cookie %q, want %q", cookie.Value, "test-value")
46+
}
47+
handler.ServeHTTP(w, r)
48+
}))
3849
defer httpServer.Close()
3950

4051
// 3. Create a client and connect it to the server using our StreamableClientTransport.
41-
transport := NewStreamableClientTransport(httpServer.URL)
52+
// Check that all requests honor a custom client.
53+
jar, err := cookiejar.New(nil)
54+
if err != nil {
55+
t.Fatal(err)
56+
}
57+
u, err := url.Parse(httpServer.URL)
58+
if err != nil {
59+
t.Fatal(err)
60+
}
61+
jar.SetCookies(u, []*http.Cookie{{Name: "test-cookie", Value: "test-value"}})
62+
httpClient := &http.Client{Jar: jar}
63+
transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{
64+
HTTPClient: httpClient,
65+
})
4266
client := NewClient("testClient", "v1.0.0", nil)
4367
session, err := client.Connect(ctx, transport)
4468
if err != nil {

0 commit comments

Comments
 (0)