Skip to content

Commit f6bb6fb

Browse files
authored
Merge pull request #8 from UiPath/fix/samples
feat: handle close session
2 parents 136a9c7 + 115e3e2 commit f6bb6fb

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-mcp"
3-
version = "0.0.3"
3+
version = "0.0.4"
44
description = "UiPath MCP SDK"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.10"

src/uipath_mcp/_cli/_runtime/_runtime.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,14 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
5353

5454
# Set up SignalR client
5555
signalr_url = (
56-
f"{os.environ.get('UIPATH_URL')}/mcp_/wsstunnel?slug={self.server.name}"
56+
f"{os.environ.get('UIPATH_URL')}/mcp_/wsstunnel?slug={self.server.name}&jobKey={self.context.job_id}"
5757
)
5858

59+
self.cancel_event = asyncio.Event()
60+
5961
self.signalr_client = SignalRClient(signalr_url)
6062
self.signalr_client.on("MessageReceived", self.handle_signalr_message)
63+
self.signalr_client.on("SessionClosed", self.handle_signalr_session_closed)
6164
self.signalr_client.on_error(self.handle_signalr_error)
6265
self.signalr_client.on_open(self.handle_signalr_open)
6366
self.signalr_client.on_close(self.handle_signalr_close)
@@ -68,7 +71,21 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
6871
# Keep the runtime alive
6972
# Start SignalR client and keep it running (this is a blocking call)
7073
logger.info("Starting SignalR client...")
71-
await self.signalr_client.run()
74+
75+
run_task = asyncio.create_task(self.signalr_client.run())
76+
77+
# Set up a task to wait for cancellation
78+
cancel_task = asyncio.create_task(self.cancel_event.wait())
79+
80+
# Wait for either the run to complete or cancellation
81+
done, pending = await asyncio.wait(
82+
[run_task, cancel_task],
83+
return_when=asyncio.FIRST_COMPLETED
84+
)
85+
86+
# Cancel any pending tasks
87+
for task in pending:
88+
task.cancel()
7289

7390
return UiPathRuntimeResult()
7491

@@ -99,6 +116,26 @@ async def validate(self) -> None:
99116
UiPathErrorCategory.DEPLOYMENT,
100117
)
101118

119+
async def handle_signalr_session_closed(self, args: list) -> None:
120+
"""
121+
Handle session closed by server.
122+
"""
123+
if len(args) < 1:
124+
logger.error(f"Received invalid SignalR message arguments: {args}")
125+
return
126+
127+
session_id = args[0]
128+
129+
logger.info(f"Received closed signal for session {session_id}")
130+
131+
try:
132+
self.cancel_event.set()
133+
134+
except Exception as e:
135+
logger.error(
136+
f"Error terminating session {session_id}: {str(e)}"
137+
)
138+
102139
async def handle_signalr_message(self, args: list) -> None:
103140
"""
104141
Handle incoming SignalR messages.
@@ -214,8 +251,9 @@ async def cleanup(self) -> None:
214251

215252
self.session_servers.clear()
216253

217-
# Close SignalR connection
218-
# self.signalr_client
254+
if self.signalr_client:
255+
# Close the SignalR connection
256+
await self.signalr_client._transport._ws.close()
219257

220258
# Add a small delay to allow the server to shut down gracefully
221259
if sys.platform == "win32":

src/uipath_mcp/_cli/_utils/_config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,16 @@ def get_servers(self) -> List[McpServer]:
6969
return list(self._servers.values())
7070

7171
def get_server(self, name: str) -> Optional[McpServer]:
72-
"""Get a server model by name."""
72+
"""
73+
Get a server model by name.
74+
If there's only one server available, return that one regardless of name.
75+
Otherwise, look up the server by the provided name.
76+
"""
77+
# If there's only one server, return it
78+
if len(self._servers) == 1:
79+
return next(iter(self._servers.values()))
80+
81+
# Otherwise, fall back to looking up by name
7382
return self._servers.get(name)
7483

7584
def get_server_names(self) -> List[str]:

0 commit comments

Comments
 (0)