Skip to content

Commit 6786a5a

Browse files
carlkesselmanclaude
andcommitted
Multi-user isolation, tool audit fixes, and new MCP tool parameters
- Add per-user connection isolation using contextvars for active connection tracking and user_id derivation from credentials at connect time - Remove global _active_executions dict from execution.py; move per-session execution state into ConnectionInfo.active_tool_execution - Remove _cached_user_id from background_tasks.py; derive user identity statelessly from connection info - Fix broken _active_executions imports in feature.py and dataset.py - Remove duplicate set_visible_columns tool from schema.py (annotation.py version is more capable and was silently overwriting it) - Add dry_run parameter to create_execution MCP tool - Add metadata parameter to create_feature MCP tool - Fix pre-existing test failures for env var naming mismatch in TestMCPWorkflowInfo Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 98e3022 commit 6786a5a

File tree

7 files changed

+250
-196
lines changed

7 files changed

+250
-196
lines changed

src/deriva_ml_mcp/connection.py

Lines changed: 133 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,22 @@
33
This module handles Deriva catalog connections, maintaining
44
active connections and providing access to DerivaML instances.
55
6+
Multi-user isolation:
7+
- Each connection is keyed by (user_id, hostname, catalog_id) so two users
8+
connecting to the same catalog get separate ConnectionInfo objects.
9+
- The active connection is tracked per async context using contextvars,
10+
so concurrent requests in HTTP transport mode don't interfere.
11+
- User identity is derived from credentials at connect time and stored
12+
on ConnectionInfo for use by other modules (background tasks, executions).
13+
614
When connecting to a catalog, an MCP workflow and execution are automatically
715
created to track all operations performed through the MCP server.
816
"""
917

1018
from __future__ import annotations
1119

20+
import contextvars
21+
import hashlib
1222
import logging
1323
import os
1424
from dataclasses import dataclass
@@ -24,6 +34,34 @@
2434
# Workflow type for MCP server operations
2535
MCP_WORKFLOW_TYPE = "Deriva MCP"
2636

37+
# Per-request active connection tracking.
38+
# In HTTP transport, each async request gets its own contextvar scope,
39+
# preventing one user's connect() from overwriting another's active connection.
40+
# In stdio transport, there's only one context so this behaves like a simple variable.
41+
_active_connection_var: contextvars.ContextVar[str | None] = contextvars.ContextVar(
42+
"active_connection", default=None
43+
)
44+
45+
46+
def derive_user_id(credential: dict | None) -> str:
47+
"""Derive a user identifier from Deriva credentials.
48+
49+
Uses a hash of the webauthn cookie value for privacy.
50+
Falls back to "default_user" for single-user (stdio) mode.
51+
52+
Args:
53+
credential: Credential dict from DerivaML (contains 'cookie' key).
54+
55+
Returns:
56+
A string identifying the user.
57+
"""
58+
if credential:
59+
cookie = credential.get("cookie", "")
60+
if "webauthn=" in cookie:
61+
webauthn = cookie.split("webauthn=")[1].split(";")[0]
62+
return hashlib.sha256(webauthn.encode()).hexdigest()[:16]
63+
return "default_user"
64+
2765

2866
def get_mcp_workflow_info() -> dict[str, str | bool]:
2967
"""Get workflow metadata from environment variables.
@@ -45,31 +83,51 @@ class ConnectionInfo:
4583
"""Information about an active DerivaML connection.
4684
4785
Each connection has an associated workflow and execution for tracking
48-
all MCP operations performed on the catalog.
86+
all MCP operations performed on the catalog. The user_id field
87+
identifies who owns this connection for multi-user isolation.
4988
"""
5089

5190
hostname: str
5291
catalog_id: str | int
5392
domain_schemas: set[str] | None
5493
ml_instance: DerivaML
94+
user_id: str = "default_user"
5595
workflow_rid: str | None = None
56-
execution: Any = None # Execution object from deriva_ml
96+
execution: Any = None # MCP session execution from deriva_ml
97+
active_tool_execution: Any = None # User-created execution via create_execution tool
5798

5899

59100
class ConnectionManager:
60-
"""Manages DerivaML catalog connections.
101+
"""Manages DerivaML catalog connections with multi-user isolation.
61102
62-
Maintains a registry of active connections and provides
63-
methods to connect, disconnect, and access DerivaML instances.
103+
Connections are keyed by (user_id, hostname, catalog_id) so different
104+
users connecting to the same catalog get separate state. The active
105+
connection is tracked per-request using contextvars, which prevents
106+
concurrent HTTP requests from interfering with each other.
107+
108+
In stdio mode (one process per client), there's only one user and
109+
one context, so this behaves identically to a simple instance variable.
64110
"""
65111

66112
def __init__(self) -> None:
67113
self._connections: dict[str, ConnectionInfo] = {}
68-
self._active_connection: str | None = None
69114

70-
def _connection_key(self, hostname: str, catalog_id: str | int) -> str:
71-
"""Generate a unique key for a connection."""
72-
return f"{hostname}:{catalog_id}"
115+
@property
116+
def _active_connection(self) -> str | None:
117+
"""Get the active connection key for the current request context."""
118+
return _active_connection_var.get()
119+
120+
@_active_connection.setter
121+
def _active_connection(self, value: str | None) -> None:
122+
"""Set the active connection key for the current request context."""
123+
_active_connection_var.set(value)
124+
125+
def _connection_key(self, hostname: str, catalog_id: str | int, user_id: str = "") -> str:
126+
"""Generate a unique key for a connection.
127+
128+
Includes user_id so two users on the same catalog get separate entries.
129+
"""
130+
return f"{user_id}:{hostname}:{catalog_id}"
73131

74132
def _ensure_mcp_workflow_type(self, ml: DerivaML) -> None:
75133
"""Ensure the 'DerivaML MCP' workflow type exists in the catalog.
@@ -164,15 +222,6 @@ def connect(
164222
Raises:
165223
DerivaMLException: If connection fails.
166224
"""
167-
key = self._connection_key(hostname, catalog_id)
168-
169-
# Return existing connection if available
170-
if key in self._connections:
171-
if set_active:
172-
self._active_connection = key
173-
logger.info(f"Reusing existing connection to {key}")
174-
return self._connections[key].ml_instance
175-
176225
# Create new connection
177226
logger.info(f"Connecting to {hostname}, catalog {catalog_id}")
178227
try:
@@ -183,6 +232,17 @@ def connect(
183232
check_auth=True,
184233
)
185234

235+
# Derive user identity from the connection's credentials
236+
user_id = derive_user_id(ml.credential)
237+
key = self._connection_key(hostname, catalog_id, user_id)
238+
239+
# Return existing connection if available for this user
240+
if key in self._connections:
241+
if set_active:
242+
self._active_connection = key
243+
logger.info(f"Reusing existing connection to {key}")
244+
return self._connections[key].ml_instance
245+
186246
# Create MCP workflow and execution for tracking operations
187247
workflow_rid, execution = self._create_mcp_execution(ml)
188248

@@ -191,15 +251,18 @@ def connect(
191251
catalog_id=catalog_id,
192252
domain_schemas=domain_schemas,
193253
ml_instance=ml,
254+
user_id=user_id,
194255
workflow_rid=workflow_rid,
195256
execution=execution,
196257
)
197258
if set_active:
198259
self._active_connection = key
199260
logger.info(f"Successfully connected to {key}")
200261
return ml
262+
except DerivaMLException:
263+
raise
201264
except Exception as e:
202-
logger.error(f"Failed to connect to {key}: {e}")
265+
logger.error(f"Failed to connect to {hostname}:{catalog_id}: {e}")
203266
raise DerivaMLException(f"Failed to connect to {hostname}:{catalog_id}: {e}")
204267

205268
def disconnect(
@@ -224,7 +287,14 @@ def disconnect(
224287
if hostname is None and catalog_id is None:
225288
key = self._active_connection
226289
else:
227-
key = self._connection_key(hostname or "", catalog_id or "")
290+
# Find the connection key for this user+host+catalog
291+
conn_info = self._find_connection(hostname or "", catalog_id or "")
292+
key = None
293+
if conn_info:
294+
for k, v in self._connections.items():
295+
if v is conn_info:
296+
key = k
297+
break
228298

229299
if key and key in self._connections:
230300
conn_info = self._connections[key]
@@ -252,6 +322,25 @@ def disconnect(
252322
return True
253323
return False
254324

325+
def _find_connection(self, hostname: str, catalog_id: str | int) -> ConnectionInfo | None:
326+
"""Find a connection by hostname and catalog_id (any user).
327+
328+
Used for disconnect when we don't have user_id handy.
329+
Prefers the active connection's user if available.
330+
"""
331+
# Try active connection first
332+
active = self._active_connection
333+
if active and active in self._connections:
334+
info = self._connections[active]
335+
if info.hostname == hostname and str(info.catalog_id) == str(catalog_id):
336+
return info
337+
338+
# Fall back to any matching connection
339+
for info in self._connections.values():
340+
if info.hostname == hostname and str(info.catalog_id) == str(catalog_id):
341+
return info
342+
return None
343+
255344
def get_active(self) -> DerivaML | None:
256345
"""Get the active DerivaML instance.
257346
@@ -314,6 +403,22 @@ def get_active_connection_info(self) -> ConnectionInfo | None:
314403
return self._connections[self._active_connection]
315404
return None
316405

406+
def get_active_connection_info_or_raise(self) -> ConnectionInfo:
407+
"""Get the active connection info or raise an error.
408+
409+
Returns:
410+
ConnectionInfo for the active connection.
411+
412+
Raises:
413+
DerivaMLException: If no active connection.
414+
"""
415+
info = self.get_active_connection_info()
416+
if info is None:
417+
raise DerivaMLException(
418+
"No active catalog connection. Use 'connect' tool to connect to a catalog first."
419+
)
420+
return info
421+
317422
def get_connection(self, hostname: str, catalog_id: str | int) -> DerivaML | None:
318423
"""Get a specific connection.
319424
@@ -324,10 +429,8 @@ def get_connection(self, hostname: str, catalog_id: str | int) -> DerivaML | Non
324429
Returns:
325430
DerivaML instance or None if not connected.
326431
"""
327-
key = self._connection_key(hostname, catalog_id)
328-
if key in self._connections:
329-
return self._connections[key].ml_instance
330-
return None
432+
conn = self._find_connection(hostname, catalog_id)
433+
return conn.ml_instance if conn else None
331434

332435
def list_connections(self) -> list[dict[str, Any]]:
333436
"""List all active connections.
@@ -357,8 +460,10 @@ def set_active(self, hostname: str, catalog_id: str | int) -> bool:
357460
Returns:
358461
True if connection exists and was set as active.
359462
"""
360-
key = self._connection_key(hostname, catalog_id)
361-
if key in self._connections:
362-
self._active_connection = key
363-
return True
463+
conn = self._find_connection(hostname, catalog_id)
464+
if conn:
465+
for key, info in self._connections.items():
466+
if info is conn:
467+
self._active_connection = key
468+
return True
364469
return False

src/deriva_ml_mcp/tools/background_tasks.py

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from deriva.core import get_credential
3636

37+
from deriva_ml_mcp.connection import derive_user_id
3738
from deriva_ml_mcp.tasks import (
3839
TaskProgress,
3940
TaskStatus,
@@ -69,63 +70,26 @@ def _resolve_hostname(hostname: str | None) -> str | None:
6970
return hostname
7071

7172

72-
# Cache for user ID to ensure consistency within a session
73-
_cached_user_id: str | None = None
73+
def _get_user_id_from_credential(hostname: str | None = None) -> str:
74+
"""Derive a user identifier from credentials for the given hostname.
7475
75-
76-
def _get_user_id(hostname: str | None = None) -> str:
77-
"""Get a user identifier from credentials.
78-
79-
For multi-user isolation, we use the credential identity as user_id.
80-
If no credentials available, fall back to a default (single-user mode).
81-
82-
IMPORTANT: This function now caches the user ID to ensure consistency
83-
between task creation and task lookup. The first call with a hostname
84-
sets the cached value which is then used for all subsequent calls.
76+
Stateless — derives fresh each call. For multi-user isolation,
77+
prefer using conn_manager.get_active_connection_info().user_id
78+
when an active connection is available.
8579
8680
Args:
8781
hostname: Optional hostname to get credentials for.
8882
8983
Returns:
9084
A string identifying the user.
9185
"""
92-
global _cached_user_id
93-
94-
# If we have a cached user ID, always use it for consistency
95-
if _cached_user_id is not None:
96-
return _cached_user_id
97-
9886
try:
9987
if hostname:
10088
cred = get_credential(hostname)
101-
if cred and "cookie" in cred:
102-
# Extract webauthn from cookie for user identity
103-
cookie = cred.get("cookie", "")
104-
if "webauthn=" in cookie:
105-
# Use a hash of the webauthn value for privacy
106-
import hashlib
107-
108-
webauthn = cookie.split("webauthn=")[1].split(";")[0]
109-
user_id = hashlib.sha256(webauthn.encode()).hexdigest()[:16]
110-
_cached_user_id = user_id
111-
logger.debug(f"Cached user ID from credentials: {user_id[:8]}...")
112-
return user_id
113-
# Fall back to checking any available credential
114-
# In single-user mode, this is fine
115-
_cached_user_id = "default_user"
116-
logger.debug("Using default_user for single-user mode")
117-
return "default_user"
89+
return derive_user_id(cred)
11890
except Exception:
119-
_cached_user_id = "default_user"
120-
return "default_user"
121-
122-
123-
async def _get_user_id_async(hostname: str | None = None) -> str:
124-
"""Async version of _get_user_id.
125-
126-
Runs credential lookup in a thread to avoid blocking the event loop.
127-
"""
128-
return await asyncio.to_thread(_get_user_id, hostname)
91+
pass
92+
return "default_user"
12993

13094

13195
def _clone_catalog_task(
@@ -370,8 +334,12 @@ async def clone_catalog_async(
370334
"""
371335
try:
372336
task_manager = get_task_manager()
373-
# Use async credential lookup to avoid blocking
374-
user_id = await _get_user_id_async(source_hostname)
337+
# Get user_id from active connection if available, otherwise from credentials
338+
conn_info = conn_manager.get_active_connection_info()
339+
if conn_info:
340+
user_id = conn_info.user_id
341+
else:
342+
user_id = await asyncio.to_thread(_get_user_id_from_credential, source_hostname)
375343

376344
# Store parameters for the task
377345
parameters = {
@@ -455,8 +423,9 @@ async def get_task_status(
455423
"""
456424
try:
457425
task_manager = get_task_manager()
458-
# Use cached user_id for consistency with task creation
459-
user_id = await _get_user_id_async()
426+
# Get user_id from active connection for consistent task lookup
427+
conn_info = conn_manager.get_active_connection_info()
428+
user_id = conn_info.user_id if conn_info else "default_user"
460429

461430
# Use async snapshot method to avoid blocking event loop and minimize lock contention
462431
task_snapshot = await task_manager.get_task_snapshot_async(task_id, user_id)
@@ -503,8 +472,9 @@ async def list_tasks(
503472
"""
504473
try:
505474
task_manager = get_task_manager()
506-
# Use cached user_id for consistency
507-
user_id = await _get_user_id_async()
475+
# Get user_id from active connection
476+
conn_info = conn_manager.get_active_connection_info()
477+
user_id = conn_info.user_id if conn_info else "default_user"
508478

509479
# Parse filters
510480
status_filter = TaskStatus(status) if status else None
@@ -551,8 +521,9 @@ async def cancel_task(task_id: str) -> str:
551521
"""
552522
try:
553523
task_manager = get_task_manager()
554-
# Use cached user_id for consistency
555-
user_id = await _get_user_id_async()
524+
# Get user_id from active connection
525+
conn_info = conn_manager.get_active_connection_info()
526+
user_id = conn_info.user_id if conn_info else "default_user"
556527

557528
# Run cancel in thread to avoid blocking
558529
cancelled = await asyncio.to_thread(task_manager.cancel_task, task_id, user_id)

0 commit comments

Comments
 (0)