77
88import httpx
99from mcp import ClientSession
10+ from mcp .client .session import SamplingFnT
1011
1112from ..logging import logger
12- from ..task_managers import ConnectionManager , SseConnectionManager , StreamableHttpConnectionManager
13+ from ..task_managers import SseConnectionManager , StreamableHttpConnectionManager
1314from .base import BaseConnector
1415
1516
@@ -27,6 +28,7 @@ def __init__(
2728 headers : dict [str , str ] | None = None ,
2829 timeout : float = 5 ,
2930 sse_read_timeout : float = 60 * 5 ,
31+ sampling_callback : SamplingFnT | None = None ,
3032 ):
3133 """Initialize a new HTTP connector.
3234
@@ -36,8 +38,9 @@ def __init__(
3638 headers: Optional additional headers.
3739 timeout: Timeout for HTTP operations in seconds.
3840 sse_read_timeout: Timeout for SSE read operations in seconds.
41+ sampling_callback: Optional sampling callback.
3942 """
40- super ().__init__ ()
43+ super ().__init__ (sampling_callback = sampling_callback )
4144 self .base_url = base_url .rstrip ("/" )
4245 self .auth_token = auth_token
4346 self .headers = headers or {}
@@ -46,14 +49,6 @@ def __init__(
4649 self .timeout = timeout
4750 self .sse_read_timeout = sse_read_timeout
4851
49- async def _setup_client (self , connection_manager : ConnectionManager ) -> None :
50- """Set up the client session with the provided connection manager."""
51-
52- self ._connection_manager = connection_manager
53- read_stream , write_stream = await self ._connection_manager .start ()
54- self .client_session = ClientSession (read_stream , write_stream , sampling_callback = None )
55- await self .client_session .__aenter__ ()
56-
5752 async def connect (self ) -> None :
5853 """Establish a connection to the MCP implementation."""
5954 if self ._connected :
@@ -76,7 +71,9 @@ async def connect(self) -> None:
7671 read_stream , write_stream = await connection_manager .start ()
7772
7873 # Test if this actually works by trying to create a client session and initialize it
79- test_client = ClientSession (read_stream , write_stream , sampling_callback = None )
74+ test_client = ClientSession (
75+ read_stream , write_stream , sampling_callback = self .sampling_callback , client_info = self .client_info
76+ )
8077 await test_client .__aenter__ ()
8178
8279 try :
@@ -154,7 +151,12 @@ async def connect(self) -> None:
154151 read_stream , write_stream = await connection_manager .start ()
155152
156153 # Create the client session for SSE
157- self .client_session = ClientSession (read_stream , write_stream , sampling_callback = None )
154+ self .client_session = ClientSession (
155+ read_stream ,
156+ write_stream ,
157+ sampling_callback = self .sampling_callback ,
158+ client_info = self .client_info ,
159+ )
158160 await self .client_session .__aenter__ ()
159161 self .transport_type = "SSE"
160162
0 commit comments