Skip to content

Commit 2934625

Browse files
committed
test: add context propagation test for StreamableClientTransport
Adds TestStreamableClientContextPropagation to verify that context values are properly propagated to background HTTP operations (SSE GET and cleanup DELETE requests) in StreamableClientTransport. The test uses a custom HTTP handler to capture request contexts and verify that context values from the parent context are accessible in both the SSE connection establishment and session cleanup requests.
1 parent 7e13da4 commit 2934625

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed

mcp/streamable_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,3 +1430,106 @@ func TestStreamableGET(t *testing.T) {
14301430
t.Errorf("GET with session ID: got status %d, want %d", got, want)
14311431
}
14321432
}
1433+
1434+
// contextCapturingHandler wraps fakeStreamableServer and captures request contexts
1435+
type contextCapturingHandler struct {
1436+
capturedGetContext *context.Context
1437+
capturedDeleteContext *context.Context
1438+
mu *sync.Mutex
1439+
server *fakeStreamableServer
1440+
}
1441+
1442+
func (h *contextCapturingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
1443+
h.mu.Lock()
1444+
switch req.Method {
1445+
case http.MethodGet:
1446+
*h.capturedGetContext = req.Context()
1447+
case http.MethodDelete:
1448+
*h.capturedDeleteContext = req.Context()
1449+
}
1450+
h.mu.Unlock()
1451+
1452+
// Delegate to the fake server
1453+
h.server.ServeHTTP(w, req)
1454+
}
1455+
1456+
func TestStreamableClientContextPropagation(t *testing.T) {
1457+
// Test that context values are propagated to background HTTP requests
1458+
// (SSE GET and cleanup DELETE requests) in StreamableClientTransport.
1459+
1460+
type contextKey string
1461+
const testKey contextKey = "test-key"
1462+
const testValue = "test-value"
1463+
1464+
ctx := context.WithValue(context.Background(), testKey, testValue)
1465+
1466+
var capturedGetContext, capturedDeleteContext context.Context
1467+
var mu sync.Mutex
1468+
1469+
fake := &fakeStreamableServer{
1470+
t: t,
1471+
responses: fakeResponses{
1472+
{"POST", "", methodInitialize}: {
1473+
header: header{
1474+
"Content-Type": "application/json",
1475+
sessionIDHeader: "123",
1476+
},
1477+
body: jsonBody(t, initResp),
1478+
},
1479+
{"POST", "123", notificationInitialized}: {
1480+
status: http.StatusAccepted,
1481+
wantProtocolVersion: latestProtocolVersion,
1482+
},
1483+
{"GET", "123", ""}: {
1484+
header: header{
1485+
"Content-Type": "text/event-stream",
1486+
},
1487+
optional: true,
1488+
wantProtocolVersion: latestProtocolVersion,
1489+
callback: func() {
1490+
// This captures the context when GET request is made
1491+
// Note: We can't directly access req.Context() here, but
1492+
// the test verifies that the fix enables context propagation
1493+
},
1494+
},
1495+
{"DELETE", "123", ""}: {},
1496+
},
1497+
}
1498+
1499+
handler := &contextCapturingHandler{
1500+
capturedGetContext: &capturedGetContext,
1501+
capturedDeleteContext: &capturedDeleteContext,
1502+
mu: &mu,
1503+
server: fake,
1504+
}
1505+
1506+
httpServer := httptest.NewServer(handler)
1507+
defer httpServer.Close()
1508+
1509+
streamTransport := &StreamableClientTransport{Endpoint: httpServer.URL}
1510+
mcpClient := NewClient(testImpl, nil)
1511+
session, err := mcpClient.Connect(ctx, streamTransport, nil)
1512+
if err != nil {
1513+
t.Fatalf("client.Connect() failed: %v", err)
1514+
}
1515+
1516+
if err := session.Close(); err != nil {
1517+
t.Errorf("closing session: %v", err)
1518+
}
1519+
1520+
mu.Lock()
1521+
defer mu.Unlock()
1522+
1523+
if capturedGetContext == nil {
1524+
t.Error("GET request context was not captured")
1525+
} else if got := capturedGetContext.Value(testKey); got != testValue {
1526+
t.Errorf("GET request context value: got %v, want %v", got, testValue)
1527+
}
1528+
1529+
if capturedDeleteContext == nil {
1530+
t.Error("DELETE request context was not captured")
1531+
} else if got := capturedDeleteContext.Value(testKey); got != testValue {
1532+
t.Errorf("DELETE request context value: got %v, want %v", got, testValue)
1533+
}
1534+
1535+
}

0 commit comments

Comments
 (0)