Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ type SSEClientTransport struct {
// HTTPClient is the client to use for making HTTP requests. If nil,
// http.DefaultClient is used.
HTTPClient *http.Client
// ModifyRequest, if non-nil, is called before each outbound HTTP request.
// It can be used to set headers (for example, auth headers) prior to sending
// the request.
ModifyRequest func(*http.Request)
}

// Connect connects through the client endpoint.
Expand All @@ -346,6 +350,9 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
httpClient = http.DefaultClient
}
req.Header.Set("Accept", "text/event-stream")
if c.ModifyRequest != nil {
c.ModifyRequest(req)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
Expand All @@ -372,11 +379,12 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {

// From here on, the stream takes ownership of resp.Body.
s := &sseClientConn{
client: httpClient,
msgEndpoint: msgEndpoint,
incoming: make(chan []byte, 100),
body: resp.Body,
done: make(chan struct{}),
client: httpClient,
msgEndpoint: msgEndpoint,
incoming: make(chan []byte, 100),
body: resp.Body,
done: make(chan struct{}),
modifyRequest: c.ModifyRequest,
}

go func() {
Expand All @@ -403,9 +411,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
// - Close terminates the GET request.
type sseClientConn struct {
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages
modifyRequest func(*http.Request)

mu sync.Mutex
body io.ReadCloser // body of the hanging GET
Expand Down Expand Up @@ -456,6 +465,9 @@ func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error {
return err
}
req.Header.Set("Content-Type", "application/json")
if c.modifyRequest != nil {
c.modifyRequest(req)
}
resp, err := c.client.Do(req)
if err != nil {
return err
Expand Down
53 changes: 53 additions & 0 deletions mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,59 @@ func TestSSEServer(t *testing.T) {
}
}

func TestSSEClientTransportModifyRequest(t *testing.T) {
ctx := context.Background()
const headerName = "X-Test-Header"
const headerValue = "abc123"

server := NewServer(testImpl, nil)
sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil)

var getCount atomic.Int32
var postCount atomic.Int32

httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if got := req.Header.Get(headerName); got != headerValue {
t.Errorf("header mismatch for %s: got %q, want %q", req.Method, got, headerValue)
}
switch req.Method {
case http.MethodGet:
getCount.Add(1)
case http.MethodPost:
postCount.Add(1)
}
sseHandler.ServeHTTP(w, req)
}))
defer httpServer.Close()

clientTransport := &SSEClientTransport{
Endpoint: httpServer.URL,
ModifyRequest: func(req *http.Request) {
req.Header.Set(headerName, headerValue)
},
}

c := NewClient(testImpl, nil)
cs, err := c.Connect(ctx, clientTransport, nil)
if err != nil {
t.Fatal(err)
}
if err := cs.Ping(ctx, nil); err != nil {
t.Fatal(err)
}
if err := cs.Close(); err != nil {
t.Fatal(err)
}
cs.Wait()

if getCount.Load() == 0 {
t.Errorf("expected GET request to include %s header", headerName)
}
if postCount.Load() == 0 {
t.Errorf("expected POST request to include %s header", headerName)
}
}

// roundTripperFunc is a helper to create a custom RoundTripper
type roundTripperFunc func(*http.Request) (*http.Response, error)

Expand Down
51 changes: 33 additions & 18 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,10 @@ func (c *streamableServerConn) Close() error {
type StreamableClientTransport struct {
Endpoint string
HTTPClient *http.Client
// ModifyRequest, if non-nil, is invoked before every outbound HTTP request.
// It can be used to set headers (for example, auth headers) or otherwise
// adjust the request before it is sent.
ModifyRequest func(*http.Request)
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
// It defaults to 5. To disable retries, use a negative number.
MaxRetries int
Expand Down Expand Up @@ -1054,29 +1058,31 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
// cancelling its blocking network operations, which prevents hangs on exit.
connCtx, cancel := context.WithCancel(ctx)
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: t.logger,
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: t.logger,
ctx: connCtx,
cancel: cancel,
modifyRequest: t.ModifyRequest,
failed: make(chan struct{}),
}
return conn, nil
}

type streamableClientConn struct {
url string
client *http.Client
ctx context.Context
cancel context.CancelFunc
incoming chan jsonrpc.Message
maxRetries int
strict bool // from [StreamableClientTransport.strict]
logger *slog.Logger // from [StreamableClientTransport.logger]
url string
client *http.Client
ctx context.Context
cancel context.CancelFunc
incoming chan jsonrpc.Message
maxRetries int
strict bool // from [StreamableClientTransport.strict]
logger *slog.Logger // from [StreamableClientTransport.logger]
modifyRequest func(*http.Request) // from [StreamableClientTransport.ModifyRequest]

// Guard calls to Close, as it may be called multiple times.
closeOnce sync.Once
Expand Down Expand Up @@ -1213,6 +1219,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
c.setMCPHeaders(req)
if c.modifyRequest != nil {
c.modifyRequest(req)
}

resp, err := c.client.Do(req)
if err != nil {
Expand Down Expand Up @@ -1473,6 +1482,9 @@ func (c *streamableClientConn) Close() error {
c.closeErr = err
} else {
c.setMCPHeaders(req)
if c.modifyRequest != nil {
c.modifyRequest(req)
}
if _, err := c.client.Do(req); err != nil {
c.closeErr = err
}
Expand All @@ -1499,6 +1511,9 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
req.Header.Set("Last-Event-ID", lastEventID)
}
req.Header.Set("Accept", "text/event-stream")
if c.modifyRequest != nil {
c.modifyRequest(req)
}

return c.client.Do(req)
}
Expand Down
69 changes: 69 additions & 0 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type streamableResponse struct {
status int // or http.StatusOK
body string // or ""
optional bool // if set, request need not be sent
wantHeader header // optional request headers to verify
wantProtocolVersion string // if "", unchecked
callback func() // if set, called after the request is handled
}
Expand Down Expand Up @@ -96,6 +97,11 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
http.Error(w, "no response", http.StatusInternalServerError)
return
}
for hdr, want := range resp.wantHeader {
if got := req.Header.Get(hdr); got != want {
s.t.Errorf("%v: header %q = %q, want %q", key, hdr, got, want)
}
}
if resp.callback != nil {
defer resp.callback()
}
Expand Down Expand Up @@ -187,6 +193,69 @@ func TestStreamableClientTransportLifecycle(t *testing.T) {
}
}

func TestStreamableClientTransportModifyRequest(t *testing.T) {
ctx := context.Background()
const headerName = "X-Test-Header"
const headerValue = "abc123"

fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
wantHeader: header{
headerName: headerValue,
},
},
{"POST", "123", notificationInitialized}: {
status: http.StatusAccepted,
wantHeader: header{
headerName: headerValue,
},
},
{"GET", "123", ""}: {
header: header{
"Content-Type": "text/event-stream",
},
optional: true,
wantHeader: header{
headerName: headerValue,
},
},
{"DELETE", "123", ""}: {
wantHeader: header{
headerName: headerValue,
},
},
},
}

httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{
Endpoint: httpServer.URL,
ModifyRequest: func(req *http.Request) {
req.Header.Set(headerName, headerValue)
},
}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
if err := session.Close(); err != nil {
t.Errorf("closing session: %v", err)
}
if missing := fake.missingRequests(); len(missing) > 0 {
t.Errorf("did not receive expected requests: %v", missing)
}
}

func TestStreamableClientRedundantDelete(t *testing.T) {
ctx := context.Background()

Expand Down