Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ type SSEClientTransport struct {
// HTTPClient is the client to use for making HTTP requests. If nil,
// http.DefaultClient is used.
HTTPClient *http.Client

// If set, ModifyRequest is called before each outgoing HTTP request made by the client
// connection. It can be used to, for example, add headers to outgoing requests.
ModifyRequest func(*http.Request)
}

// Connect connects through the client endpoint.
Expand All @@ -346,6 +350,9 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
httpClient = http.DefaultClient
}
req.Header.Set("Accept", "text/event-stream")
if c.ModifyRequest != nil {
c.ModifyRequest(req)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
Expand All @@ -372,11 +379,12 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {

// From here on, the stream takes ownership of resp.Body.
s := &sseClientConn{
client: httpClient,
msgEndpoint: msgEndpoint,
incoming: make(chan []byte, 100),
body: resp.Body,
done: make(chan struct{}),
client: httpClient,
msgEndpoint: msgEndpoint,
modifyRequest: c.ModifyRequest,
incoming: make(chan []byte, 100),
body: resp.Body,
done: make(chan struct{}),
}

go func() {
Expand All @@ -403,9 +411,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
// - Close terminates the GET request.
type sseClientConn struct {
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
modifyRequest func(*http.Request) // optional callback to modify outgoing requests
incoming chan []byte // queue of incoming messages

mu sync.Mutex
body io.ReadCloser // body of the hanging GET
Expand Down Expand Up @@ -456,6 +465,9 @@ func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error {
return err
}
req.Header.Set("Content-Type", "application/json")
if c.modifyRequest != nil {
c.modifyRequest(req)
}
resp, err := c.client.Do(req)
if err != nil {
return err
Expand Down
125 changes: 125 additions & 0 deletions mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"

Expand Down Expand Up @@ -131,3 +132,127 @@ type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func TestSSEClientModifyRequest(t *testing.T) {
ctx := context.Background()

// Track all HTTP requests
var mu sync.Mutex
var requestMethods []string
var requestHeaders []http.Header

// Create a server
server := NewServer(testImpl, nil)
AddTool(server, &Tool{Name: "greet"}, sayHi)

sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil)
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mu.Lock()
requestMethods = append(requestMethods, req.Method)
requestHeaders = append(requestHeaders, req.Header.Clone())
mu.Unlock()
sseHandler.ServeHTTP(w, req)
}))
defer httpServer.Close()

// Create transport with ModifyRequest
clientTransport := &SSEClientTransport{
Endpoint: httpServer.URL,
ModifyRequest: func(req *http.Request) {
req.Header.Set("X-Custom-Header", "test-value")
req.Header.Set("Authorization", "Bearer test-token")
},
}

c := NewClient(testImpl, nil)
cs, err := c.Connect(ctx, clientTransport, nil)
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
defer cs.Close()

// Call a tool (which will make a POST request)
_, err = cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
})
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}

// Verify that we have both GET and POST requests
mu.Lock()
defer mu.Unlock()

if len(requestMethods) < 2 {
t.Fatalf("Expected at least 2 requests (GET and POST), got %d", len(requestMethods))
}

// Verify GET request has custom headers
foundGET := false
for i, method := range requestMethods {
if method == "GET" {
foundGET = true
if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" {
t.Errorf("GET request: X-Custom-Header = %q, want %q", got, "test-value")
}
if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" {
t.Errorf("GET request: Authorization = %q, want %q", got, "Bearer test-token")
}
}
}
if !foundGET {
t.Error("No GET request found")
}

// Verify POST request has custom headers
foundPOST := false
for i, method := range requestMethods {
if method == "POST" {
foundPOST = true
if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" {
t.Errorf("POST request: X-Custom-Header = %q, want %q", got, "test-value")
}
if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" {
t.Errorf("POST request: Authorization = %q, want %q", got, "Bearer test-token")
}
}
}
if !foundPOST {
t.Error("No POST request found")
}
}

func TestSSEClientModifyRequestNil(t *testing.T) {
ctx := context.Background()

// Create a server
server := NewServer(testImpl, nil)
AddTool(server, &Tool{Name: "greet"}, sayHi)

sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil)
httpServer := httptest.NewServer(sseHandler)
defer httpServer.Close()

// Create transport with nil ModifyRequest (should not panic)
clientTransport := &SSEClientTransport{
Endpoint: httpServer.URL,
ModifyRequest: nil, // explicitly nil
}

c := NewClient(testImpl, nil)
cs, err := c.Connect(ctx, clientTransport, nil)
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
defer cs.Close()

// Call a tool - should work normally
_, err = cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
})
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}
}
51 changes: 33 additions & 18 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,10 @@ type StreamableClientTransport struct {
// It defaults to 5. To disable retries, use a negative number.
MaxRetries int

// If set, ModifyRequest is called before each outgoing HTTP request made by the client
// connection. It can be used to, for example, add headers to outgoing requests.
ModifyRequest func(*http.Request)

// TODO(rfindley): propose exporting these.
// If strict is set, the transport is in 'strict mode', where any violation
// of the MCP spec causes a failure.
Expand Down Expand Up @@ -1029,29 +1033,31 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
// cancelling its blocking network operations, which prevents hangs on exit.
connCtx, cancel := context.WithCancel(ctx)
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: t.logger,
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
url: t.Endpoint,
client: client,
modifyRequest: t.ModifyRequest,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: t.logger,
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
}
return conn, nil
}

type streamableClientConn struct {
url string
client *http.Client
ctx context.Context
cancel context.CancelFunc
incoming chan jsonrpc.Message
maxRetries int
strict bool // from [StreamableClientTransport.strict]
logger *slog.Logger // from [StreamableClientTransport.logger]
url string
client *http.Client
modifyRequest func(*http.Request) // from [StreamableClientTransport.ModifyRequest]
ctx context.Context
cancel context.CancelFunc
incoming chan jsonrpc.Message
maxRetries int
strict bool // from [StreamableClientTransport.strict]
logger *slog.Logger // from [StreamableClientTransport.logger]

// Guard calls to Close, as it may be called multiple times.
closeOnce sync.Once
Expand Down Expand Up @@ -1188,6 +1194,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
c.setMCPHeaders(req)
if c.modifyRequest != nil {
c.modifyRequest(req)
}

resp, err := c.client.Do(req)
if err != nil {
Expand Down Expand Up @@ -1448,6 +1457,9 @@ func (c *streamableClientConn) Close() error {
c.closeErr = err
} else {
c.setMCPHeaders(req)
if c.modifyRequest != nil {
c.modifyRequest(req)
}
if _, err := c.client.Do(req); err != nil {
c.closeErr = err
}
Expand All @@ -1474,6 +1486,9 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
req.Header.Set("Last-Event-ID", lastEventID)
}
req.Header.Set("Accept", "text/event-stream")
if c.modifyRequest != nil {
c.modifyRequest(req)
}

return c.client.Do(req)
}
Expand Down
Loading