Skip to content

Commit 5d64d61

Browse files
authored
mcp: fix context propagation in StreamableClientTransport (#514)
## Problem Context propagation was broken in `StreamableClientTransport` because: 1. `Connect()` used `context.Background()` instead of the parent context 2. `Close()` created a race condition where DELETE requests were cancelled before completion ## Solution - Use parent context when creating the connection context in `Connect()` - Reorder `Close()` operations to perform cleanup DELETE before cancelling context ## Impact - Request-scoped values (auth tokens, trace IDs) now propagate correctly to background HTTP operations - Eliminates race condition in cleanup operations - Maintains proper Go context semantics Fixes #513
1 parent 33ff851 commit 5d64d61

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

mcp/streamable.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
10181018
// Create a new cancellable context that will manage the connection's lifecycle.
10191019
// This is crucial for cleanly shutting down the background SSE listener by
10201020
// cancelling its blocking network operations, which prevents hangs on exit.
1021-
connCtx, cancel := context.WithCancel(context.Background())
1021+
connCtx, cancel := context.WithCancel(ctx)
10221022
conn := &streamableClientConn{
10231023
url: t.Endpoint,
10241024
client: client,
@@ -1230,7 +1230,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
12301230

12311231
// testAuth controls whether a fake Authorization header is added to outgoing requests.
12321232
// TODO: replace with a better mechanism when client-side auth is in place.
1233-
var testAuth = false
1233+
var testAuth atomic.Bool
12341234

12351235
func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
12361236
c.mu.Lock()
@@ -1242,7 +1242,7 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
12421242
if c.sessionID != "" {
12431243
req.Header.Set(sessionIDHeader, c.sessionID)
12441244
}
1245-
if testAuth {
1245+
if testAuth.Load() {
12461246
req.Header.Set("Authorization", "Bearer foo")
12471247
}
12481248
}
@@ -1394,14 +1394,10 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
13941394
// Close implements the [Connection] interface.
13951395
func (c *streamableClientConn) Close() error {
13961396
c.closeOnce.Do(func() {
1397-
// Cancel any hanging network requests.
1398-
c.cancel()
1399-
close(c.done)
1400-
14011397
if errors.Is(c.failure(), errSessionMissing) {
14021398
// If the session is missing, no need to delete it.
14031399
} else {
1404-
req, err := http.NewRequest(http.MethodDelete, c.url, nil)
1400+
req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil)
14051401
if err != nil {
14061402
c.closeErr = err
14071403
} else {
@@ -1411,6 +1407,10 @@ func (c *streamableClientConn) Close() error {
14111407
}
14121408
}
14131409
}
1410+
1411+
// Cancel any hanging network requests after cleanup.
1412+
c.cancel()
1413+
close(c.done)
14141414
})
14151415
return c.closeErr
14161416
}

mcp/streamable_test.go

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,8 +1310,9 @@ func textContent(t *testing.T, res *CallToolResult) string {
13101310
}
13111311

13121312
func TestTokenInfo(t *testing.T) {
1313-
defer func(b bool) { testAuth = b }(testAuth)
1314-
testAuth = true
1313+
oldAuth := testAuth.Load()
1314+
defer testAuth.Store(oldAuth)
1315+
testAuth.Store(true)
13151316
ctx := context.Background()
13161317

13171318
// Create a server with a tool that returns TokenInfo.
@@ -1430,3 +1431,57 @@ func TestStreamableGET(t *testing.T) {
14301431
t.Errorf("GET with session ID: got status %d, want %d", got, want)
14311432
}
14321433
}
1434+
1435+
func TestStreamableClientContextPropagation(t *testing.T) {
1436+
type contextKey string
1437+
const testKey = contextKey("test-key")
1438+
const testValue = "test-value"
1439+
1440+
ctx, cancel := context.WithCancel(context.Background())
1441+
defer cancel()
1442+
ctx2 := context.WithValue(ctx, testKey, testValue)
1443+
1444+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
1445+
switch req.Method {
1446+
case "POST":
1447+
w.Header().Set("Content-Type", "application/json")
1448+
w.Header().Set("Mcp-Session-Id", "test-session")
1449+
w.WriteHeader(http.StatusOK)
1450+
w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{},"serverInfo":{"name":"test","version":"1.0"}}}`))
1451+
case "GET":
1452+
w.Header().Set("Content-Type", "text/event-stream")
1453+
w.WriteHeader(http.StatusOK)
1454+
case "DELETE":
1455+
w.WriteHeader(http.StatusNoContent)
1456+
}
1457+
}))
1458+
defer server.Close()
1459+
1460+
transport := &StreamableClientTransport{Endpoint: server.URL}
1461+
conn, err := transport.Connect(ctx2)
1462+
if err != nil {
1463+
t.Fatalf("Connect failed: %v", err)
1464+
}
1465+
defer conn.Close()
1466+
1467+
streamableConn, ok := conn.(*streamableClientConn)
1468+
if !ok {
1469+
t.Fatalf("Expected *streamableClientConn, got %T", conn)
1470+
}
1471+
1472+
if got := streamableConn.ctx.Value(testKey); got != testValue {
1473+
t.Errorf("Context value not propagated: got %v, want %v", got, testValue)
1474+
}
1475+
1476+
if streamableConn.ctx.Done() == nil {
1477+
t.Error("Connection context is not cancellable")
1478+
}
1479+
1480+
cancel()
1481+
select {
1482+
case <-streamableConn.ctx.Done():
1483+
case <-time.After(100 * time.Millisecond):
1484+
t.Error("Connection context was not cancelled when parent was cancelled")
1485+
}
1486+
1487+
}

0 commit comments

Comments
 (0)