From 8ccf26c8c728d465cd0c6ed97e9c1390c10b0441 Mon Sep 17 00:00:00 2001 From: Manuel Ibar Date: Tue, 30 Sep 2025 16:31:53 -0300 Subject: [PATCH 1/2] mcp: add ModifyRequest to HTTP client transports Add ModifyRequest func(*http.Request) field to both SSEClientTransport and StreamableClientTransport. This callback is invoked before each outgoing HTTP request, allowing users to add headers, authentication, or other request modifications. This provides a simpler alternative to implementing custom RoundTrippers for common use cases like adding authorization headers or request IDs. Fixes #533 --- mcp/sse.go | 28 +++++-- mcp/sse_test.go | 125 ++++++++++++++++++++++++++++++ mcp/streamable.go | 51 ++++++++----- mcp/streamable_test.go | 138 ++++++++++++++++++++++++++++++++++ mcp/transport_example_test.go | 45 +++++++++++ 5 files changed, 361 insertions(+), 26 deletions(-) diff --git a/mcp/sse.go b/mcp/sse.go index 7f644918..4302230c 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -329,6 +329,10 @@ type SSEClientTransport struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. HTTPClient *http.Client + + // If set, ModifyRequest is called before each outgoing HTTP request made by the client + // connection. It can be used to, for example, add headers to outgoing requests. + ModifyRequest func(*http.Request) } // Connect connects through the client endpoint. @@ -346,6 +350,9 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { httpClient = http.DefaultClient } req.Header.Set("Accept", "text/event-stream") + if c.ModifyRequest != nil { + c.ModifyRequest(req) + } resp, err := httpClient.Do(req) if err != nil { return nil, err @@ -372,11 +379,12 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // From here on, the stream takes ownership of resp.Body. s := &sseClientConn{ - client: httpClient, - msgEndpoint: msgEndpoint, - incoming: make(chan []byte, 100), - body: resp.Body, - done: make(chan struct{}), + client: httpClient, + msgEndpoint: msgEndpoint, + modifyRequest: c.ModifyRequest, + incoming: make(chan []byte, 100), + body: resp.Body, + done: make(chan struct{}), } go func() { @@ -403,9 +411,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - client *http.Client // HTTP client to use for requests - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + modifyRequest func(*http.Request) // optional callback to modify outgoing requests + incoming chan []byte // queue of incoming messages mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -456,6 +465,9 @@ func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { return err } req.Header.Set("Content-Type", "application/json") + if c.modifyRequest != nil { + c.modifyRequest(req) + } resp, err := c.client.Do(req) if err != nil { return err diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 25435ff3..281ca68d 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "net/http/httptest" + "sync" "sync/atomic" "testing" @@ -131,3 +132,127 @@ type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } + +func TestSSEClientModifyRequest(t *testing.T) { + ctx := context.Background() + + // Track all HTTP requests + var mu sync.Mutex + var requestMethods []string + var requestHeaders []http.Header + + // Create a server + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet"}, sayHi) + + sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + mu.Lock() + requestMethods = append(requestMethods, req.Method) + requestHeaders = append(requestHeaders, req.Header.Clone()) + mu.Unlock() + sseHandler.ServeHTTP(w, req) + })) + defer httpServer.Close() + + // Create transport with ModifyRequest + clientTransport := &SSEClientTransport{ + Endpoint: httpServer.URL, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Custom-Header", "test-value") + req.Header.Set("Authorization", "Bearer test-token") + }, + } + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer cs.Close() + + // Call a tool (which will make a POST request) + _, err = cs.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "user"}, + }) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + // Verify that we have both GET and POST requests + mu.Lock() + defer mu.Unlock() + + if len(requestMethods) < 2 { + t.Fatalf("Expected at least 2 requests (GET and POST), got %d", len(requestMethods)) + } + + // Verify GET request has custom headers + foundGET := false + for i, method := range requestMethods { + if method == "GET" { + foundGET = true + if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" { + t.Errorf("GET request: X-Custom-Header = %q, want %q", got, "test-value") + } + if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" { + t.Errorf("GET request: Authorization = %q, want %q", got, "Bearer test-token") + } + } + } + if !foundGET { + t.Error("No GET request found") + } + + // Verify POST request has custom headers + foundPOST := false + for i, method := range requestMethods { + if method == "POST" { + foundPOST = true + if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" { + t.Errorf("POST request: X-Custom-Header = %q, want %q", got, "test-value") + } + if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" { + t.Errorf("POST request: Authorization = %q, want %q", got, "Bearer test-token") + } + } + } + if !foundPOST { + t.Error("No POST request found") + } +} + +func TestSSEClientModifyRequestNil(t *testing.T) { + ctx := context.Background() + + // Create a server + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet"}, sayHi) + + sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(sseHandler) + defer httpServer.Close() + + // Create transport with nil ModifyRequest (should not panic) + clientTransport := &SSEClientTransport{ + Endpoint: httpServer.URL, + ModifyRequest: nil, // explicitly nil + } + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer cs.Close() + + // Call a tool - should work normally + _, err = cs.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "user"}, + }) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } +} diff --git a/mcp/streamable.go b/mcp/streamable.go index a96386d9..10f96442 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -983,6 +983,10 @@ type StreamableClientTransport struct { // It defaults to 5. To disable retries, use a negative number. MaxRetries int + // If set, ModifyRequest is called before each outgoing HTTP request made by the client + // connection. It can be used to, for example, add headers to outgoing requests. + ModifyRequest func(*http.Request) + // TODO(rfindley): propose exporting these. // If strict is set, the transport is in 'strict mode', where any violation // of the MCP spec causes a failure. @@ -1029,29 +1033,31 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // cancelling its blocking network operations, which prevents hangs on exit. connCtx, cancel := context.WithCancel(ctx) conn := &streamableClientConn{ - url: t.Endpoint, - client: client, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - maxRetries: maxRetries, - strict: t.strict, - logger: t.logger, - ctx: connCtx, - cancel: cancel, - failed: make(chan struct{}), + url: t.Endpoint, + client: client, + modifyRequest: t.ModifyRequest, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + strict: t.strict, + logger: t.logger, + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), } return conn, nil } type streamableClientConn struct { - url string - client *http.Client - ctx context.Context - cancel context.CancelFunc - incoming chan jsonrpc.Message - maxRetries int - strict bool // from [StreamableClientTransport.strict] - logger *slog.Logger // from [StreamableClientTransport.logger] + url string + client *http.Client + modifyRequest func(*http.Request) // from [StreamableClientTransport.ModifyRequest] + ctx context.Context + cancel context.CancelFunc + incoming chan jsonrpc.Message + maxRetries int + strict bool // from [StreamableClientTransport.strict] + logger *slog.Logger // from [StreamableClientTransport.logger] // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once @@ -1188,6 +1194,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") c.setMCPHeaders(req) + if c.modifyRequest != nil { + c.modifyRequest(req) + } resp, err := c.client.Do(req) if err != nil { @@ -1448,6 +1457,9 @@ func (c *streamableClientConn) Close() error { c.closeErr = err } else { c.setMCPHeaders(req) + if c.modifyRequest != nil { + c.modifyRequest(req) + } if _, err := c.client.Do(req); err != nil { c.closeErr = err } @@ -1474,6 +1486,9 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, req.Header.Set("Last-Event-ID", lastEventID) } req.Header.Set("Accept", "text/event-stream") + if c.modifyRequest != nil { + c.modifyRequest(req) + } return c.client.Do(req) } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 3576d2b5..13954c88 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1485,3 +1485,141 @@ func TestStreamableClientContextPropagation(t *testing.T) { } } + +func TestStreamableClientModifyRequest(t *testing.T) { + ctx := context.Background() + + // Track all HTTP requests + var mu sync.Mutex + var requestMethods []string + var requestHeaders []http.Header + + // Create a server + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet"}, sayHi) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + mu.Lock() + requestMethods = append(requestMethods, req.Method) + requestHeaders = append(requestHeaders, req.Header.Clone()) + mu.Unlock() + handler.ServeHTTP(w, req) + })) + defer httpServer.Close() + + // Create transport with ModifyRequest + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Custom-Header", "test-value") + req.Header.Set("Authorization", "Bearer test-token") + }, + } + + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + + // Call a tool (which will make POST and potentially GET requests) + _, err = session.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "user"}, + }) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + // Close the session (which will make a DELETE request) + session.Close() + + // Verify that we have POST and DELETE requests + mu.Lock() + defer mu.Unlock() + + if len(requestMethods) < 2 { + t.Fatalf("Expected at least 2 requests (POST and DELETE), got %d", len(requestMethods)) + } + + // Verify POST request has custom headers + foundPOST := false + for i, method := range requestMethods { + if method == "POST" { + foundPOST = true + if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" { + t.Errorf("POST request: X-Custom-Header = %q, want %q", got, "test-value") + } + if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" { + t.Errorf("POST request: Authorization = %q, want %q", got, "Bearer test-token") + } + } + } + if !foundPOST { + t.Error("No POST request found") + } + + // Verify DELETE request has custom headers + foundDELETE := false + for i, method := range requestMethods { + if method == "DELETE" { + foundDELETE = true + if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" { + t.Errorf("DELETE request: X-Custom-Header = %q, want %q", got, "test-value") + } + if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" { + t.Errorf("DELETE request: Authorization = %q, want %q", got, "Bearer test-token") + } + } + } + if !foundDELETE { + t.Error("No DELETE request found") + } + + // Verify GET request has custom headers (if any) + for i, method := range requestMethods { + if method == "GET" { + if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" { + t.Errorf("GET request: X-Custom-Header = %q, want %q", got, "test-value") + } + if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" { + t.Errorf("GET request: Authorization = %q, want %q", got, "Bearer test-token") + } + } + } +} + +func TestStreamableClientModifyRequestNil(t *testing.T) { + ctx := context.Background() + + // Create a server + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet"}, sayHi) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // Create transport with nil ModifyRequest (should not panic) + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + ModifyRequest: nil, // explicitly nil + } + + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer session.Close() + + // Call a tool - should work normally + _, err = session.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "user"}, + }) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } +} diff --git a/mcp/transport_example_test.go b/mcp/transport_example_test.go index ab54a422..9128fc10 100644 --- a/mcp/transport_example_test.go +++ b/mcp/transport_example_test.go @@ -12,6 +12,7 @@ import ( "context" "fmt" "log" + "net/http" "slices" "strings" @@ -51,3 +52,47 @@ func ExampleLoggingTransport() { } // !-loggingtransport + +func ExampleStreamableClientTransport_ModifyRequest() { + // Create a transport with ModifyRequest to add authentication headers + transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + ModifyRequest: func(req *http.Request) { + req.Header.Set("Authorization", "Bearer my-secret-token") + req.Header.Set("X-Request-ID", "req-12345") + }, + } + + ctx := context.Background() + client := mcp.NewClient(&mcp.Implementation{Name: "my-client", Version: "v1.0.0"}, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + // All HTTP requests (POST, GET, DELETE) will have the custom headers + // added by ModifyRequest before being sent to the server. +} + +// This example demonstrates how to use ModifyRequest with SSEClientTransport. +func ExampleSSEClientTransport_ModifyRequest() { + // Create a transport with ModifyRequest + transport := &mcp.SSEClientTransport{ + Endpoint: "https://example.com/sse", + ModifyRequest: func(req *http.Request) { + req.Header.Set("Authorization", "Bearer my-token") + req.Header.Set("X-Custom-Header", "custom-value") + }, + } + + ctx := context.Background() + client := mcp.NewClient(&mcp.Implementation{Name: "my-client", Version: "v1.0.0"}, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + // All HTTP requests will have the custom headers +} From e16b01f06556e6c86087466ae90666141e64b637 Mon Sep 17 00:00:00 2001 From: Manuel Ibar Date: Tue, 30 Sep 2025 16:36:41 -0300 Subject: [PATCH 2/2] mcp: fix formatting in sse.go --- mcp/sse.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mcp/sse.go b/mcp/sse.go index 4302230c..03790cd9 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -411,10 +411,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - client *http.Client // HTTP client to use for requests - msgEndpoint *url.URL // session endpoint for POSTs - modifyRequest func(*http.Request) // optional callback to modify outgoing requests - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + modifyRequest func(*http.Request) // optional callback to modify outgoing requests + incoming chan []byte // queue of incoming messages mu sync.Mutex body io.ReadCloser // body of the hanging GET