Skip to content

Commit 473a047

Browse files
committed
test: add context propagation test for StreamableClientTransport
Adds TestStreamableClientContextPropagation to verify that context values are properly propagated to background HTTP operations (SSE GET and cleanup DELETE requests) in StreamableClientTransport. The test uses a custom HTTP handler to capture request contexts and verify that context values from the parent context are accessible in both the SSE connection establishment and session cleanup requests.
1 parent 7e13da4 commit 473a047

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

mcp/streamable_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,3 +1430,111 @@ func TestStreamableGET(t *testing.T) {
14301430
t.Errorf("GET with session ID: got status %d, want %d", got, want)
14311431
}
14321432
}
1433+
1434+
// contextCapturingHandler wraps fakeStreamableServer and captures request contexts
1435+
type contextCapturingHandler struct {
1436+
capturedGetContext *context.Context
1437+
capturedDeleteContext *context.Context
1438+
mu *sync.Mutex
1439+
server *fakeStreamableServer
1440+
}
1441+
1442+
func (h *contextCapturingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
1443+
h.mu.Lock()
1444+
switch req.Method {
1445+
case http.MethodGet:
1446+
*h.capturedGetContext = req.Context()
1447+
case http.MethodDelete:
1448+
*h.capturedDeleteContext = req.Context()
1449+
}
1450+
h.mu.Unlock()
1451+
1452+
// Delegate to the fake server
1453+
h.server.ServeHTTP(w, req)
1454+
}
1455+
1456+
func TestStreamableClientContextPropagation(t *testing.T) {
1457+
// Test that context values are propagated to background HTTP requests
1458+
// (SSE GET and cleanup DELETE requests) in StreamableClientTransport.
1459+
1460+
type contextKey string
1461+
const testKey contextKey = "test-key"
1462+
const testValue = "test-value"
1463+
1464+
// Create context with test value
1465+
ctx := context.WithValue(context.Background(), testKey, testValue)
1466+
1467+
var capturedGetContext, capturedDeleteContext context.Context
1468+
var mu sync.Mutex
1469+
1470+
// Enhanced fake server that captures contexts for specific requests
1471+
fake := &fakeStreamableServer{
1472+
t: t,
1473+
responses: fakeResponses{
1474+
{"POST", "", methodInitialize}: {
1475+
header: header{
1476+
"Content-Type": "application/json",
1477+
sessionIDHeader: "123",
1478+
},
1479+
body: jsonBody(t, initResp),
1480+
},
1481+
{"POST", "123", notificationInitialized}: {
1482+
status: http.StatusAccepted,
1483+
wantProtocolVersion: latestProtocolVersion,
1484+
},
1485+
{"GET", "123", ""}: {
1486+
header: header{
1487+
"Content-Type": "text/event-stream",
1488+
},
1489+
optional: true,
1490+
wantProtocolVersion: latestProtocolVersion,
1491+
callback: func() {
1492+
// This captures the context when GET request is made
1493+
// Note: We can't directly access req.Context() here, but
1494+
// the test verifies that the fix enables context propagation
1495+
},
1496+
},
1497+
{"DELETE", "123", ""}: {},
1498+
},
1499+
}
1500+
1501+
// Custom handler that wraps the fake server and captures contexts
1502+
handler := &contextCapturingHandler{
1503+
capturedGetContext: &capturedGetContext,
1504+
capturedDeleteContext: &capturedDeleteContext,
1505+
mu: &mu,
1506+
server: fake,
1507+
}
1508+
1509+
httpServer := httptest.NewServer(handler)
1510+
defer httpServer.Close()
1511+
1512+
streamTransport := &StreamableClientTransport{Endpoint: httpServer.URL}
1513+
mcpClient := NewClient(testImpl, nil)
1514+
session, err := mcpClient.Connect(ctx, streamTransport, nil)
1515+
if err != nil {
1516+
t.Fatalf("client.Connect() failed: %v", err)
1517+
}
1518+
1519+
// Close the session to trigger DELETE request
1520+
if err := session.Close(); err != nil {
1521+
t.Errorf("closing session: %v", err)
1522+
}
1523+
1524+
// Verify context propagation
1525+
mu.Lock()
1526+
defer mu.Unlock()
1527+
1528+
if capturedGetContext == nil {
1529+
t.Error("GET request context was not captured")
1530+
} else if got := capturedGetContext.Value(testKey); got != testValue {
1531+
t.Errorf("GET request context value: got %v, want %v", got, testValue)
1532+
}
1533+
1534+
if capturedDeleteContext == nil {
1535+
t.Error("DELETE request context was not captured")
1536+
} else if got := capturedDeleteContext.Value(testKey); got != testValue {
1537+
t.Errorf("DELETE request context value: got %v, want %v", got, testValue)
1538+
}
1539+
1540+
}

0 commit comments

Comments
 (0)