Skip to content

Commit 3f576c2

Browse files
committed
Added input/output types, further code cleanup, got rid of pylint disable self use
1 parent 4ea1757 commit 3f576c2

File tree

3 files changed

+152
-115
lines changed

3 files changed

+152
-115
lines changed
Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# MCP Instrumentor
22

3-
OpenTelemetry MCP instrumentation package.
3+
OpenTelemetry MCP instrumentation package for AWS Distro.
44

55
## Installation
66

77
```bash
8-
pip install mcpinstrumentor
8+
pip install amazon-opentelemetry-distro-mcpinstrumentor
99
```
1010

1111
## Usage
@@ -14,4 +14,20 @@ pip install mcpinstrumentor
1414
from mcpinstrumentor import MCPInstrumentor
1515

1616
MCPInstrumentor().instrument()
17-
```
17+
```
18+
19+
## Configuration
20+
21+
### Environment Variables
22+
23+
- `MCP_SERVICE_NAME`: Sets the service name for MCP client spans. Defaults to "Generic MCP Server" if not set.
24+
25+
```bash
26+
export MCP_SERVICE_NAME="My Custom MCP Server"
27+
```
28+
29+
## Features
30+
31+
- Automatic instrumentation of MCP client and server requests
32+
- Distributed tracing support with trace context propagation
33+
- Configurable service naming via environment variables

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py

Lines changed: 78 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
import logging
4-
from typing import Any, Collection
4+
from typing import Any, Callable, Collection, Dict, Tuple
55

6+
from mcp import ClientRequest
67
from wrapt import register_post_import_hook, wrap_function_wrapper
78

89
from opentelemetry import trace
@@ -29,12 +30,46 @@ class MCPInstrumentor(BaseInstrumentor):
2930
An instrumenter for MCP.
3031
"""
3132

32-
def __init__(self): # pylint: disable=no-self-use
33+
def __init__(self):
3334
super().__init__()
3435
self.tracer = None
3536

37+
@staticmethod
38+
def instrumentation_dependencies() -> Collection[str]:
39+
return _instruments
40+
41+
def _instrument(self, **kwargs: Any) -> None:
42+
tracer_provider = kwargs.get("tracer_provider")
43+
if tracer_provider:
44+
self.tracer = tracer_provider.get_tracer("mcp")
45+
else:
46+
self.tracer = trace.get_tracer("mcp")
47+
register_post_import_hook(
48+
lambda _: wrap_function_wrapper(
49+
"mcp.shared.session",
50+
"BaseSession.send_request",
51+
self._wrap_send_request,
52+
),
53+
"mcp.shared.session",
54+
)
55+
register_post_import_hook(
56+
lambda _: wrap_function_wrapper(
57+
"mcp.server.lowlevel.server",
58+
"Server._handle_request",
59+
self._wrap_handle_request,
60+
),
61+
"mcp.server.lowlevel.server",
62+
)
63+
64+
@staticmethod
65+
def _uninstrument(**kwargs: Any) -> None:
66+
unwrap("mcp.shared.session", "BaseSession.send_request")
67+
unwrap("mcp.server.lowlevel.server", "Server._handle_request")
68+
3669
# Send Request Wrapper
37-
def _wrap_send_request(self, wrapped, instance, args, kwargs): # pylint: disable=no-self-use
70+
def _wrap_send_request(
71+
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
72+
) -> Callable:
3873
"""
3974
Changes made:
4075
The wrapper intercepts the request before sending, injects distributed tracing context into the
@@ -43,14 +78,14 @@ def _wrap_send_request(self, wrapped, instance, args, kwargs): # pylint: disabl
4378
type and calling the original function with identical parameters.
4479
"""
4580

46-
async def async_wrapper(): # pylint: disable=no-self-use
81+
async def async_wrapper():
4782
with self.tracer.start_as_current_span("client.send_request", kind=trace.SpanKind.CLIENT) as span:
4883
span_ctx = span.get_span_context()
4984
request = args[0] if len(args) > 0 else kwargs.get("request")
5085
if request:
5186
req_root = request.root if hasattr(request, "root") else request
5287

53-
self.handle_attributes(span, req_root, True)
88+
self._generate_mcp_attributes(span, req_root, is_client=True)
5489
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
5590
self._inject_trace_context(request_data, span_ctx)
5691
# Reconstruct request object with injected trace context
@@ -68,7 +103,9 @@ async def async_wrapper(): # pylint: disable=no-self-use
68103
return async_wrapper()
69104

70105
# Handle Request Wrapper
71-
async def _wrap_handle_request(self, wrapped, instance, args, kwargs): # pylint: disable=no-self-use
106+
async def _wrap_handle_request(
107+
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
108+
) -> Any:
72109
"""
73110
Changes made:
74111
This wrapper intercepts requests before processing, extracts distributed tracing context from
@@ -87,19 +124,35 @@ async def _wrap_handle_request(self, wrapped, instance, args, kwargs): # pylint
87124
traceparent = getattr(req.params.meta, "traceparent", None)
88125
span_context = self._extract_span_context_from_traceparent(traceparent) if traceparent else None
89126
if span_context:
90-
span_name = self._get_span_name(req)
127+
span_name = self._get_mcp_operation(req)
91128
with self.tracer.start_as_current_span(
92129
span_name,
93130
kind=trace.SpanKind.SERVER,
94131
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
95132
) as span:
96-
self.handle_attributes(span, req, False)
133+
self._generate_mcp_attributes(span, req, False)
97134
result = await wrapped(*args, **kwargs)
98135
return result
99136
else:
100137
return await wrapped(*args, **kwargs)
101138

102-
def _inject_trace_context(self, request_data, span_ctx): # pylint: disable=no-self-use
139+
def _generate_mcp_attributes(self, span: trace.Span, request: ClientRequest, is_client: bool) -> None:
140+
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
141+
142+
operation = "UnknownOperation"
143+
if isinstance(request, types.ListToolsRequest):
144+
operation = "ListTool"
145+
span.set_attribute("mcp.list_tools", True)
146+
elif isinstance(request, types.CallToolRequest):
147+
operation = request.params.name
148+
span.set_attribute("mcp.call_tool", True)
149+
if is_client:
150+
self._add_client_attributes(span, operation, request)
151+
else:
152+
self._add_server_attributes(span, operation, request)
153+
154+
@staticmethod
155+
def _inject_trace_context(request_data: Dict[str, Any], span_ctx) -> None:
103156
if "params" not in request_data:
104157
request_data["params"] = {}
105158
if "_meta" not in request_data["params"]:
@@ -110,7 +163,8 @@ def _inject_trace_context(self, request_data, span_ctx): # pylint: disable=no-s
110163
traceparent = f"00-{trace_id_hex}-{span_id_hex}-{trace_flags}"
111164
request_data["params"]["_meta"]["traceparent"] = traceparent
112165

113-
def _extract_span_context_from_traceparent(self, traceparent): # pylint: disable=no-self-use
166+
@staticmethod
167+
def _extract_span_context_from_traceparent(traceparent: str):
114168
parts = traceparent.split("-")
115169
if len(parts) == 4:
116170
try:
@@ -127,72 +181,29 @@ def _extract_span_context_from_traceparent(self, traceparent): # pylint: disabl
127181
return None
128182
return None
129183

130-
def _get_span_name(self, req): # pylint: disable=no-self-use
131-
span_name = "unknown"
184+
@staticmethod
185+
def _get_mcp_operation(req: ClientRequest) -> str:
132186
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
133187

188+
span_name = "unknown"
189+
134190
if isinstance(req, types.ListToolsRequest):
135191
span_name = "tools/list"
136192
elif isinstance(req, types.CallToolRequest):
137-
if hasattr(req, "params") and hasattr(req.params, "name"):
138-
span_name = f"tools/{req.params.name}"
139-
else:
140-
span_name = "unknown"
193+
span_name = f"tools/{req.params.name}"
141194
return span_name
142195

143-
def handle_attributes(self, span, request, is_client=True): # pylint: disable=no-self-use
144-
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
145-
146-
operation = self._get_span_name(request)
147-
if isinstance(request, types.ListToolsRequest):
148-
operation = "ListTool"
149-
span.set_attribute("mcp.list_tools", True)
150-
elif isinstance(request, types.CallToolRequest):
151-
if hasattr(request, "params") and hasattr(request.params, "name"):
152-
operation = request.params.name
153-
span.set_attribute("mcp.call_tool", True)
154-
if is_client:
155-
self._add_client_attributes(span, operation, request)
156-
else:
157-
self._add_server_attributes(span, operation, request)
196+
@staticmethod
197+
def _add_client_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None:
198+
import os # pylint: disable=import-outside-toplevel
158199

159-
def _add_client_attributes(self, span, operation, request): # pylint: disable=no-self-use
160-
span.set_attribute("aws.remote.service", "Appsignals MCP Server")
200+
service_name = os.environ.get("MCP_SERVICE_NAME", "Generic MCP Server")
201+
span.set_attribute("aws.remote.service", service_name)
161202
span.set_attribute("aws.remote.operation", operation)
162-
if hasattr(request, "params") and hasattr(request.params, "name"):
203+
if hasattr(request, "params") and request.params and hasattr(request.params, "name"):
163204
span.set_attribute("tool.name", request.params.name)
164205

165-
def _add_server_attributes(self, span, operation, request): # pylint: disable=no-self-use
166-
span.set_attribute("server_side", True)
167-
if hasattr(request, "params") and hasattr(request.params, "name"):
206+
@staticmethod
207+
def _add_server_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None:
208+
if hasattr(request, "params") and request.params and hasattr(request.params, "name"):
168209
span.set_attribute("tool.name", request.params.name)
169-
170-
def instrumentation_dependencies(self) -> Collection[str]: # pylint: disable=no-self-use
171-
return _instruments
172-
173-
def _instrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use
174-
tracer_provider = kwargs.get("tracer_provider")
175-
if tracer_provider:
176-
self.tracer = tracer_provider.get_tracer("mcp")
177-
else:
178-
self.tracer = trace.get_tracer("mcp")
179-
register_post_import_hook(
180-
lambda _: wrap_function_wrapper(
181-
"mcp.shared.session",
182-
"BaseSession.send_request",
183-
self._wrap_send_request,
184-
),
185-
"mcp.shared.session",
186-
)
187-
register_post_import_hook(
188-
lambda _: wrap_function_wrapper(
189-
"mcp.server.lowlevel.server",
190-
"Server._handle_request",
191-
self._wrap_handle_request,
192-
),
193-
"mcp.server.lowlevel.server",
194-
)
195-
196-
def _uninstrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use
197-
unwrap("mcp.shared.session", "BaseSession.send_request")
198-
unwrap("mcp.server.lowlevel.server", "Server._handle_request")

0 commit comments

Comments
 (0)