Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 188 additions & 6 deletions src/mcp_agent/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

import asyncio
import concurrent.futures
from typing import Any, List, Optional, TYPE_CHECKING
from typing import Any, List, Optional, TYPE_CHECKING, Literal
import warnings

from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict

from mcp import ServerSession
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp import Context as MCPContext

from opentelemetry import trace

Expand All @@ -37,11 +38,12 @@

if TYPE_CHECKING:
from mcp_agent.agents.agent_spec import AgentSpec
from mcp_agent.human_input.types import HumanInputCallback
from mcp_agent.app import MCPApp
from mcp_agent.elicitation.types import ElicitationCallback
from mcp_agent.executor.workflow_signal import SignalWaitCallback
from mcp_agent.executor.workflow_registry import WorkflowRegistry
from mcp_agent.app import MCPApp
from mcp_agent.human_input.types import HumanInputCallback
from mcp_agent.logging.logger import Logger
else:
# Runtime placeholders for the types
AgentSpec = Any
Expand All @@ -50,11 +52,12 @@
SignalWaitCallback = Any
WorkflowRegistry = Any
MCPApp = Any
Logger = Any

logger = get_logger(__name__)


class Context(BaseModel):
class Context(MCPContext):
"""
Context that is passed around through the application.
This is a global context that is shared across the application.
Expand All @@ -65,7 +68,7 @@ class Context(BaseModel):
human_input_handler: Optional[HumanInputCallback] = None
elicitation_handler: Optional[ElicitationCallback] = None
signal_notification: Optional[SignalWaitCallback] = None
upstream_session: Optional[ServerSession] = None # TODO: saqadri - figure this out
upstream_session: Optional[ServerSession] = None
model_selector: Optional[ModelSelector] = None
session_id: str | None = None
app: Optional["MCPApp"] = None
Expand Down Expand Up @@ -102,6 +105,185 @@ class Context(BaseModel):
def mcp(self) -> FastMCP | None:
return self.app.mcp if self.app else None

@property
def fastmcp(self) -> FastMCP | None: # type: ignore[override]
"""Return the FastMCP instance if available.

Prefer the active request-bound FastMCP instance if present; otherwise
fall back to the app's configured FastMCP server. Returns None if neither
is available. This is more forgiving than the FastMCP Context default,
which raises outside of a request.
"""
try:
# Prefer a request-bound fastmcp if set by FastMCP during a request
if getattr(self, "_fastmcp", None) is not None:
return getattr(self, "_fastmcp", None)
except Exception:
pass
# Fall back to app-managed server instance (may be None in local scripts)
return self.mcp

@property
def session(self) -> ServerSession | None:
"""Best-effort ServerSession for upstream communication.

Priority:
- If explicitly provided, use `upstream_session`.
- If running within an active FastMCP request, use parent session.
- If an app FastMCP exists, use its current request context if any.

Returns None when no session can be resolved (e.g., local scripts).
"""
# 1) Explicit upstream session set by app/workflow
if getattr(self, "upstream_session", None) is not None:
return self.upstream_session

# 2) Try request-scoped session from FastMCP Context (may raise outside requests)
try:
return super().session # type: ignore[misc]
except Exception:
pass

# 3) Fall back to FastMCP server's current context if available
try:
mcp = self.mcp
if mcp is not None:
ctx = mcp.get_context()
# FastMCP.get_context returns a Context that raises outside a request;
# guard accordingly.
try:
return getattr(ctx, "session", None)
except Exception:
return None
except Exception:
pass

# No session available in this runtime mode
return None

@property
def logger(self) -> "Logger":
return self.app.logger if self.app else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logger property's return type annotation is Logger, but it returns None when self.app is None. This type mismatch could lead to runtime errors when calling methods on what's expected to be a Logger object. Consider either:

  1. Updating the return type annotation to Logger | None to accurately reflect the possible return values, or
  2. Providing a fallback logger implementation when self.app is None

This would help prevent potential NoneType errors during execution.

Suggested change
def logger(self) -> "Logger":
return self.app.logger if self.app else None
def logger(self) -> "Logger | None":
return self.app.logger if self.app else None

Spotted by Diamond

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.


# ---- FastMCP Context method fallbacks (safe outside requests) ---------

def bind_request(
self, request_context: Any, fastmcp: FastMCP | None = None
) -> "Context":
"""Return a shallow-copied Context bound to a specific FastMCP request.

- Shares app-wide state (config, registries, token counter, etc.) with the original Context
- Attaches `_request_context` and `_fastmcp` so FastMCP Context APIs work during the request
- Does not mutate the original Context (safe for concurrent requests)
"""
# Shallow copy to preserve references to registries/loggers while keeping isolation
bound: Context = self.model_copy(deep=False)
try:
setattr(bound, "_request_context", request_context)
except Exception:
pass
try:
if fastmcp is None:
fastmcp = getattr(self, "_fastmcp", None) or self.mcp
setattr(bound, "_fastmcp", fastmcp)
except Exception:
pass
return bound

@property
def client_id(self) -> str | None: # type: ignore[override]
try:
return super().client_id # type: ignore[misc]
except Exception:
return None

@property
def request_id(self) -> str: # type: ignore[override]
try:
return super().request_id # type: ignore[misc]
except Exception:
# Provide a stable-ish fallback based on app session if available
try:
return str(self.session_id) if getattr(self, "session_id", None) else ""
except Exception:
return ""

async def log(
self,
level: "Literal['debug', 'info', 'warning', 'error']",
message: str,
*,
logger_name: str | None = None,
) -> None: # type: ignore[override]
"""Send a log to the client if possible; otherwise, log locally.

Matches FastMCP Context API but avoids raising when no request context
is active by falling back to the app's logger.
"""
# If we have a live FastMCP request context, delegate to parent
try:
# will raise if request_context is not available
_ = self.request_context # type: ignore[attr-defined]
return await super().log(level, message, logger_name=logger_name) # type: ignore[misc]
except Exception:
pass

# Fall back to local logger if available
try:
_logger = self.logger
if _logger is not None:
if level == "debug":
_logger.debug(message)
elif level == "warning":
_logger.warning(message)
elif level == "error":
_logger.error(message)
else:
_logger.info(message)
except Exception:
# Swallow errors in fallback logging to avoid masking tool behavior
pass

async def report_progress(
self, progress: float, total: float | None = None, message: str | None = None
) -> None: # type: ignore[override]
"""Report progress to the client if a request is active.

Outside of a request (e.g., local scripts), this is a no-op to avoid
runtime errors as no progressToken exists.
"""
try:
_ = self.request_context # type: ignore[attr-defined]
return await super().report_progress(progress, total, message) # type: ignore[misc]
except Exception:
# No-op when no active request context
return None

async def read_resource(self, uri: Any) -> Any: # type: ignore[override]
"""Read a resource via FastMCP if possible; otherwise raise clearly.

This provides a friendlier error outside of a request and supports
fallback to the app's FastMCP instance if available.
"""
# Use the parent implementation if request-bound fastmcp is available
try:
if getattr(self, "_fastmcp", None) is not None:
return await super().read_resource(uri) # type: ignore[misc]
except Exception:
pass

# Fall back to app-managed FastMCP if present
try:
mcp = self.mcp
if mcp is not None:
return await mcp.read_resource(uri) # type: ignore[no-any-return]
except Exception:
pass

raise ValueError(
"read_resource is only available when an MCP server is active."
)


async def configure_otel(
config: "Settings", session_id: str | None = None
Expand Down
Loading