33import functools
44from typing import TYPE_CHECKING , Any
55
6- from mcp .shared .session import BaseSession , SendRequestT
7- from mcp .types import CallToolRequest
6+ from mcp .client .session import ClientSession
7+ from mcp .shared .session import BaseSession
8+ from mcp .types import CallToolRequest , LoggingMessageNotification
9+
10+ from logfire ._internal .utils import handle_internal_errors
811
912if TYPE_CHECKING :
10- from logfire import Logfire
13+ from logfire import LevelName , Logfire
1114
1215
1316def instrument_mcp (logfire_instance : Logfire ):
1417 logfire_instance = logfire_instance .with_settings (custom_scope_suffix = 'mcp' )
1518
16- original = BaseSession .send_request # type: ignore
19+ original_send_request = BaseSession .send_request # type: ignore
1720
18- @functools .wraps (original ) # type: ignore
19- async def send_request (self , request : SendRequestT , * args , ** kwargs : Any ): # type: ignore
21+ @functools .wraps (original_send_request ) # type: ignore
22+ async def send_request (self : Any , request : Any , * args : Any , ** kwargs : Any ):
2023 attributes : dict [str , Any ] = {
2124 'request' : request ,
2225 # https://opentelemetry.io/docs/specs/semconv/rpc/json-rpc/
@@ -35,8 +38,28 @@ async def send_request(self, request: SendRequestT, *args, **kwargs: Any): # ty
3538 span_name += f' { root .params .name } '
3639
3740 with logfire_instance .span (span_name , ** attributes ) as span :
38- result = await original (self , request , * args , ** kwargs ) # type: ignore
41+ result = await original_send_request (self , request , * args , ** kwargs )
3942 span .set_attribute ('response' , result )
40- return result # type: ignore
43+ return result
4144
4245 BaseSession .send_request = send_request
46+
47+ original_received_notification = ClientSession ._received_notification # type: ignore
48+
49+ @functools .wraps (original_received_notification )
50+ async def _received_notification (self : Any , notification : Any , * args : Any , ** kwargs : Any ):
51+ with handle_internal_errors :
52+ if isinstance (notification .root , LoggingMessageNotification ): # pragma: no branch
53+ params = notification .root .params
54+ level : LevelName
55+ if params .level in ('critical' , 'alert' , 'emergency' ):
56+ level = 'fatal'
57+ else :
58+ level = params .level
59+ span_name = 'MCP server log'
60+ if params .logger :
61+ span_name += f' from { params .logger } '
62+ logfire_instance .log (level , span_name , attributes = dict (data = params .data ))
63+ await original_received_notification (self , notification , * args , ** kwargs )
64+
65+ ClientSession ._received_notification = _received_notification # type: ignore
0 commit comments