diff --git a/pyproject.toml b/pyproject.toml index 78944c8..7bb9080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-mcp" -version = "0.0.76" +version = "0.0.77" description = "UiPath MCP SDK" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.10" diff --git a/src/uipath_mcp/_cli/_runtime/_session.py b/src/uipath_mcp/_cli/_runtime/_session.py index fe878fc..58cad5b 100644 --- a/src/uipath_mcp/_cli/_runtime/_session.py +++ b/src/uipath_mcp/_cli/_runtime/_session.py @@ -1,7 +1,7 @@ import asyncio import logging import tempfile -from typing import Optional +from typing import Dict, Optional import mcp.types as types from mcp import StdioServerParameters @@ -18,6 +18,7 @@ MAX_RETRIES = 3 RETRY_DELAY = 1 + class SessionServer: """Manages a server process for a specific session.""" @@ -29,6 +30,8 @@ def __init__(self, server_config: McpServer, session_id: str): self._mcp_session = None self._run_task = None self._message_queue = asyncio.Queue() + self._active_requests: Dict[str, str] = {} + self._last_request_id: None self._uipath = UiPath() self._mcp_tracer = McpTracer(tracer, logger) self._server_stderr_output: Optional[str] = None @@ -115,8 +118,27 @@ async def _run_server(self, server_params: StdioServerParameters) -> None: # Process incoming messages from the local server try: while True: + # Get message from local server message = await self._read_stream.receive() - await self._send_message(message, request_id=self._last_request_id) + + # For responses, determine which request_id to use + if self._is_response(message): + message_id = self._get_message_id(message) + if message_id and message_id in self._active_requests: + # Use the stored request_id for this response + request_id = self._active_requests[message_id] + # Send with the matched request_id + await self._send_message(message, request_id) + # Clean up the mapping after use + del self._active_requests[message_id] + else: + # If no mapping found, use the last known request_id + await self._send_message( + message, self._last_request_id + ) + else: + # For non-responses, use the last known request_id + await self._send_message(message, self._last_request_id) finally: # Cancel the consumer when we exit the loop consumer_task.cancel() @@ -153,13 +175,12 @@ async def _consume_messages(self): """Consume messages from the queue and send them to the local server.""" try: while True: - message, request_id = await self._message_queue.get() + message = await self._message_queue.get() try: if self._write_stream: logger.info( f"Session {self._session_id} - processing queued message: {message}..." ) - self._last_request_id = request_id await self._write_stream.send(message) except Exception as e: logger.error( @@ -176,10 +197,15 @@ async def _consume_messages(self): except asyncio.QueueEmpty: break - async def _send_message(self, message: types.JSONRPCMessage, request_id: str) -> None: + async def _send_message( + self, message: types.JSONRPCMessage, request_id: str + ) -> None: """Send new message to UiPath MCP Server.""" with self._mcp_tracer.create_span_for_message( - message, session_id=self._session_id, server_name=self._server_config.name + message, + session_id=self._session_id, + request_id=request_id, + server_name=self._server_config.name, ) as _: for attempt in range(MAX_RETRIES + 1): try: @@ -198,20 +224,18 @@ async def _send_message(self, message: types.JSONRPCMessage, request_id: str) -> ) raise - async def _send_message_internal(self, message: types.JSONRPCMessage, request_id: str) -> None: + async def _send_message_internal( + self, message: types.JSONRPCMessage, request_id: str + ) -> None: response = await self._uipath.api_client.request_async( "POST", f"mcp_/mcp/{self._server_config.name}/out/message?sessionId={self._session_id}&requestId={request_id}", json=message.model_dump(), ) if response.status_code == 202: - logger.info( - f"Outgoing message sent to UiPath MCP Server: {message}" - ) + logger.info(f"Outgoing message sent to UiPath MCP Server: {message}") elif 500 <= response.status_code < 600: - raise Exception( - f"{response.status_code} - {response.text}" - ) + raise Exception(f"{response.status_code} - {response.text}") async def _get_messages_internal(self, request_id: str) -> None: response = await self._uipath.api_client.request_async( @@ -219,17 +243,43 @@ async def _get_messages_internal(self, request_id: str) -> None: f"mcp_/mcp/{self._server_config.name}/in/messages?sessionId={self._session_id}&requestId={request_id}", ) if response.status_code == 200: + self._last_request_id = request_id messages = response.json() for message in messages: logger.info(f"Received message: {message}") json_message = types.JSONRPCMessage.model_validate(message) + if self._is_request(json_message): + message_id = self._get_message_id(json_message) + if message_id: + self._active_requests[message_id] = request_id with self._mcp_tracer.create_span_for_message( json_message, session_id=self._session_id, + request_id=request_id, server_name=self._server_config.name, ) as _: - await self._message_queue.put((json_message, request_id)) + await self._message_queue.put(json_message) elif 500 <= response.status_code < 600: - raise Exception( - f"{response.status_code} - {response.text}" + raise Exception(f"{response.status_code} - {response.text}") + + def _is_request(self, message: types.JSONRPCMessage) -> bool: + """Check if a message is a JSONRPCRequest.""" + if hasattr(message, "root"): + root = message.root + return isinstance(root, types.JSONRPCRequest) + return False + + def _is_response(self, message: types.JSONRPCMessage) -> bool: + """Check if a message is a JSONRPCResponse or JSONRPCError.""" + if hasattr(message, "root"): + root = message.root + return isinstance(root, types.JSONRPCResponse) or isinstance( + root, types.JSONRPCError ) + return False + + def _get_message_id(self, message: types.JSONRPCMessage) -> str: + """Extract the message id from a JSONRPCMessage.""" + if hasattr(message, "root") and hasattr(message.root, "id"): + return str(message.root.id) + return ""