diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index 2ada0483b..e60e9d2a7 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -586,6 +586,7 @@ def _create_workflow_from_function( async def _invoke_target(workflow_self, *args, **kwargs): # Inject app_ctx (AppContext) and shim ctx (FastMCP Context) if requested by the function import inspect as _inspect + import typing as _typing call_kwargs = dict(kwargs) @@ -622,24 +623,51 @@ async def _invoke_target(workflow_self, *args, **kwargs): except Exception: pass - # If the function expects a FastMCP Context (ctx/context), ensure it's present (None inside workflow) + # If the function expects a FastMCP Context (ctx/context), ensure it's present. try: from mcp.server.fastmcp import Context as _Ctx # type: ignore except Exception: _Ctx = None # type: ignore + def _is_fast_ctx_annotation(annotation) -> bool: + if _Ctx is None or annotation is _inspect._empty: + return False + if annotation is _Ctx: + return True + try: + origin = _typing.get_origin(annotation) + if origin is not None: + return any( + _is_fast_ctx_annotation(arg) + for arg in _typing.get_args(annotation) + ) + except Exception: + pass + try: + return "fastmcp" in str(annotation) + except Exception: + return False + try: sig = sig if "sig" in locals() else _inspect.signature(fn) for p in sig.parameters.values(): - if ( - p.annotation is not _inspect._empty - and _Ctx is not None - and p.annotation is _Ctx + needs_fast_ctx = False + if _is_fast_ctx_annotation(p.annotation): + needs_fast_ctx = True + elif p.annotation is _inspect._empty and p.name in ( + "ctx", + "context", ): - if p.name not in call_kwargs: - call_kwargs[p.name] = None - if p.name in ("ctx", "context") and p.name not in call_kwargs: - call_kwargs[p.name] = None + needs_fast_ctx = True + if needs_fast_ctx and p.name not in call_kwargs: + fast_ctx = getattr(workflow_self, "_mcp_request_context", None) + if fast_ctx is None and app_context_param_name: + fast_ctx = getattr( + call_kwargs.get(app_context_param_name, None), + "fastmcp", + None, + ) + call_kwargs[p.name] = fast_ctx except Exception: pass diff --git a/src/mcp_agent/cli/cloud/commands/deploy/main.py b/src/mcp_agent/cli/cloud/commands/deploy/main.py index 5634e1618..c1ed69cf9 100644 --- a/src/mcp_agent/cli/cloud/commands/deploy/main.py +++ b/src/mcp_agent/cli/cloud/commands/deploy/main.py @@ -173,9 +173,7 @@ def deploy_config( if app_name is None: if default_app_name: - print_info( - f"Using app name from config.yaml: '{default_app_name}'" - ) + print_info(f"Using app name from config.yaml: '{default_app_name}'") app_name = default_app_name else: app_name = "default" @@ -205,7 +203,7 @@ def deploy_config( " • Or use the --api-key flag with your key", retriable=False, ) - + if settings.VERBOSE: print_info(f"Using API at {effective_api_url}") @@ -231,9 +229,7 @@ def deploy_config( print_info(f"New app id: `{app_id}`") else: short_id = f"{app_id[:8]}…" - print_success( - f"Found existing app '{app_name}' (ID: `{short_id}`)" - ) + print_success(f"Found existing app '{app_name}' (ID: `{short_id}`)") if not non_interactive: use_existing = typer.confirm( f"Deploy an update to '{app_name}' (ID: `{short_id}`)?", diff --git a/src/mcp_agent/cli/cloud/commands/deploy/wrangler_wrapper.py b/src/mcp_agent/cli/cloud/commands/deploy/wrangler_wrapper.py index 60df32da6..50eee106c 100644 --- a/src/mcp_agent/cli/cloud/commands/deploy/wrangler_wrapper.py +++ b/src/mcp_agent/cli/cloud/commands/deploy/wrangler_wrapper.py @@ -296,7 +296,9 @@ def ignore_patterns(path_str, names): ) meta_vars.update({"MCP_DEPLOY_WORKSPACE_HASH": bundle_hash}) if settings.VERBOSE: - print_info(f"Deploying from non-git workspace (hash {bundle_hash[:12]}…)") + print_info( + f"Deploying from non-git workspace (hash {bundle_hash[:12]}…)" + ) # Write a breadcrumb file into the project so it ships with the bundle. # Use a Python file for guaranteed inclusion without renaming. diff --git a/src/mcp_agent/core/context.py b/src/mcp_agent/core/context.py index d449c938a..57c26a927 100644 --- a/src/mcp_agent/core/context.py +++ b/src/mcp_agent/core/context.py @@ -4,13 +4,14 @@ import asyncio import concurrent.futures -from typing import Any, List, Optional, TYPE_CHECKING +from typing import Any, List, Optional, TYPE_CHECKING, Literal import warnings -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from mcp import ServerSession from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import Context as MCPContext from opentelemetry import trace @@ -37,11 +38,12 @@ if TYPE_CHECKING: from mcp_agent.agents.agent_spec import AgentSpec - from mcp_agent.human_input.types import HumanInputCallback + from mcp_agent.app import MCPApp from mcp_agent.elicitation.types import ElicitationCallback from mcp_agent.executor.workflow_signal import SignalWaitCallback from mcp_agent.executor.workflow_registry import WorkflowRegistry - from mcp_agent.app import MCPApp + from mcp_agent.human_input.types import HumanInputCallback + from mcp_agent.logging.logger import Logger else: # Runtime placeholders for the types AgentSpec = Any @@ -50,11 +52,12 @@ SignalWaitCallback = Any WorkflowRegistry = Any MCPApp = Any + Logger = Any logger = get_logger(__name__) -class Context(BaseModel): +class Context(MCPContext): """ Context that is passed around through the application. This is a global context that is shared across the application. @@ -65,7 +68,7 @@ class Context(BaseModel): human_input_handler: Optional[HumanInputCallback] = None elicitation_handler: Optional[ElicitationCallback] = None signal_notification: Optional[SignalWaitCallback] = None - upstream_session: Optional[ServerSession] = None # TODO: saqadri - figure this out + upstream_session: Optional[ServerSession] = None model_selector: Optional[ModelSelector] = None session_id: str | None = None app: Optional["MCPApp"] = None @@ -102,6 +105,213 @@ class Context(BaseModel): def mcp(self) -> FastMCP | None: return self.app.mcp if self.app else None + @property + def fastmcp(self) -> FastMCP | None: # type: ignore[override] + """Return the FastMCP instance if available. + + Prefer the active request-bound FastMCP instance if present; otherwise + fall back to the app's configured FastMCP server. Returns None if neither + is available. This is more forgiving than the FastMCP Context default, + which raises outside of a request. + """ + try: + # Prefer a request-bound fastmcp if set by FastMCP during a request + if getattr(self, "_fastmcp", None) is not None: + return getattr(self, "_fastmcp", None) + except Exception: + pass + # Fall back to app-managed server instance (may be None in local scripts) + return self.mcp + + @property + def session(self) -> ServerSession | None: + """Best-effort ServerSession for upstream communication. + + Priority: + - If explicitly provided, use `upstream_session`. + - If running within an active FastMCP request, use parent session. + - If an app FastMCP exists, use its current request context if any. + + Returns None when no session can be resolved (e.g., local scripts). + """ + # 1) Explicit upstream session set by app/workflow + if getattr(self, "upstream_session", None) is not None: + return self.upstream_session + + # 2) Try request-scoped session from FastMCP Context (may raise outside requests) + try: + return super().session # type: ignore[misc] + except Exception: + pass + + # 3) Fall back to FastMCP server's current context if available + try: + mcp = self.mcp + if mcp is not None: + ctx = mcp.get_context() + # FastMCP.get_context returns a Context that raises outside a request; + # guard accordingly. + try: + return getattr(ctx, "session", None) + except Exception: + return None + except Exception: + pass + + # No session available in this runtime mode + return None + + @property + def logger(self) -> "Logger": + if self.app: + return self.app.logger + namespace_components = ["mcp_agent", "context"] + try: + if getattr(self, "session_id", None): + namespace_components.append(str(self.session_id)) + except Exception: + pass + namespace = ".".join(namespace_components) + logger = get_logger( + namespace, session_id=getattr(self, "session_id", None), context=self + ) + try: + setattr(logger, "_bound_context", self) + except Exception: + pass + return logger + + @property + def name(self) -> str | None: + if self.app and getattr(self.app, "name", None): + return self.app.name + return None + + @property + def description(self) -> str | None: + if self.app and getattr(self.app, "description", None): + return self.app.description + return None + + # ---- FastMCP Context method fallbacks (safe outside requests) --------- + + def bind_request( + self, request_context: Any, fastmcp: FastMCP | None = None + ) -> "Context": + """Return a shallow-copied Context bound to a specific FastMCP request. + + - Shares app-wide state (config, registries, token counter, etc.) with the original Context + - Attaches `_request_context` and `_fastmcp` so FastMCP Context APIs work during the request + - Does not mutate the original Context (safe for concurrent requests) + """ + # Shallow copy to preserve references to registries/loggers while keeping isolation + bound: Context = self.model_copy(deep=False) + try: + setattr(bound, "_request_context", request_context) + except Exception: + pass + try: + if fastmcp is None: + fastmcp = getattr(self, "_fastmcp", None) or self.mcp + setattr(bound, "_fastmcp", fastmcp) + except Exception: + pass + return bound + + @property + def client_id(self) -> str | None: # type: ignore[override] + try: + return super().client_id # type: ignore[misc] + except Exception: + return None + + @property + def request_id(self) -> str: # type: ignore[override] + try: + return super().request_id # type: ignore[misc] + except Exception: + # Provide a stable-ish fallback based on app session if available + try: + return str(self.session_id) if getattr(self, "session_id", None) else "" + except Exception: + return "" + + async def log( + self, + level: "Literal['debug', 'info', 'warning', 'error']", + message: str, + *, + logger_name: str | None = None, + ) -> None: # type: ignore[override] + """Send a log to the client if possible; otherwise, log locally. + + Matches FastMCP Context API but avoids raising when no request context + is active by falling back to the app's logger. + """ + # If we have a live FastMCP request context, delegate to parent + try: + # will raise if request_context is not available + _ = self.request_context # type: ignore[attr-defined] + return await super().log(level, message, logger_name=logger_name) # type: ignore[misc] + except Exception: + pass + + # Fall back to local logger if available + try: + _logger = self.logger + if _logger is not None: + if level == "debug": + _logger.debug(message) + elif level == "warning": + _logger.warning(message) + elif level == "error": + _logger.error(message) + else: + _logger.info(message) + except Exception: + # Swallow errors in fallback logging to avoid masking tool behavior + pass + + async def report_progress( + self, progress: float, total: float | None = None, message: str | None = None + ) -> None: # type: ignore[override] + """Report progress to the client if a request is active. + + Outside of a request (e.g., local scripts), this is a no-op to avoid + runtime errors as no progressToken exists. + """ + try: + _ = self.request_context # type: ignore[attr-defined] + return await super().report_progress(progress, total, message) # type: ignore[misc] + except Exception: + # No-op when no active request context + return None + + async def read_resource(self, uri: Any) -> Any: # type: ignore[override] + """Read a resource via FastMCP if possible; otherwise raise clearly. + + This provides a friendlier error outside of a request and supports + fallback to the app's FastMCP instance if available. + """ + # Use the parent implementation if request-bound fastmcp is available + try: + if getattr(self, "_fastmcp", None) is not None: + return await super().read_resource(uri) # type: ignore[misc] + except Exception: + pass + + # Fall back to app-managed FastMCP if present + try: + mcp = self.mcp + if mcp is not None: + return await mcp.read_resource(uri) # type: ignore[no-any-return] + except Exception: + pass + + raise ValueError( + "read_resource is only available when an MCP server is active." + ) + async def configure_otel( config: "Settings", session_id: str | None = None diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index ec56574c1..524f99abe 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -1449,6 +1449,40 @@ def create_declared_function_tools(mcp: FastMCP, server_context: ServerContext): import inspect import asyncio import time + import typing as _typing + + try: + from mcp.server.fastmcp import Context as _Ctx + except Exception: + _Ctx = None # type: ignore + + def _annotation_is_fast_ctx(annotation) -> bool: + if _Ctx is None or annotation is inspect._empty: + return False + if annotation is _Ctx: + return True + try: + origin = _typing.get_origin(annotation) + if origin is not None: + return any( + _annotation_is_fast_ctx(arg) for arg in _typing.get_args(annotation) + ) + except Exception: + pass + try: + return "fastmcp" in str(annotation) + except Exception: + return False + + def _detect_context_param(signature: inspect.Signature) -> str | None: + for param in signature.parameters.values(): + if param.name == "app_ctx": + continue + if _annotation_is_fast_ctx(param.annotation): + return param.name + if param.annotation is inspect._empty and param.name in {"ctx", "context"}: + return param.name + return None async def _wait_for_completion( ctx: MCPContext, @@ -1564,23 +1598,29 @@ async def _wrapper(**kwargs): ann = dict(getattr(fn, "__annotations__", {})) ann.pop("app_ctx", None) - ctx_param_name = "ctx" - from mcp.server.fastmcp import Context as _Ctx + existing_ctx_param = _detect_context_param(sig) + ctx_param_name = existing_ctx_param or "ctx" - ann[ctx_param_name] = _Ctx + if _Ctx is not None: + ann[ctx_param_name] = _Ctx ann["return"] = getattr(fn, "__annotations__", {}).get("return", return_ann) _wrapper.__annotations__ = ann _wrapper.__name__ = name_local _wrapper.__doc__ = description or (fn.__doc__ or "") params = [p for p in sig.parameters.values() if p.name != "app_ctx"] - ctx_param = inspect.Parameter( - ctx_param_name, - kind=inspect.Parameter.KEYWORD_ONLY, - annotation=_Ctx, - ) + if existing_ctx_param is None: + ctx_param = inspect.Parameter( + ctx_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=_Ctx, + ) + signature_params = params + [ctx_param] + else: + signature_params = params + _wrapper.__signature__ = inspect.Signature( - parameters=params + [ctx_param], return_annotation=return_ann + parameters=signature_params, return_annotation=return_ann ) def _make_adapter(context_param_name: str, inner_wrapper): @@ -1620,13 +1660,16 @@ async def _async_wrapper(**kwargs): # Mirror original signature and annotations similar to sync path ann = dict(getattr(fn, "__annotations__", {})) ann.pop("app_ctx", None) + try: - from mcp.server.fastmcp import Context as _Ctx + sig_async = inspect.signature(fn) except Exception: - _Ctx = None # type: ignore + sig_async = None + existing_ctx_param = ( + _detect_context_param(sig_async) if sig_async else None + ) - # Choose context kw-only parameter - ctx_param_name = "ctx" + ctx_param_name = existing_ctx_param or "ctx" if _Ctx is not None: ann[ctx_param_name] = _Ctx @@ -1650,38 +1693,36 @@ async def _async_wrapper(**kwargs): # Build mirrored signature: drop app_ctx and any FastMCP Context params params = [] - try: - sig_async = inspect.signature(fn) + if sig_async is not None: for p in sig_async.parameters.values(): if p.name == "app_ctx": continue - if p.name in ("ctx", "context"): - continue - if ( - _Ctx is not None - and p.annotation is not inspect._empty - and p.annotation is _Ctx + if existing_ctx_param is None and ( + _annotation_is_fast_ctx(p.annotation) + or p.name in ("ctx", "context") ): continue params.append(p) - except Exception: - params = [] # Append kw-only context param - if _Ctx is not None: - ctx_param = inspect.Parameter( - ctx_param_name, - kind=inspect.Parameter.KEYWORD_ONLY, - annotation=_Ctx, - ) + if existing_ctx_param is None: + if _Ctx is not None: + ctx_param = inspect.Parameter( + ctx_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=_Ctx, + ) + else: + ctx_param = inspect.Parameter( + ctx_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + ) + signature_params = params + [ctx_param] else: - ctx_param = inspect.Parameter( - ctx_param_name, - kind=inspect.Parameter.KEYWORD_ONLY, - ) + signature_params = params _async_wrapper.__signature__ = inspect.Signature( - parameters=params + [ctx_param], return_annotation=ann.get("return") + parameters=signature_params, return_annotation=ann.get("return") ) # Adapter to map injected FastMCP context kwarg without additional propagation @@ -1829,6 +1870,33 @@ async def _workflow_run( if not workflows_dict or not app_context: raise ToolError("Server context not available for MCPApp Server.") + # Bind the app context to this FastMCP request so request-scoped methods + # (client_id, request_id, log/progress/resource reads) work seamlessly. + bound_app_context = app_context + try: + request_ctx = getattr(ctx, "request_context", None) + except Exception: + request_ctx = None + if request_ctx is not None and hasattr(app_context, "bind_request"): + try: + bound_app_context = app_context.bind_request( + request_ctx, + getattr(ctx, "fastmcp", None), + ) + # Preserve upstream_session if the copy drops it for any reason + if ( + getattr(bound_app_context, "upstream_session", None) is None + and getattr(app_context, "upstream_session", None) is not None + ): + bound_app_context.upstream_session = app_context.upstream_session + except Exception: + bound_app_context = app_context + # Expose the per-request bound context on the FastMCP context for adapters + try: + object.__setattr__(ctx, "bound_app_context", bound_app_context) + except Exception: + pass + if workflow_name not in workflows_dict: raise ToolError(f"Workflow '{workflow_name}' not found.") @@ -1842,14 +1910,20 @@ async def _workflow_run( if app is not None and getattr(app, "name", None): from mcp_agent.logging.logger import get_logger as _get_logger - _get_logger(f"mcp_agent.{app.name}", context=app_context) + _get_logger(f"mcp_agent.{app.name}", context=bound_app_context) except Exception: pass # Create and initialize the workflow instance using the factory method try: # Create workflow instance with context that has upstream_session - workflow = await workflow_cls.create(name=workflow_name, context=app_context) + workflow = await workflow_cls.create( + name=workflow_name, context=bound_app_context + ) + try: + setattr(workflow, "_mcp_request_context", ctx) + except Exception: + pass run_parameters = run_parameters or {} diff --git a/src/mcp_agent/server/tool_adapter.py b/src/mcp_agent/server/tool_adapter.py index 7295fb1b1..61ed37fd5 100644 --- a/src/mcp_agent/server/tool_adapter.py +++ b/src/mcp_agent/server/tool_adapter.py @@ -7,6 +7,7 @@ """ import inspect +import typing as _typing from typing import Any, Callable, Optional from mcp.server.fastmcp import Context as _Ctx @@ -36,26 +37,70 @@ def create_tool_adapter_signature( signature can be converted to JSON schema. """ sig = inspect.signature(fn) + + def _annotation_is_fast_ctx(annotation) -> bool: + if _Ctx is None or annotation is inspect._empty: + return False + if annotation is _Ctx: + return True + try: + origin = _typing.get_origin(annotation) + if origin is not None: + return any( + _annotation_is_fast_ctx(arg) for arg in _typing.get_args(annotation) + ) + except Exception: + pass + try: + return "fastmcp" in str(annotation) + except Exception: + return False + + existing_ctx_param = None + for param in sig.parameters.values(): + if param.name == "app_ctx": + continue + annotation = param.annotation + if annotation is inspect._empty and param.name in ("ctx", "context"): + existing_ctx_param = param.name + break + if _annotation_is_fast_ctx(annotation): + existing_ctx_param = param.name + break return_ann = sig.return_annotation # Copy annotations and remove app_ctx ann = dict(getattr(fn, "__annotations__", {})) ann.pop("app_ctx", None) - # Add ctx parameter annotation - ctx_param_name = "ctx" - ann[ctx_param_name] = _Ctx + # Determine context parameter name + ctx_param_name = existing_ctx_param or "ctx" + if _Ctx is not None: + ann[ctx_param_name] = _Ctx ann["return"] = getattr(fn, "__annotations__", {}).get("return", return_ann) - # Filter parameters to remove app_ctx - params = [p for p in sig.parameters.values() if p.name != "app_ctx"] - - # Create ctx parameter - ctx_param = inspect.Parameter( - ctx_param_name, - kind=inspect.Parameter.KEYWORD_ONLY, - annotation=_Ctx, - ) + # Filter parameters to remove app_ctx and, when needed, ctx/context placeholders + params = [] + for p in sig.parameters.values(): + if p.name == "app_ctx": + continue + if existing_ctx_param is None and ( + (p.annotation is inspect._empty and p.name in ("ctx", "context")) + or _annotation_is_fast_ctx(p.annotation) + ): + continue + params.append(p) + + # Create ctx parameter when not already present + if existing_ctx_param is None: + ctx_param = inspect.Parameter( + ctx_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=_Ctx, + ) + signature_params = params + [ctx_param] + else: + signature_params = params # Create a dummy function with the transformed signature async def _transformed(**kwargs): @@ -68,7 +113,7 @@ async def _transformed(**kwargs): # Create new signature with filtered params + ctx param _transformed.__signature__ = inspect.Signature( - parameters=params + [ctx_param], return_annotation=return_ann + parameters=signature_params, return_annotation=return_ann ) return _transformed diff --git a/tests/core/test_context.py b/tests/core/test_context.py new file mode 100644 index 000000000..038bd80d6 --- /dev/null +++ b/tests/core/test_context.py @@ -0,0 +1,123 @@ +import pytest +from types import SimpleNamespace + +from mcp_agent.core.context import Context +from mcp_agent.logging.logger import Logger as AgentLogger + + +class _DummyLogger: + def __init__(self): + self.messages = [] + + def debug(self, message: str): + self.messages.append(("debug", message)) + + def info(self, message: str): + self.messages.append(("info", message)) + + def warning(self, message: str): + self.messages.append(("warning", message)) + + def error(self, message: str): + self.messages.append(("error", message)) + + +class _DummyMCP: + def __init__(self): + self.last_uri = None + + async def read_resource(self, uri): + self.last_uri = uri + return [("text", uri)] + + +def _make_context(*, app: SimpleNamespace | None = None) -> Context: + ctx = Context() + if app is not None: + ctx.app = app + return ctx + + +def test_session_prefers_explicit_upstream(): + upstream = object() + ctx = _make_context() + ctx.upstream_session = upstream + + assert ctx.session is upstream + + +def test_fastmcp_fallback_to_app(): + dummy_mcp = object() + app = SimpleNamespace(mcp=dummy_mcp, logger=None) + ctx = _make_context(app=app) + + assert ctx.fastmcp is dummy_mcp + + bound = ctx.bind_request(SimpleNamespace(), fastmcp="request_mcp") + assert bound.fastmcp == "request_mcp" + # Original context remains unchanged + assert ctx.fastmcp is dummy_mcp + + +@pytest.mark.asyncio +async def test_log_falls_back_to_app_logger(): + dummy_logger = _DummyLogger() + app = SimpleNamespace(mcp=None, logger=dummy_logger) + ctx = _make_context(app=app) + + await ctx.log("info", "hello world") + + assert ("info", "hello world") in dummy_logger.messages + + +@pytest.mark.asyncio +async def test_read_resource_falls_back_to_app_mcp(): + dummy_mcp = _DummyMCP() + app = SimpleNamespace(mcp=dummy_mcp, logger=None) + ctx = _make_context(app=app) + + contents = await ctx.read_resource("resource://foo") + + assert dummy_mcp.last_uri == "resource://foo" + assert list(contents) == [("text", "resource://foo")] + + +@pytest.mark.asyncio +async def test_read_resource_without_mcp_raises(): + ctx = _make_context() + + with pytest.raises(ValueError): + await ctx.read_resource("resource://missing") + + +def test_logger_property_uses_app_logger(): + dummy_logger = _DummyLogger() + app = SimpleNamespace(mcp=None, logger=dummy_logger, name="demo-app") + ctx = _make_context(app=app) + + assert ctx.logger is dummy_logger + + +def test_logger_property_without_app_creates_logger(): + ctx = _make_context() + + logger = ctx.logger + + assert isinstance(logger, AgentLogger) + assert getattr(logger, "_bound_context", None) is ctx + + +def test_name_and_description_properties(): + app = SimpleNamespace( + mcp=None, logger=_DummyLogger(), name="app-name", description="app-desc" + ) + ctx = _make_context(app=app) + ctx.config = SimpleNamespace(name="config-name", description="config-desc") + + assert ctx.name == "app-name" + assert ctx.description == "app-desc" + + ctx_no_app = _make_context() + + assert ctx_no_app.name is None + assert ctx_no_app.description is None diff --git a/tests/server/test_app_server.py b/tests/server/test_app_server.py index 16d61528d..2b734b0f9 100644 --- a/tests/server/test_app_server.py +++ b/tests/server/test_app_server.py @@ -76,9 +76,13 @@ async def test_workflow_run_with_custom_workflow_id( ) # Verify the workflow was created - mock_workflow_class.create.assert_called_once_with( - name=workflow_name, - context=mock_server_context.request_context.lifespan_context.context, + mock_workflow_class.create.assert_called_once() + create_kwargs = mock_workflow_class.create.call_args.kwargs + assert create_kwargs["name"] == workflow_name + # Bound context should be derived from the original lifespan context + assert ( + create_kwargs["context"] + is not mock_server_context.request_context.lifespan_context.context ) # Verify run_async was called with the custom workflow_id diff --git a/tests/server/test_tool_decorators.py b/tests/server/test_tool_decorators.py index dec2d8fb4..5a61aa7f8 100644 --- a/tests/server/test_tool_decorators.py +++ b/tests/server/test_tool_decorators.py @@ -1,7 +1,11 @@ import asyncio +from typing import Any + import pytest from mcp_agent.app import MCPApp +from mcp_agent.core.context import Context +from mcp.server.fastmcp import Context as FastMCPContext from mcp_agent.server.app_server import ( create_workflow_tools, create_declared_function_tools, @@ -87,6 +91,10 @@ async def echo(text: str) -> str: ctx = _make_ctx(server_context) result = await sync_tool_fn(text="hi", ctx=ctx) assert result == "hi!" # unwrapped (not WorkflowResult) + bound_app_ctx = getattr(ctx, "bound_app_context", None) + assert bound_app_ctx is not None + assert bound_app_ctx is not server_context.context + assert bound_app_ctx.fastmcp == ctx.fastmcp # Also ensure the underlying workflow returned a WorkflowResult # Start via workflow_run to get run_id, then wait for completion and inspect @@ -171,3 +179,69 @@ async def wrapme(v: int) -> int: assert result_payload["value"] == 42 else: assert result_payload in (42, {"result": 42}) + + +@pytest.mark.asyncio +async def test_workflow_run_binds_app_context_per_request(): + app = MCPApp(name="test_request_binding") + await app.initialize() + + sentinel_session = object() + app.context.upstream_session = sentinel_session + + captured: dict[str, Any] = {} + + @app.async_tool(name="binding_tool") + async def binding_tool( + value: int, + app_ctx: Context | None = None, + ctx: FastMCPContext | None = None, + ) -> str: + captured["app_ctx"] = app_ctx + captured["ctx"] = ctx + if app_ctx is not None: + # Access session property to confirm fallback path works during execution + captured["session_property"] = app_ctx.session + captured["request_context"] = getattr(app_ctx, "_request_context", None) + captured["fastmcp"] = app_ctx.fastmcp + return f"done:{value}" + + server_context = type( + "SC", (), {"workflows": app.workflows, "context": app.context} + )() + + ctx = _make_ctx(server_context) + # Simulate FastMCP attaching the app to its server for lookup paths + ctx.fastmcp._mcp_agent_app = app # type: ignore[attr-defined] + + run_info = await _workflow_run(ctx, "binding_tool", {"value": 7}) + run_id = run_info["run_id"] + + # Workflow should have the FastMCP request context attached + workflow = await app.context.workflow_registry.get_workflow(run_id) + assert getattr(workflow, "_mcp_request_context", None) is ctx + + # Wait for completion so the tool function executes + for _ in range(200): + status = await app.context.workflow_registry.get_workflow_status(run_id) + if status.get("completed"): + break + await asyncio.sleep(0.01) + assert status.get("completed") is True + + bound_app_ctx = getattr(ctx, "bound_app_context", None) + assert bound_app_ctx is not None + # The tool received the per-request bound context + assert captured.get("app_ctx") is bound_app_ctx + # FastMCP context argument should be the original request context + assert captured.get("ctx") is ctx + assert getattr(captured.get("ctx"), "bound_app_context", None) is bound_app_ctx + assert bound_app_ctx is not app.context + # Upstream session should be preserved on the bound context + assert bound_app_ctx.upstream_session is sentinel_session + assert captured.get("session_property") is sentinel_session + # FastMCP instance and request context bridge through the bound context + assert captured.get("fastmcp") is ctx.fastmcp + assert captured.get("request_context") is ctx.request_context + # Accessing session on the bound context should prefer upstream_session + assert bound_app_ctx.session is sentinel_session