Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/mcpm/monitor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,37 @@
from enum import Enum, auto
from typing import Any, Dict, Optional, Union

from pydantic import BaseModel, Field


class Pagination(BaseModel):
total: int = Field(description="Total number of events")
page: int = Field(description="Page number")
limit: int = Field(description="Number of events per page")
total_pages: int = Field(description="Total number of pages")


class MCPEvent(BaseModel):
id: int = Field(description="Event ID")
event_type: str = Field(description="Event type")
server_id: str = Field(description="Server ID")
resource_id: str = Field(description="Resource ID")
client_id: Optional[str] = Field(description="Client ID")
timestamp: str = Field(description="Event timestamp")
duration_ms: Optional[int] = Field(description="Event duration in milliseconds")
request_size: Optional[int] = Field(description="Request size in bytes")
response_size: Optional[int] = Field(description="Response size in bytes")
success: bool = Field(description="Event success status")
error_message: Optional[str] = Field(description="Error message")
metadata: Optional[Dict[str, Any]] = Field(description="Event metadata")
raw_request: Optional[Union[str, Dict]] = Field(description="Raw request data")
raw_response: Optional[Union[str, Dict]] = Field(description="Raw response data")


class QueryEventResponse(BaseModel):
pagination: Pagination = Field(description="Pagination information")
events: list[MCPEvent] = Field(description="List of events")


class AccessEventType(Enum):
"""Type of MCP access event"""
Expand Down Expand Up @@ -75,3 +106,21 @@ async def close(self) -> None:
Close any open connections to the storage backend
"""
pass

@abstractmethod
async def query_events(
self, offset: str, page: int, limit: int, event_type: Optional[str] = None
) -> QueryEventResponse:
"""
Query events from the storage backend

Args:
offset: Time offset for the query
page: Page number
limit: Number of events per page
event_type: Type of events to query (optional)

Returns:
QueryEventResponse: List of events matching the query
"""
pass
129 changes: 128 additions & 1 deletion src/mcpm/monitor/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import duckdb

from mcpm.monitor.base import AccessEventType, AccessMonitor
from mcpm.monitor.base import AccessEventType, AccessMonitor, MCPEvent, Pagination, QueryEventResponse
from mcpm.utils.config import ConfigManager


Expand Down Expand Up @@ -234,6 +234,133 @@ def _track_event_impl(
print(f"Error tracking event: {e}")
return False

async def query_events(
self, offset: str, page: int, limit: int, event_type: Optional[str] = None
) -> QueryEventResponse:
"""
Query events from the database with pagination.

Args:
offset: Time offset pattern like "3h" for past 3 hours, "1d" for past day, etc.
page: Page number (1-based)
limit: Number of events per page
event_type: Type of events to filter by

Returns:
Dict containing events, total count, page, and limit
"""
if not self._initialized:
if not await self.initialize_storage():
return QueryEventResponse(pagination=Pagination(total=0, page=0, limit=0, total_pages=0), events=[])

async with self._lock:
response = await asyncio.to_thread(
self._query_events_impl,
offset,
page,
limit,
event_type,
)
return response

def _query_events_impl(
self,
offset: str,
page: int,
limit: int,
event_type: Optional[str],
) -> QueryEventResponse:
"""
Query events from the storage backend

Args:
offset: Time offset for the query
page: Page number
limit: Number of events per page
event_type: Type of events to query (optional)

Returns:
QueryEventResponse: List of events matching the query
"""
try:
# Build the base query and conditions
conditions = []
parameters = []

# handle time offset
time_value = 0
time_unit = ""

# Parse offset pattern like "3h", "1d", etc.
for i, char in enumerate(offset):
if char.isdigit():
time_value = time_value * 10 + int(char)
else:
time_unit = offset[i:]
break

if time_unit and time_value > 0:
# Convert to SQL interval format
interval_map = {"h": "HOUR", "d": "DAY", "w": "WEEK", "m": "MONTH"}

if time_unit.lower() in interval_map:
conditions.append(
f"timestamp >= TIMESTAMP '{datetime.now()}' - INTERVAL {time_value} {interval_map.get(time_unit.lower())}"
)
else:
return QueryEventResponse(pagination=Pagination(total=0, page=0, limit=0, total_pages=0), events=[])

if event_type:
conditions.append("event_type = ?")
parameters.append(event_type)

# Build the final query
where_clause = " AND ".join(conditions)
if where_clause:
where_clause = f"WHERE {where_clause}"

sql_offset = (page - 1) * limit
# Get total count
count_query = f"SELECT COUNT(*) FROM monitor_events {where_clause}"
total_result = self.connection.execute(count_query, parameters).fetchone()
total = total_result[0] if total_result else 0

# Get paginated results
query = f"""
SELECT * FROM monitor_events
{where_clause}
ORDER BY timestamp DESC
LIMIT ? OFFSET ?
"""
cursor = self.connection.execute(query, parameters + [limit, sql_offset])

# Convert result to dictionary
column_names = [desc[0] for desc in cursor.description]
events = []

for row in cursor.fetchall():
event_dict = dict(zip(column_names, row))

for field in ["metadata", "raw_request", "raw_response"]:
if event_dict[field] and isinstance(event_dict[field], str):
try:
event_dict[field] = json.loads(event_dict[field])
except Exception:
pass

event_dict["timestamp"] = datetime.strftime(event_dict["timestamp"], "%Y-%m-%d %H:%M:%S")
events.append(MCPEvent.model_validate(event_dict))

return QueryEventResponse(
pagination=Pagination(
total=total, page=page, limit=limit, total_pages=1 if limit == 0 else (total + limit - 1) // limit
),
events=events,
)
except Exception as e:
print(f"Error querying events: {e}")
return QueryEventResponse(pagination=Pagination(total=0, page=0, limit=0, total_pages=0), events=[])

async def close(self) -> None:
"""Close the database connection asynchronously."""
async with self._lock:
Expand Down
110 changes: 110 additions & 0 deletions src/mcpm/monitor/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import datetime
import time
import typing
from functools import wraps

from mcp.types import (
CallToolRequest,
CallToolResult,
EmptyResult,
GetPromptRequest,
ReadResourceRequest,
Request,
ServerResult,
TextContent,
)

from mcpm.utils.config import PROMPT_SPLITOR, RESOURCE_SPLITOR, TOOL_SPLITOR

from .base import AccessEventType
from .duckdb import DuckDBAccessMonitor

monitor = DuckDBAccessMonitor()

RequestT = typing.TypeVar("RequestT", bound=Request)
MCPRequestHandler = typing.Callable[[RequestT], typing.Awaitable[ServerResult]]


class TraceIdentifier(typing.TypedDict):
client_id: str
server_id: str
resource_id: str


class ResponseStatus(typing.TypedDict):
success: bool
error_message: str


def get_trace_identifier(req: Request) -> TraceIdentifier:
resource_id = ""
if isinstance(req, CallToolRequest):
server_id = req.params.name.split(TOOL_SPLITOR, 1)[0]
elif isinstance(req, GetPromptRequest):
server_id = req.params.name.split(PROMPT_SPLITOR, 1)[0]
elif isinstance(req, ReadResourceRequest):
# resource uri is formatted as {server_id}:{protocol}://{resource_path}
server_id, resource_id = str(req.params.uri).split(RESOURCE_SPLITOR, 1)
else:
# currently only support call tool, get prompt and read resource
server_id = ""
resource_id = ""

return TraceIdentifier(client_id=req.params.meta.client_id, server_id=server_id, resource_id=resource_id) # type: ignore


def get_response_status(server_result: ServerResult) -> ResponseStatus:
result_root = server_result.root

if isinstance(result_root, EmptyResult):
return ResponseStatus(success=False, error_message="empty result")

if isinstance(result_root, CallToolResult):
if result_root.isError:
return ResponseStatus(
success=False,
error_message=typing.cast(TextContent, result_root.content[0]).text,
)
else:
return ResponseStatus(
success=True,
error_message="",
)

return ResponseStatus(success=True, error_message="")


def trace_event(event_type: AccessEventType):
def decorator(func: MCPRequestHandler):
@wraps(func)
async def wrapper(request: Request):
request_time = datetime.datetime.now().replace(microsecond=0)
start_time = time.perf_counter()
# parse client id, server id and resource id (optional) from request
trace_identifier = get_trace_identifier(request)

response: ServerResult = await func(request)

# empty results and call tool failures are treated as not success
response_status: ResponseStatus = get_response_status(response)

await monitor.track_event(
event_type=event_type,
server_id=trace_identifier["server_id"],
client_id=trace_identifier["client_id"],
resource_id=trace_identifier["resource_id"],
timestamp=request_time,
duration_ms=int((time.perf_counter() - start_time) * 1000),
request_size=len(request.params.model_dump_json().encode("utf-8")),
response_size=len(response.root.model_dump_json().encode("utf-8")),
success=response_status["success"],
error_message=response_status["error_message"],
metadata=None,
raw_request=request.model_dump_json(),
raw_response=response.root.model_dump_json(),
)
return response

return wrapper

return decorator
Loading