Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
// Create a new cancellable context that will manage the connection's lifecycle.
// This is crucial for cleanly shutting down the background SSE listener by
// cancelling its blocking network operations, which prevents hangs on exit.
connCtx, cancel := context.WithCancel(context.Background())
connCtx, cancel := context.WithCancel(ctx)
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
Expand Down Expand Up @@ -1230,7 +1230,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e

// testAuth controls whether a fake Authorization header is added to outgoing requests.
// TODO: replace with a better mechanism when client-side auth is in place.
var testAuth = false
var testAuth atomic.Bool

func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
c.mu.Lock()
Expand All @@ -1242,7 +1242,7 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
if c.sessionID != "" {
req.Header.Set(sessionIDHeader, c.sessionID)
}
if testAuth {
if testAuth.Load() {
req.Header.Set("Authorization", "Bearer foo")
}
}
Expand Down Expand Up @@ -1394,14 +1394,10 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
// Close implements the [Connection] interface.
func (c *streamableClientConn) Close() error {
c.closeOnce.Do(func() {
// Cancel any hanging network requests.
c.cancel()
close(c.done)

if errors.Is(c.failure(), errSessionMissing) {
// If the session is missing, no need to delete it.
} else {
req, err := http.NewRequest(http.MethodDelete, c.url, nil)
req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil)
if err != nil {
c.closeErr = err
} else {
Expand All @@ -1411,6 +1407,10 @@ func (c *streamableClientConn) Close() error {
}
}
}

// Cancel any hanging network requests after cleanup.
c.cancel()
close(c.done)
})
return c.closeErr
}
Expand Down
78 changes: 76 additions & 2 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1310,8 +1310,9 @@ func textContent(t *testing.T, res *CallToolResult) string {
}

func TestTokenInfo(t *testing.T) {
defer func(b bool) { testAuth = b }(testAuth)
testAuth = true
oldAuth := testAuth.Load()
defer testAuth.Store(oldAuth)
testAuth.Store(true)
ctx := context.Background()

// Create a server with a tool that returns TokenInfo.
Expand Down Expand Up @@ -1430,3 +1431,76 @@ func TestStreamableGET(t *testing.T) {
t.Errorf("GET with session ID: got status %d, want %d", got, want)
}
}

func TestStreamableClientContextPropagation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var getCtx, deleteCtx context.Context
var mu sync.Mutex

handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mu.Lock()
switch req.Method {
case http.MethodGet:
if getCtx == nil {
getCtx = req.Context()
}
case http.MethodDelete:
if deleteCtx == nil {
deleteCtx = req.Context()
}
}
mu.Unlock()

fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
},
{"POST", "123", notificationInitialized}: {
status: http.StatusAccepted,
wantProtocolVersion: latestProtocolVersion,
},
{"GET", "123", ""}: {
header: header{
"Content-Type": "text/event-stream",
},
optional: true,
wantProtocolVersion: latestProtocolVersion,
},
{"DELETE", "123", ""}: {},
},
}
fake.ServeHTTP(w, req)
})

httpServer := httptest.NewServer(handler)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}

if err := session.Close(); err != nil {
t.Errorf("session.Close() failed: %v", err)
}

mu.Lock()
defer mu.Unlock()

if getCtx != nil && getCtx.Done() == nil {
Copy link
Contributor

@samthanawalla samthanawalla Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only checks that the context is cancellable but doesn't check if the values are passed. Can we change this test to verify that the value of some testkey is preserved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this comment on this commit

t.Error("GET request context is not cancellable")
}
if deleteCtx != nil && deleteCtx.Done() == nil {
t.Error("DELETE request context is not cancellable")
}
}