Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "uipath-mcp"
version = "0.0.76"
version = "0.0.77"
description = "UiPath MCP SDK"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.10"
Expand Down
82 changes: 66 additions & 16 deletions src/uipath_mcp/_cli/_runtime/_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
import tempfile
from typing import Optional
from typing import Dict, Optional

import mcp.types as types
from mcp import StdioServerParameters
Expand All @@ -18,6 +18,7 @@
MAX_RETRIES = 3
RETRY_DELAY = 1


class SessionServer:
"""Manages a server process for a specific session."""

Expand All @@ -29,6 +30,8 @@ def __init__(self, server_config: McpServer, session_id: str):
self._mcp_session = None
self._run_task = None
self._message_queue = asyncio.Queue()
self._active_requests: Dict[str, str] = {}
self._last_request_id: None
self._uipath = UiPath()
self._mcp_tracer = McpTracer(tracer, logger)
self._server_stderr_output: Optional[str] = None
Expand Down Expand Up @@ -115,8 +118,27 @@ async def _run_server(self, server_params: StdioServerParameters) -> None:
# Process incoming messages from the local server
try:
while True:
# Get message from local server
message = await self._read_stream.receive()
await self._send_message(message, request_id=self._last_request_id)

# For responses, determine which request_id to use
if self._is_response(message):
message_id = self._get_message_id(message)
if message_id and message_id in self._active_requests:
# Use the stored request_id for this response
request_id = self._active_requests[message_id]
# Send with the matched request_id
await self._send_message(message, request_id)
# Clean up the mapping after use
del self._active_requests[message_id]
else:
# If no mapping found, use the last known request_id
await self._send_message(
message, self._last_request_id
)
else:
# For non-responses, use the last known request_id
await self._send_message(message, self._last_request_id)
finally:
# Cancel the consumer when we exit the loop
consumer_task.cancel()
Expand Down Expand Up @@ -153,13 +175,12 @@ async def _consume_messages(self):
"""Consume messages from the queue and send them to the local server."""
try:
while True:
message, request_id = await self._message_queue.get()
message = await self._message_queue.get()
try:
if self._write_stream:
logger.info(
f"Session {self._session_id} - processing queued message: {message}..."
)
self._last_request_id = request_id
await self._write_stream.send(message)
except Exception as e:
logger.error(
Expand All @@ -176,10 +197,15 @@ async def _consume_messages(self):
except asyncio.QueueEmpty:
break

async def _send_message(self, message: types.JSONRPCMessage, request_id: str) -> None:
async def _send_message(
self, message: types.JSONRPCMessage, request_id: str
) -> None:
"""Send new message to UiPath MCP Server."""
with self._mcp_tracer.create_span_for_message(
message, session_id=self._session_id, server_name=self._server_config.name
message,
session_id=self._session_id,
request_id=request_id,
server_name=self._server_config.name,
) as _:
for attempt in range(MAX_RETRIES + 1):
try:
Expand All @@ -198,38 +224,62 @@ async def _send_message(self, message: types.JSONRPCMessage, request_id: str) ->
)
raise

async def _send_message_internal(self, message: types.JSONRPCMessage, request_id: str) -> None:
async def _send_message_internal(
self, message: types.JSONRPCMessage, request_id: str
) -> None:
response = await self._uipath.api_client.request_async(
"POST",
f"mcp_/mcp/{self._server_config.name}/out/message?sessionId={self._session_id}&requestId={request_id}",
json=message.model_dump(),
)
if response.status_code == 202:
logger.info(
f"Outgoing message sent to UiPath MCP Server: {message}"
)
logger.info(f"Outgoing message sent to UiPath MCP Server: {message}")
elif 500 <= response.status_code < 600:
raise Exception(
f"{response.status_code} - {response.text}"
)
raise Exception(f"{response.status_code} - {response.text}")

async def _get_messages_internal(self, request_id: str) -> None:
response = await self._uipath.api_client.request_async(
"GET",
f"mcp_/mcp/{self._server_config.name}/in/messages?sessionId={self._session_id}&requestId={request_id}",
)
if response.status_code == 200:
self._last_request_id = request_id
messages = response.json()
for message in messages:
logger.info(f"Received message: {message}")
json_message = types.JSONRPCMessage.model_validate(message)
if self._is_request(json_message):
message_id = self._get_message_id(json_message)
if message_id:
self._active_requests[message_id] = request_id
with self._mcp_tracer.create_span_for_message(
json_message,
session_id=self._session_id,
request_id=request_id,
server_name=self._server_config.name,
) as _:
await self._message_queue.put((json_message, request_id))
await self._message_queue.put(json_message)
elif 500 <= response.status_code < 600:
raise Exception(
f"{response.status_code} - {response.text}"
raise Exception(f"{response.status_code} - {response.text}")

def _is_request(self, message: types.JSONRPCMessage) -> bool:
"""Check if a message is a JSONRPCRequest."""
if hasattr(message, "root"):
root = message.root
return isinstance(root, types.JSONRPCRequest)
return False

def _is_response(self, message: types.JSONRPCMessage) -> bool:
"""Check if a message is a JSONRPCResponse or JSONRPCError."""
if hasattr(message, "root"):
root = message.root
return isinstance(root, types.JSONRPCResponse) or isinstance(
root, types.JSONRPCError
)
return False

def _get_message_id(self, message: types.JSONRPCMessage) -> str:
"""Extract the message id from a JSONRPCMessage."""
if hasattr(message, "root") and hasattr(message.root, "id"):
return str(message.root.id)
return ""