Skip to content

Commit 0eea628

Browse files
saqadriroman-van-der-krogt
authored andcommitted
hotfix: OAuth token caching by user identity instead of by MCPApp session ID (lastmile-ai#570)
* Fix workflow resume issue * temp wip * Fix workflow resume issue * temp wip * add audience validation * add comfigured token support; add workflow_pre_auth; oauth example against github * working oauth example with workflow_pre_auth * fixes to oauth discovery; add dynamic oath example * full e2e workflow * improve how we're dealing with no oauth user * all tests passing * reformat * rework oauth flow * Update readme * Fix user conflation on temporal; start of moving cache to app side * cache tokens requested from temporal flow in app * Implement local loopback OAuth callback server for MCPApp client-only runs * Tests and more fixes for loopback, including browser launch * various fixes * additional fixes suggested by cursor * merge and format * Various updates to get the "interactive_tool" example working * All examples working * Regenerate schema and add docstrings * Remove TESTING_GUIDE.md * Fixes to make sure oauth identities are properly isolated across multiple users * remove * update org name --------- Co-authored-by: Roman van der Krogt <[email protected]>
1 parent a8cddc4 commit 0eea628

File tree

7 files changed

+224
-54
lines changed

7 files changed

+224
-54
lines changed

examples/oauth/interactive_tool/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ async def main() -> None:
8686
print("[client] Invoking github_org_search...")
8787
result = await connection.call_tool(
8888
"github_org_search",
89-
{"query": "lastmileai"},
89+
{"query": "lastmile-ai"},
9090
)
9191
print("[client] Result:")
9292
for item in result.content or []:

examples/oauth/interactive_tool/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
settings = Settings(
6161
execution_engine="asyncio",
62-
logger=LoggerSettings(level="info"),
62+
logger=LoggerSettings(level="debug"),
6363
oauth=OAuthSettings(
6464
callback_base_url=AnyHttpUrl("http://localhost:8000"),
6565
flow_timeout_seconds=300,

src/mcp_agent/core/context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import asyncio
66
import concurrent.futures
7-
from typing import Any, List, Optional, TYPE_CHECKING, Literal
7+
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Literal
88
import warnings
99

10-
from pydantic import ConfigDict
10+
from pydantic import ConfigDict, Field
1111

1212
from mcp import ServerSession
1313
from mcp.server.fastmcp import FastMCP
@@ -34,6 +34,7 @@
3434
from mcp_agent.workflows.llm.llm_selector import ModelSelector
3535
from mcp_agent.logging.logger import get_logger
3636
from mcp_agent.tracing.token_counter import TokenCounter
37+
from mcp_agent.oauth.identity import OAuthUserIdentity
3738

3839

3940
if TYPE_CHECKING:
@@ -103,6 +104,7 @@ class Context(MCPContext):
103104
# OAuth helpers for downstream servers
104105
token_store: Optional[TokenStore] = None
105106
token_manager: Optional[TokenManager] = None
107+
identity_registry: Dict[str, OAuthUserIdentity] = Field(default_factory=dict)
106108

107109
model_config = ConfigDict(
108110
extra="allow",

src/mcp_agent/logging/logger.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ class LoggingConfig:
370370

371371
_initialized: bool = False
372372
_event_filter_ref: EventFilter | None = None
373+
_upstream_event_filter_ref: EventFilter | None = None
373374

374375
@classmethod
375376
async def configure(
@@ -392,8 +393,11 @@ async def configure(
392393
"""
393394
bus = AsyncEventBus.get(transport=transport)
394395
# Keep a reference to the provided filter so we can update at runtime
395-
if event_filter is not None:
396-
cls._event_filter_ref = event_filter
396+
if event_filter is None:
397+
event_filter = EventFilter()
398+
399+
cls._event_filter_ref = event_filter
400+
cls._upstream_event_filter_ref = event_filter.model_copy(deep=True)
397401

398402
# If already initialized, ensure critical listeners exist and return
399403
if cls._initialized:
@@ -411,7 +415,9 @@ async def configure(
411415
MCP_UPSTREAM_LISTENER_NAME: _Final[str] = "mcp_upstream"
412416
bus.add_listener(
413417
MCP_UPSTREAM_LISTENER_NAME,
414-
MCPUpstreamLoggingListener(event_filter=cls._event_filter_ref),
418+
MCPUpstreamLoggingListener(
419+
event_filter=cls._upstream_event_filter_ref
420+
),
415421
)
416422
except Exception:
417423
pass
@@ -451,7 +457,9 @@ async def configure(
451457
MCP_UPSTREAM_LISTENER_NAME: Final[str] = "mcp_upstream"
452458
bus.add_listener(
453459
MCP_UPSTREAM_LISTENER_NAME,
454-
MCPUpstreamLoggingListener(event_filter=event_filter),
460+
MCPUpstreamLoggingListener(
461+
event_filter=cls._upstream_event_filter_ref
462+
),
455463
)
456464
except Exception:
457465
# Non-fatal if import fails
@@ -472,7 +480,7 @@ async def shutdown(cls):
472480
@classmethod
473481
def set_min_level(cls, level: EventType | str) -> None:
474482
"""Update the minimum logging level on the shared event filter, if available."""
475-
if cls._event_filter_ref is None:
483+
if cls._upstream_event_filter_ref is None:
476484
return
477485
# Normalize level
478486
normalized = str(level).lower()
@@ -488,7 +496,7 @@ def set_min_level(cls, level: EventType | str) -> None:
488496
"alert": "error",
489497
"emergency": "error",
490498
}
491-
cls._event_filter_ref.min_level = mapping.get(normalized, "info")
499+
cls._upstream_event_filter_ref.min_level = mapping.get(normalized, "info")
492500

493501
@classmethod
494502
def get_event_filter(cls) -> EventFilter | None:

src/mcp_agent/oauth/manager.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,13 @@ async def get_access_token_if_present(
342342
self._default_identity,
343343
]
344344
identities = _dedupe(identity_candidates)
345+
logger.debug(
346+
"Resolved identity candidates for token acquisition",
347+
data={
348+
"server": server_name,
349+
"candidates": [candidate.cache_key for candidate in identities],
350+
},
351+
)
345352
if not identities:
346353
raise MissingUserIdentityError(
347354
"No authenticated user available for OAuth authorization"
@@ -364,6 +371,14 @@ async def get_access_token_if_present(
364371
async with lock:
365372
record = await self._token_store.get(key)
366373
if record and not record.is_expired(leeway_seconds=leeway):
374+
logger.debug(
375+
"Token cache hit",
376+
data={
377+
"server": server_name,
378+
"identity": identity.cache_key,
379+
"resource": resolved.resource,
380+
},
381+
)
367382
return record
368383

369384
if record and record.refresh_token:
@@ -546,6 +561,14 @@ async def ensure_access_token(
546561
)
547562

548563
await self._token_store.set(user_key, record)
564+
logger.debug(
565+
"Stored new access token via authorization flow",
566+
data={
567+
"server": server_name,
568+
"identity": flow_identity.cache_key,
569+
"resource": resolved.resource,
570+
},
571+
)
549572
return record
550573

551574
async def invalidate(
@@ -780,25 +803,56 @@ def _session_identity(self, context: "Context") -> OAuthUserIdentity | None:
780803
except Exception:
781804
in_temporal = False
782805

783-
session_id = None
806+
# Temporal workflows/activities carry their own execution identity.
784807
if in_temporal:
785-
# Base the identity on the Temporal workflow execution ID
786808
try:
787809
from mcp_agent.executor.temporal.temporal_context import (
788810
get_execution_id as _get_exec_id,
789811
)
812+
from mcp_agent.server import app_server
790813

791-
session_id = _get_exec_id()
814+
execution_id = _get_exec_id()
815+
if execution_id:
816+
identity = app_server._get_identity_for_execution(execution_id)
817+
if identity is not None:
818+
return identity
792819
except Exception:
793820
pass
794821

795-
if not session_id:
822+
session_id = None
823+
if in_temporal:
796824
session_id = getattr(context, "session_id", None)
797-
if not session_id:
798-
app = getattr(context, "app", None)
799-
if app is not None:
800-
session_id = getattr(app, "_session_id_override", None)
801825

802-
if not session_id:
803-
return None
804-
return OAuthUserIdentity(provider="mcp-session", subject=str(session_id))
826+
if session_id:
827+
try:
828+
from mcp_agent.server import app_server
829+
830+
identity = app_server.get_identity_for_session(session_id, context)
831+
if identity is not None:
832+
logger.debug(
833+
"Resolved session identity from registry",
834+
data={
835+
"session_id": session_id,
836+
"identity": identity.cache_key,
837+
},
838+
)
839+
return identity
840+
except Exception as exc:
841+
logger.debug(
842+
"Failed to resolve session identity from registry",
843+
data={"session_id": session_id, "error": repr(exc)},
844+
)
845+
fallback = OAuthUserIdentity(
846+
provider="mcp-session", subject=str(session_id)
847+
)
848+
logger.debug(
849+
"Falling back to synthetic session identity",
850+
data={"session_id": session_id, "identity": fallback.cache_key},
851+
)
852+
return fallback
853+
854+
logger.debug(
855+
"TokenManager no session identity resolved",
856+
data={"context_session_id": getattr(context, "session_id", None)},
857+
)
858+
return None

src/mcp_agent/server/app_server.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -147,24 +147,54 @@ def _resolve_identity_for_request(
147147
identity = _CURRENT_IDENTITY.get()
148148
if identity is None and execution_id:
149149
identity = _get_identity_for_execution(execution_id)
150-
if identity is None and app_context is not None:
151-
session_id = getattr(app_context, "session_id", None)
152-
identity = _session_identity_from_value(session_id)
153-
if identity is None and ctx is not None:
154-
session_id = _extract_session_id_from_context(ctx)
155-
identity = _session_identity_from_value(session_id)
156-
if identity is None and app_context is None and ctx is not None:
150+
request_session_id: str | None = None
151+
if ctx is not None:
152+
request_session_id = _extract_session_id_from_context(ctx)
153+
if app_context is None and ctx is not None:
157154
app = _get_attached_app(ctx.fastmcp)
158155
if app is not None and getattr(app, "context", None) is not None:
159156
app_context = app.context
157+
if identity is None and request_session_id:
158+
resolved = get_identity_for_session(request_session_id, app_context)
159+
if resolved:
160+
logger.debug(
161+
"Resolved identity from session registry",
162+
data={
163+
"session_id": request_session_id,
164+
"identity": resolved.cache_key,
165+
},
166+
)
167+
identity = resolved
160168
if identity is None and app_context is not None:
161169
session_id = getattr(app_context, "session_id", None)
162-
identity = _session_identity_from_value(session_id)
170+
if session_id and session_id != request_session_id:
171+
identity = get_identity_for_session(session_id, app_context)
163172
if identity is None:
164173
identity = DEFAULT_PRECONFIGURED_IDENTITY
165174
return identity
166175

167176

177+
def get_identity_for_session(
178+
session_id: str | None, app_context: "Context" | None = None
179+
) -> OAuthUserIdentity | None:
180+
"""Lookup the cached identity for a given MCP session."""
181+
if not session_id:
182+
return None
183+
if app_context is not None:
184+
try:
185+
identity = app_context.identity_registry.get(session_id)
186+
if identity is not None:
187+
return identity
188+
except Exception:
189+
pass
190+
else:
191+
logger.debug(
192+
"No app context provided when resolving session identity",
193+
data={"session_id": session_id},
194+
)
195+
return _session_identity_from_value(session_id)
196+
197+
168198
class ServerContext(ContextDependent):
169199
"""Context object for the MCP App server."""
170200

@@ -250,14 +280,11 @@ def _set_upstream_from_request_ctx_if_available(ctx: MCPContext) -> None:
250280
# First, try to use the session property from the FastMCP Context
251281
session = None
252282
try:
253-
session = (
254-
ctx.session
255-
) # This accesses the property which returns ctx.request_context.session
283+
session = ctx.session
256284
except (AttributeError, ValueError):
257285
# ctx.session property might raise ValueError if context not available
258286
pass
259287

260-
# Capture authenticated user information if available
261288
session_id = _extract_session_id_from_context(ctx)
262289
identity: OAuthUserIdentity | None = None
263290
try:
@@ -268,7 +295,6 @@ def _set_upstream_from_request_ctx_if_available(ctx: MCPContext) -> None:
268295
if isinstance(auth_user, AuthenticatedUser):
269296
access_token = getattr(auth_user, "access_token", None)
270297
if access_token is not None:
271-
# Prefer enriched token instances but fall back to raw data if necessary
272298
try:
273299
from mcp_agent.oauth.access_token import MCPAccessToken
274300

@@ -277,38 +303,35 @@ def _set_upstream_from_request_ctx_if_available(ctx: MCPContext) -> None:
277303
else:
278304
token_dict = getattr(access_token, "model_dump", None)
279305
if callable(token_dict):
280-
maybe_token = MCPAccessToken.model_validate(
281-
access_token.model_dump()
282-
)
283-
identity = OAuthUserIdentity.from_access_token(maybe_token)
306+
maybe_token = MCPAccessToken.model_validate(token_dict())
307+
if maybe_token is not None:
308+
identity = OAuthUserIdentity.from_access_token(maybe_token)
284309
except Exception:
285310
identity = None
286311

312+
app_context: "Context" | None = None
287313
app: MCPApp | None = _get_attached_app(ctx.fastmcp)
288314
if app is not None and getattr(app, "context", None) is not None:
315+
app_context = app.context
289316
if session is not None:
290-
app.context.upstream_session = session
291-
if session_id and not getattr(app.context, "session_id", None):
292-
app.context.session_id = session_id
293-
294-
app_session_id = getattr(app.context, "session_id", None)
295-
else:
296-
app_session_id = None
297-
298-
if app_session_id:
299-
app_identity = _session_identity_from_value(app_session_id)
300-
if identity is None or (
301-
isinstance(identity, OAuthUserIdentity)
302-
and identity.provider == "mcp-session"
303-
):
304-
identity = app_identity
317+
app_context.upstream_session = session
305318

306-
if identity is None:
319+
if identity is None and session_id:
307320
identity = _session_identity_from_value(session_id)
308321

309322
if identity is None:
310323
identity = DEFAULT_PRECONFIGURED_IDENTITY
311324

325+
if app_context is not None and session_id and identity is not None:
326+
try:
327+
app_context.identity_registry[session_id] = identity
328+
logger.debug(
329+
"Registered identity for session",
330+
data={"session_id": session_id, "identity": identity.cache_key},
331+
)
332+
except Exception:
333+
pass
334+
312335
_set_current_identity(identity)
313336

314337

0 commit comments

Comments
 (0)