diff --git a/mcp/sse.go b/mcp/sse.go index 7f644918..e041db4e 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -329,6 +329,10 @@ type SSEClientTransport struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. HTTPClient *http.Client + // ModifyRequest, if non-nil, is called before each outbound HTTP request. + // It can be used to set headers (for example, auth headers) prior to sending + // the request. + ModifyRequest func(*http.Request) } // Connect connects through the client endpoint. @@ -346,6 +350,9 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { httpClient = http.DefaultClient } req.Header.Set("Accept", "text/event-stream") + if c.ModifyRequest != nil { + c.ModifyRequest(req) + } resp, err := httpClient.Do(req) if err != nil { return nil, err @@ -372,11 +379,12 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // From here on, the stream takes ownership of resp.Body. s := &sseClientConn{ - client: httpClient, - msgEndpoint: msgEndpoint, - incoming: make(chan []byte, 100), - body: resp.Body, - done: make(chan struct{}), + client: httpClient, + msgEndpoint: msgEndpoint, + incoming: make(chan []byte, 100), + body: resp.Body, + done: make(chan struct{}), + modifyRequest: c.ModifyRequest, } go func() { @@ -403,9 +411,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - client *http.Client // HTTP client to use for requests - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan []byte // queue of incoming messages + modifyRequest func(*http.Request) mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -456,6 +465,9 @@ func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { return err } req.Header.Set("Content-Type", "application/json") + if c.modifyRequest != nil { + c.modifyRequest(req) + } resp, err := c.client.Do(req) if err != nil { return err diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 25435ff3..c61d920d 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -125,6 +125,59 @@ func TestSSEServer(t *testing.T) { } } +func TestSSEClientTransportModifyRequest(t *testing.T) { + ctx := context.Background() + const headerName = "X-Test-Header" + const headerValue = "abc123" + + server := NewServer(testImpl, nil) + sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil) + + var getCount atomic.Int32 + var postCount atomic.Int32 + + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if got := req.Header.Get(headerName); got != headerValue { + t.Errorf("header mismatch for %s: got %q, want %q", req.Method, got, headerValue) + } + switch req.Method { + case http.MethodGet: + getCount.Add(1) + case http.MethodPost: + postCount.Add(1) + } + sseHandler.ServeHTTP(w, req) + })) + defer httpServer.Close() + + clientTransport := &SSEClientTransport{ + Endpoint: httpServer.URL, + ModifyRequest: func(req *http.Request) { + req.Header.Set(headerName, headerValue) + }, + } + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatal(err) + } + if err := cs.Ping(ctx, nil); err != nil { + t.Fatal(err) + } + if err := cs.Close(); err != nil { + t.Fatal(err) + } + cs.Wait() + + if getCount.Load() == 0 { + t.Errorf("expected GET request to include %s header", headerName) + } + if postCount.Load() == 0 { + t.Errorf("expected POST request to include %s header", headerName) + } +} + // roundTripperFunc is a helper to create a custom RoundTripper type roundTripperFunc func(*http.Request) (*http.Response, error) diff --git a/mcp/streamable.go b/mcp/streamable.go index 20eb13d5..940cf0a2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1004,6 +1004,10 @@ func (c *streamableServerConn) Close() error { type StreamableClientTransport struct { Endpoint string HTTPClient *http.Client + // ModifyRequest, if non-nil, is invoked before every outbound HTTP request. + // It can be used to set headers (for example, auth headers) or otherwise + // adjust the request before it is sent. + ModifyRequest func(*http.Request) // MaxRetries is the maximum number of times to attempt a reconnect before giving up. // It defaults to 5. To disable retries, use a negative number. MaxRetries int @@ -1054,29 +1058,31 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // cancelling its blocking network operations, which prevents hangs on exit. connCtx, cancel := context.WithCancel(ctx) conn := &streamableClientConn{ - url: t.Endpoint, - client: client, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - maxRetries: maxRetries, - strict: t.strict, - logger: t.logger, - ctx: connCtx, - cancel: cancel, - failed: make(chan struct{}), + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + strict: t.strict, + logger: t.logger, + ctx: connCtx, + cancel: cancel, + modifyRequest: t.ModifyRequest, + failed: make(chan struct{}), } return conn, nil } type streamableClientConn struct { - url string - client *http.Client - ctx context.Context - cancel context.CancelFunc - incoming chan jsonrpc.Message - maxRetries int - strict bool // from [StreamableClientTransport.strict] - logger *slog.Logger // from [StreamableClientTransport.logger] + url string + client *http.Client + ctx context.Context + cancel context.CancelFunc + incoming chan jsonrpc.Message + maxRetries int + strict bool // from [StreamableClientTransport.strict] + logger *slog.Logger // from [StreamableClientTransport.logger] + modifyRequest func(*http.Request) // from [StreamableClientTransport.ModifyRequest] // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once @@ -1213,6 +1219,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") c.setMCPHeaders(req) + if c.modifyRequest != nil { + c.modifyRequest(req) + } resp, err := c.client.Do(req) if err != nil { @@ -1473,6 +1482,9 @@ func (c *streamableClientConn) Close() error { c.closeErr = err } else { c.setMCPHeaders(req) + if c.modifyRequest != nil { + c.modifyRequest(req) + } if _, err := c.client.Do(req); err != nil { c.closeErr = err } @@ -1499,6 +1511,9 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, req.Header.Set("Last-Event-ID", lastEventID) } req.Header.Set("Accept", "text/event-stream") + if c.modifyRequest != nil { + c.modifyRequest(req) + } return c.client.Do(req) } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 9116677b..296a9635 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -33,6 +33,7 @@ type streamableResponse struct { status int // or http.StatusOK body string // or "" optional bool // if set, request need not be sent + wantHeader header // optional request headers to verify wantProtocolVersion string // if "", unchecked callback func() // if set, called after the request is handled } @@ -96,6 +97,11 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques http.Error(w, "no response", http.StatusInternalServerError) return } + for hdr, want := range resp.wantHeader { + if got := req.Header.Get(hdr); got != want { + s.t.Errorf("%v: header %q = %q, want %q", key, hdr, got, want) + } + } if resp.callback != nil { defer resp.callback() } @@ -187,6 +193,69 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { } } +func TestStreamableClientTransportModifyRequest(t *testing.T) { + ctx := context.Background() + const headerName = "X-Test-Header" + const headerValue = "abc123" + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + wantHeader: header{ + headerName: headerValue, + }, + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantHeader: header{ + headerName: headerValue, + }, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + optional: true, + wantHeader: header{ + headerName: headerValue, + }, + }, + {"DELETE", "123", ""}: { + wantHeader: header{ + headerName: headerValue, + }, + }, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + ModifyRequest: func(req *http.Request) { + req.Header.Set(headerName, headerValue) + }, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } +} + func TestStreamableClientRedundantDelete(t *testing.T) { ctx := context.Background()