@@ -1430,3 +1430,106 @@ func TestStreamableGET(t *testing.T) {
14301430 t .Errorf ("GET with session ID: got status %d, want %d" , got , want )
14311431 }
14321432}
1433+
1434+ // contextCapturingHandler wraps fakeStreamableServer and captures request contexts
1435+ type contextCapturingHandler struct {
1436+ capturedGetContext * context.Context
1437+ capturedDeleteContext * context.Context
1438+ mu * sync.Mutex
1439+ server * fakeStreamableServer
1440+ }
1441+
1442+ func (h * contextCapturingHandler ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
1443+ h .mu .Lock ()
1444+ switch req .Method {
1445+ case http .MethodGet :
1446+ * h .capturedGetContext = req .Context ()
1447+ case http .MethodDelete :
1448+ * h .capturedDeleteContext = req .Context ()
1449+ }
1450+ h .mu .Unlock ()
1451+
1452+ // Delegate to the fake server
1453+ h .server .ServeHTTP (w , req )
1454+ }
1455+
1456+ func TestStreamableClientContextPropagation (t * testing.T ) {
1457+ // Test that context values are propagated to background HTTP requests
1458+ // (SSE GET and cleanup DELETE requests) in StreamableClientTransport.
1459+
1460+ type contextKey string
1461+ const testKey contextKey = "test-key"
1462+ const testValue = "test-value"
1463+
1464+ ctx := context .WithValue (context .Background (), testKey , testValue )
1465+
1466+ var capturedGetContext , capturedDeleteContext context.Context
1467+ var mu sync.Mutex
1468+
1469+ fake := & fakeStreamableServer {
1470+ t : t ,
1471+ responses : fakeResponses {
1472+ {"POST" , "" , methodInitialize }: {
1473+ header : header {
1474+ "Content-Type" : "application/json" ,
1475+ sessionIDHeader : "123" ,
1476+ },
1477+ body : jsonBody (t , initResp ),
1478+ },
1479+ {"POST" , "123" , notificationInitialized }: {
1480+ status : http .StatusAccepted ,
1481+ wantProtocolVersion : latestProtocolVersion ,
1482+ },
1483+ {"GET" , "123" , "" }: {
1484+ header : header {
1485+ "Content-Type" : "text/event-stream" ,
1486+ },
1487+ optional : true ,
1488+ wantProtocolVersion : latestProtocolVersion ,
1489+ callback : func () {
1490+ // This captures the context when GET request is made
1491+ // Note: We can't directly access req.Context() here, but
1492+ // the test verifies that the fix enables context propagation
1493+ },
1494+ },
1495+ {"DELETE" , "123" , "" }: {},
1496+ },
1497+ }
1498+
1499+ handler := & contextCapturingHandler {
1500+ capturedGetContext : & capturedGetContext ,
1501+ capturedDeleteContext : & capturedDeleteContext ,
1502+ mu : & mu ,
1503+ server : fake ,
1504+ }
1505+
1506+ httpServer := httptest .NewServer (handler )
1507+ defer httpServer .Close ()
1508+
1509+ streamTransport := & StreamableClientTransport {Endpoint : httpServer .URL }
1510+ mcpClient := NewClient (testImpl , nil )
1511+ session , err := mcpClient .Connect (ctx , streamTransport , nil )
1512+ if err != nil {
1513+ t .Fatalf ("client.Connect() failed: %v" , err )
1514+ }
1515+
1516+ if err := session .Close (); err != nil {
1517+ t .Errorf ("closing session: %v" , err )
1518+ }
1519+
1520+ mu .Lock ()
1521+ defer mu .Unlock ()
1522+
1523+ if capturedGetContext == nil {
1524+ t .Error ("GET request context was not captured" )
1525+ } else if got := capturedGetContext .Value (testKey ); got != testValue {
1526+ t .Errorf ("GET request context value: got %v, want %v" , got , testValue )
1527+ }
1528+
1529+ if capturedDeleteContext == nil {
1530+ t .Error ("DELETE request context was not captured" )
1531+ } else if got := capturedDeleteContext .Value (testKey ); got != testValue {
1532+ t .Errorf ("DELETE request context value: got %v, want %v" , got , testValue )
1533+ }
1534+
1535+ }
0 commit comments