77from mcp .client .session import ClientSession
88from mcp .server import Server
99from mcp .shared .session import BaseSession
10- from mcp .types import CallToolRequest , LoggingMessageNotification , RequestParams
10+ from mcp .types import CallToolRequest , LoggingMessageNotification
1111from pydantic import TypeAdapter
1212
1313from logfire ._internal .utils import handle_internal_errors
@@ -42,26 +42,22 @@ async def send_request(self: Any, request: Any, *args: Any, **kwargs: Any):
4242 span_name += f' { root .params .name } '
4343
4444 with logfire_instance .span (span_name , ** attributes ) as span :
45- with handle_internal_errors :
46- if propagate_otel_context : # pragma: no branch
47- carrier = get_context ()
48- if params := getattr (root , 'params' , None ):
49- if meta := getattr (params , 'meta' , None ): # pragma: no cover # TODO
50- dumped_meta = meta .model_dump ()
51- else :
52- dumped_meta = {}
53- # Prioritise existing values in meta over the context carrier.
54- # RequestParams.Meta should allow basically anything, we're being extra careful here.
55- params .meta = RequestParams .Meta .model_validate ({** carrier , ** dumped_meta })
56- else :
57- root .params = _request_params_type_adapter (type (root )).validate_python ({'_meta' : carrier }) # type: ignore
58-
45+ _attach_context_to_request (root )
5946 result = await original_send_request (self , request , * args , ** kwargs )
6047 span .set_attribute ('response' , result )
6148 return result
6249
6350 BaseSession .send_request = send_request
6451
52+ original_send_notification = BaseSession .send_notification # type: ignore
53+
54+ @functools .wraps (original_send_notification ) # type: ignore
55+ async def send_notification (self : Any , notification : Any , * args : Any , ** kwargs : Any ):
56+ _attach_context_to_request (notification .root )
57+ return await original_send_notification (self , notification , * args , ** kwargs )
58+
59+ BaseSession .send_notification = send_notification
60+
6561 original_received_notification = ClientSession ._received_notification # type: ignore
6662
6763 @functools .wraps (original_received_notification )
@@ -77,7 +73,8 @@ async def _received_notification(self: Any, notification: Any, *args: Any, **kwa
7773 span_name = 'MCP server log'
7874 if params .logger :
7975 span_name += f' from { params .logger } '
80- logfire_instance .log (level , span_name , attributes = dict (data = params .data ))
76+ with _request_context (notification .root ):
77+ logfire_instance .log (level , span_name , attributes = dict (data = params .data ))
8178 await original_received_notification (self , notification , * args , ** kwargs )
8279
8380 ClientSession ._received_notification = _received_notification # type: ignore
@@ -105,18 +102,39 @@ async def _handle_request(self: Any, message: Any, request: Any, *args: Any, **k
105102
106103 @contextmanager
107104 def _handle_request_with_context (request : Any , span_name : str ):
108- with ExitStack () as exit_stack :
109- if ( # pragma: no branch
110- propagate_otel_context
111- and (params := getattr (request , 'params' , None ))
112- and (meta := getattr (params , 'meta' , None ))
113- ):
114- exit_stack .enter_context (attach_context (meta .model_dump ()))
105+ with _request_context (request ):
115106 if method := getattr (request , 'method' , None ): # pragma: no branch
116107 span_name += f': { method } '
117108 with logfire_instance .span (span_name , request = request ):
118109 yield
119110
111+ @contextmanager
112+ def _request_context (request : Any ):
113+ with ExitStack () as exit_stack :
114+ with handle_internal_errors :
115+ if ( # pragma: no branch
116+ propagate_otel_context
117+ and (params := getattr (request , 'params' , None ))
118+ and (meta := getattr (params , 'meta' , None ))
119+ ):
120+ exit_stack .enter_context (attach_context (meta .model_dump ()))
121+ yield
122+
123+ def _attach_context_to_request (root : Any ):
124+ if not propagate_otel_context : # pragma: no cover
125+ return
126+ carrier = get_context ()
127+ if params := getattr (root , 'params' , None ):
128+ if meta := getattr (params , 'meta' , None ): # pragma: no cover # TODO
129+ dumped_meta = meta .model_dump ()
130+ else :
131+ dumped_meta = {}
132+ # Prioritise existing values in meta over the context carrier.
133+ # RequestParams.Meta should allow basically anything, we're being extra careful here.
134+ params .meta = type (params ).Meta .model_validate ({** carrier , ** dumped_meta }) # type: ignore
135+ else :
136+ root .params = _request_params_type_adapter (type (root )).validate_python ({'_meta' : carrier }) # type: ignore
137+
120138
121139@functools .lru_cache
122140def _request_params_type_adapter (root_type : Any ):
0 commit comments