From 778c857d1e437d120e4f1c1861ba75f78ca4c1f2 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Thu, 28 Aug 2025 23:10:08 -0400 Subject: [PATCH 01/24] Add @app.tool and @app.async_tool decorators --- src/mcp_agent/app.py | 151 ++++++++++ src/mcp_agent/server/app_server.py | 273 +++++++++++++++++- .../server/test_app_server_workflow_schema.py | 59 ++++ tests/server/test_tool_decorators.py | 173 +++++++++++ 4 files changed, 654 insertions(+), 2 deletions(-) create mode 100644 tests/server/test_app_server_workflow_schema.py create mode 100644 tests/server/test_tool_decorators.py diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index 4294a4012..c4fb8f6c2 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -115,6 +115,17 @@ def __init__( self._model_selector = model_selector self._workflows: Dict[str, Type["Workflow"]] = {} # id to workflow class + # Deferred tool declarations to register with MCP server when available + # Each entry: { + # "name": str, + # "mode": "sync" | "async", + # "workflow_name": str, + # "workflow_cls": Type[Workflow], + # "tool_wrapper": Callable | None, + # "structured_output": bool | None, + # "description": str | None, + # } + self._declared_tools: list[dict[str, Any]] = [] self._logger = None self._context: Optional[Context] = None @@ -512,6 +523,146 @@ async def wrapper(*args, **kwargs): return wrapper + def _create_workflow_from_function( + self, + fn: Callable[..., Any], + *, + workflow_name: str, + description: str | None = None, + mark_sync_tool: bool = False, + ) -> Type: + """ + Create a Workflow subclass dynamically from a plain function. + + The generated workflow class will: + - Have `run` implemented to call the provided function + - Be decorated with engine-specific run decorators via workflow_run + - Expose the original function for parameter schema generation + """ + + import asyncio as _asyncio + from mcp_agent.executor.workflow import Workflow as _Workflow + + async def _invoke_target(*args, **kwargs): + # Support both async and sync callables + res = fn(*args, **kwargs) + if _asyncio.iscoroutine(res): + res = await res + + # Ensure WorkflowResult return type + try: + from mcp_agent.executor.workflow import ( + WorkflowResult as _WorkflowResult, + ) + except Exception: + _WorkflowResult = None # type: ignore[assignment] + + if _WorkflowResult is not None and not isinstance(res, _WorkflowResult): + return _WorkflowResult(value=res) + return res + + async def _run(self, *args, **kwargs): # type: ignore[no-redef] + return await _invoke_target(*args, **kwargs) + + # Decorate run with engine-specific decorator + decorated_run = self.workflow_run(_run) + + # Build the Workflow subclass dynamically + cls_dict: Dict[str, Any] = { + "__doc__": description or (fn.__doc__ or ""), + "run": decorated_run, + "__mcp_agent_param_source_fn__": fn, + } + if mark_sync_tool: + cls_dict["__mcp_agent_sync_tool__"] = True + else: + cls_dict["__mcp_agent_async_tool__"] = True + + auto_cls = type(f"AutoWorkflow_{workflow_name}", (_Workflow,), cls_dict) + + # Register with app (and apply engine-specific workflow decorator) + self.workflow(auto_cls, workflow_id=workflow_name) + return auto_cls + + def tool( + self, + name: str | None = None, + *, + description: str | None = None, + structured_output: bool | None = None, + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorator to declare a synchronous MCP tool that runs via an auto-generated + Workflow and waits for completion before returning. + + Also registers an async Workflow under the same name so that run/get_status + endpoints are available. + """ + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + tool_name = name or fn.__name__ + # Construct the workflow from function + workflow_cls = self._create_workflow_from_function( + fn, + workflow_name=tool_name, + description=description, + mark_sync_tool=True, + ) + + # Defer tool registration until the MCP server is created + self._declared_tools.append( + { + "name": tool_name, + "mode": "sync", + "workflow_name": tool_name, + "workflow_cls": workflow_cls, + "source_fn": fn, + "structured_output": structured_output, + "description": description or (fn.__doc__ or ""), + } + ) + + return fn + + return decorator + + def async_tool( + self, + name: str | None = None, + *, + description: str | None = None, + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorator to declare an asynchronous MCP tool. + + Creates a Workflow class from the function and registers it so that + the standard per-workflow tools (run/get_status) are exposed by the server. + """ + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + workflow_name = name or fn.__name__ + workflow_cls = self._create_workflow_from_function( + fn, + workflow_name=workflow_name, + description=description, + mark_sync_tool=False, + ) + # Defer alias tool registration for run/get_status + self._declared_tools.append( + { + "name": workflow_name, + "mode": "async", + "workflow_name": workflow_name, + "workflow_cls": workflow_cls, + "source_fn": fn, + "structured_output": None, + "description": description or (fn.__doc__ or ""), + } + ) + return fn + + return decorator + def workflow_task( self, name: str | None = None, diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index c3b5a6bb8..1615f3543 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -132,6 +132,12 @@ def _resolve_workflows_and_context( def _resolve_workflow_registry(ctx: MCPContext) -> WorkflowRegistry | None: """Resolve the workflow registry regardless of startup mode.""" lifespan_ctx = getattr(ctx.request_context, "lifespan_context", None) + # Prefer the underlying app context's registry if available + if lifespan_ctx is not None and hasattr(lifespan_ctx, "context"): + ctx_inner = getattr(lifespan_ctx, "context", None) + if ctx_inner is not None and hasattr(ctx_inner, "workflow_registry"): + return ctx_inner.workflow_registry + # Fallback: top-level lifespan registry if present if lifespan_ctx is not None and hasattr(lifespan_ctx, "workflow_registry"): return lifespan_ctx.workflow_registry @@ -142,6 +148,43 @@ def _resolve_workflow_registry(ctx: MCPContext) -> WorkflowRegistry | None: return None +def _get_param_source_function_from_workflow(workflow_cls: Type["Workflow"]): + """Return the function to use for parameter schema for a workflow's run. + + For auto-generated workflows from @app.tool/@app.async_tool, prefer the original + function that defined the parameters if available; fall back to the class run. + """ + return getattr(workflow_cls, "__mcp_agent_param_source_fn__", None) or getattr( + workflow_cls, "run" + ) + + +def _build_run_param_tool(workflow_cls: Type["Workflow"]) -> FastTool: + """Return a FastTool built from the proper parameter source, skipping 'self'.""" + param_source = _get_param_source_function_from_workflow(workflow_cls) + import inspect as _inspect + + if param_source is getattr(workflow_cls, "run"): + + def _schema_fn_proxy(*args, **kwargs): + return None + + sig = _inspect.signature(param_source) + params = list(sig.parameters.values()) + if params and params[0].name == "self": + params = params[1:] + _schema_fn_proxy.__annotations__ = dict( + getattr(param_source, "__annotations__", {}) + ) + if "self" in _schema_fn_proxy.__annotations__: + _schema_fn_proxy.__annotations__.pop("self", None) + _schema_fn_proxy.__signature__ = _inspect.Signature( + parameters=params, return_annotation=sig.return_annotation + ) + return FastTool.from_function(_schema_fn_proxy) + return FastTool.from_function(param_source) + + def create_mcp_server_for_app(app: MCPApp, **kwargs: Any) -> FastMCP: """ Create an MCP server for a given MCPApp instance. @@ -166,6 +209,8 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Register initial workflow tools when running with our managed lifespan create_workflow_tools(mcp, server_context) + # Register function-declared tools (from @app.tool/@app.async_tool) + create_declared_function_tools(mcp, server_context) try: yield server_context @@ -189,6 +234,8 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Register per-workflow tools create_workflow_tools(mcp, server_context) + # Register function-declared tools (from @app.tool/@app.async_tool) + create_declared_function_tools(mcp, server_context) else: mcp = FastMCP( name=app.name or "mcp_agent_server", @@ -403,6 +450,11 @@ def create_workflow_tools(mcp: FastMCP, server_context: ServerContext): registered_workflow_tools = _get_registered_workflow_tools(mcp) for workflow_name, workflow_cls in server_context.workflows.items(): + # Skip creating generic workflows-* tools for sync/async auto tools + if getattr(workflow_cls, "__mcp_agent_sync_tool__", False): + continue + if getattr(workflow_cls, "__mcp_agent_async_tool__", False): + continue if workflow_name not in registered_workflow_tools: create_workflow_specific_tools(mcp, workflow_name, workflow_cls) registered_workflow_tools.add(workflow_name) @@ -410,12 +462,229 @@ def create_workflow_tools(mcp: FastMCP, server_context: ServerContext): setattr(mcp, "_registered_workflow_tools", registered_workflow_tools) +def _get_registered_function_tools(mcp: FastMCP) -> Set[str]: + return getattr(mcp, "_registered_function_tools", set()) + + +def _set_registered_function_tools(mcp: FastMCP, tools: Set[str]): + setattr(mcp, "_registered_function_tools", tools) + + +def create_declared_function_tools(mcp: FastMCP, server_context: ServerContext): + """ + Register tools declared via @app.tool/@app.async_tool on the attached app. + - @app.tool registers a synchronous tool with the same signature as the function + that runs the auto-generated workflow and waits for completion. + - @app.async_tool registers alias tools -run and -get_status + that proxy to the workflow run/status utilities. + """ + app = _get_attached_app(mcp) + if app is None: + # Fallbacks for tests or externally provided contexts + app = getattr(server_context, "app", None) + if app is None: + ctx = getattr(server_context, "context", None) + if ctx is not None: + app = getattr(ctx, "app", None) + if app is None: + return + + declared = getattr(app, "_declared_tools", []) or [] + if not declared: + return + + registered = _get_registered_function_tools(mcp) + + # Utility: build a wrapper function with the same signature and return annotation + import inspect + import asyncio + + async def _wait_for_completion( + ctx: MCPContext, run_id: str, timeout: float | None = None + ): + registry = _resolve_workflow_registry(ctx) + if not registry: + raise ToolError("Workflow registry not found for MCPApp Server.") + # Try to get the workflow and wait on its task if available + start = asyncio.get_event_loop().time() + # Ensure the workflow is registered locally to retrieve the task + try: + wf = await registry.get_workflow(run_id) + if wf is None and hasattr(registry, "register"): + # Best-effort: some registries need explicit register; try to find by status + # and skip if unavailable. This is a no-op for InMemory which registers at run_async. + pass + except Exception: + pass + while True: + wf = await registry.get_workflow(run_id) + if wf is not None: + task = getattr(wf, "_run_task", None) + if isinstance(task, asyncio.Task): + return await asyncio.wait_for(task, timeout=timeout) + # Fallback to polling the status + status = await wf.get_status() + if status.get("completed"): + return status.get("result") + if ( + timeout is not None + and (asyncio.get_event_loop().time() - start) > timeout + ): + raise ToolError("Timed out waiting for workflow completion") + await asyncio.sleep(0.1) + + for decl in declared: + name = decl["name"] + if name in registered: + continue + mode = decl["mode"] + workflow_name = decl["workflow_name"] + fn = decl.get("source_fn") + description = decl.get("description") + structured_output = decl.get("structured_output") + + if mode == "sync" and fn is not None: + sig = inspect.signature(fn) + return_ann = sig.return_annotation + + async def _wrapper(**kwargs): + # Context will be injected by FastMCP using the special annotation below + ctx: MCPContext = kwargs.pop( + "__context__" + ) # placeholder, reassigned below via signature name + # Start workflow and wait for completion + result_ids = await _workflow_run(ctx, workflow_name, kwargs) + run_id = result_ids["run_id"] + result = await _wait_for_completion(ctx, run_id) + # Unwrap WorkflowResult to match the original function's return type + try: + from mcp_agent.executor.workflow import WorkflowResult as _WFRes + except Exception: + _WFRes = None # type: ignore + if _WFRes is not None and isinstance(result, _WFRes): + return getattr(result, "value", None) + # If get_status returned dict/str, pass through; otherwise return model + return result + + # Attach introspection metadata to match the original function + ann = dict(getattr(fn, "__annotations__", {})) + + # Choose a context kwarg name unlikely to clash with user params + ctx_param_name = "ctx" + from mcp.server.fastmcp import Context as _Ctx + + ann[ctx_param_name] = _Ctx + ann["return"] = getattr(fn, "__annotations__", {}).get("return", return_ann) + _wrapper.__annotations__ = ann + _wrapper.__name__ = name + _wrapper.__doc__ = description or (fn.__doc__ or "") + + # Build a fake signature containing original params plus context kwarg + params = list(sig.parameters.values()) + ctx_param = inspect.Parameter( + ctx_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=_Ctx, + ) + _wrapper.__signature__ = inspect.Signature( + parameters=params + [ctx_param], return_annotation=return_ann + ) + + # FastMCP expects the actual kwarg name for context; it detects it by annotation + # We need to map the injected kwarg inside the wrapper body. Achieve this by + # creating a thin adapter that renames the injected context kwarg. + async def _adapter(**kw): + # Receive validated args plus injected context kwarg + if ctx_param_name not in kw: + raise ToolError("Context not provided") + # Rename to the placeholder expected by _wrapper + kw["__context__"] = kw.pop(ctx_param_name) + return await _wrapper(**kw) + + # Copy the visible signature/annotations to adapter for correct schema + _adapter.__annotations__ = _wrapper.__annotations__ + _adapter.__name__ = _wrapper.__name__ + _adapter.__doc__ = _wrapper.__doc__ + _adapter.__signature__ = _wrapper.__signature__ + + # Register the main tool with the same signature as original + mcp.add_tool( + _adapter, + name=name, + description=description or (fn.__doc__ or ""), + structured_output=structured_output, + ) + registered.add(name) + + # Also register a per-run status tool: -get_status + status_tool_name = f"{name}-get_status" + if status_tool_name not in registered: + + @mcp.tool(name=status_tool_name) + async def _sync_status(ctx: MCPContext, run_id: str) -> Dict[str, Any]: + return await _workflow_status( + ctx, run_id=run_id, workflow_name=workflow_name + ) + + registered.add(status_tool_name) + + elif mode == "async": + # Create only named aliases for async: -async-run and -async-get_status + run_tool_name = f"{name}-async-run" + status_tool_name = f"{name}-async-get_status" + + if run_tool_name not in registered: + + @mcp.tool(name=run_tool_name) + async def _alias_run( + ctx: MCPContext, run_parameters: Dict[str, Any] | None = None + ) -> Dict[str, str]: + return await _workflow_run(ctx, workflow_name, run_parameters or {}) + + registered.add(run_tool_name) + + if status_tool_name not in registered: + + @mcp.tool(name=status_tool_name) + async def _alias_status(ctx: MCPContext, run_id: str) -> Dict[str, Any]: + return await _workflow_status( + ctx, run_id=run_id, workflow_name=workflow_name + ) + + registered.add(status_tool_name) + + _set_registered_function_tools(mcp, registered) + + def create_workflow_specific_tools( mcp: FastMCP, workflow_name: str, workflow_cls: Type["Workflow"] ): """Create specific tools for a given workflow.""" - - run_fn_tool = FastTool.from_function(workflow_cls.run) + param_source = _get_param_source_function_from_workflow(workflow_cls) + # Ensure we don't include 'self' in tool schema; FastMCP will ignore Context but not 'self' + import inspect as _inspect + + if param_source is getattr(workflow_cls, "run"): + # Wrap to drop the first positional param (self) for schema purposes + def _schema_fn_proxy(*args, **kwargs): + return None + + sig = _inspect.signature(param_source) + params = list(sig.parameters.values()) + # remove leading 'self' if present + if params and params[0].name == "self": + params = params[1:] + _schema_fn_proxy.__annotations__ = dict( + getattr(param_source, "__annotations__", {}) + ) + if "self" in _schema_fn_proxy.__annotations__: + _schema_fn_proxy.__annotations__.pop("self", None) + _schema_fn_proxy.__signature__ = _inspect.Signature( + parameters=params, return_annotation=sig.return_annotation + ) + run_fn_tool = FastTool.from_function(_schema_fn_proxy) + else: + run_fn_tool = FastTool.from_function(param_source) run_fn_tool_params = json.dumps(run_fn_tool.parameters, indent=2) @mcp.tool( diff --git a/tests/server/test_app_server_workflow_schema.py b/tests/server/test_app_server_workflow_schema.py new file mode 100644 index 000000000..2b5f452ef --- /dev/null +++ b/tests/server/test_app_server_workflow_schema.py @@ -0,0 +1,59 @@ +import json +import pytest +from types import SimpleNamespace + +from mcp_agent.app import MCPApp +from mcp_agent.executor.workflow import Workflow, WorkflowResult +from mcp_agent.server.app_server import create_workflow_tools + + +class _ToolRecorder: + def __init__(self): + self.decorated = [] + + def tool(self, *args, **kwargs): + name = kwargs.get("name", args[0] if args else None) + + def _decorator(func): + self.decorated.append((name, func, kwargs)) + return func + + return _decorator + + +@pytest.mark.asyncio +async def test_workflow_run_schema_strips_self_and_uses_param_annotations(): + app = MCPApp(name="schema_app") + await app.initialize() + + @app.workflow + class MyWF(Workflow[str]): + """Doc for MyWF""" + + @app.workflow_run + async def run(self, q: int, flag: bool = False) -> WorkflowResult[str]: + return WorkflowResult(value=f"{q}:{flag}") + + mcp = _ToolRecorder() + server_context = SimpleNamespace(workflows=app.workflows, context=app.context) + + # This should create per-workflow tools; run tool must be built from run signature + create_workflow_tools(mcp, server_context) + + # Find the "workflows-MyWF-run" tool and inspect its parameters schema via FastMCP + names = [name for name, *_ in mcp.decorated] + assert f"workflows-MyWF-run" in names + + # We can’t call FastTool.from_function here since the tool is already created inside create_workflow_tools, + # but we can at least ensure that the schema text embedded in the description JSON includes our parameters (q, flag) + # Description contains a pretty-printed JSON of parameters; locate and parse it + run_entry = next( + (entry for entry in mcp.decorated if entry[0] == "workflows-MyWF-run"), None + ) + assert run_entry is not None + _, _, kwargs = run_entry + desc = kwargs.get("description", "") + # The description embeds the JSON schema; assert basic fields are referenced + assert "q" in desc + assert "flag" in desc + assert "self" not in desc diff --git a/tests/server/test_tool_decorators.py b/tests/server/test_tool_decorators.py new file mode 100644 index 000000000..1a7bfcc78 --- /dev/null +++ b/tests/server/test_tool_decorators.py @@ -0,0 +1,173 @@ +import asyncio +import pytest + +from mcp_agent.app import MCPApp +from mcp_agent.server.app_server import ( + create_workflow_tools, + create_declared_function_tools, + _workflow_run, + _workflow_status, +) + + +class _ToolRecorder: + """Helper to record tools registered via FastMCP-like interface.""" + + def __init__(self): + self.decorated_tools = [] # via mcp.tool decorator (workflow endpoints) + self.added_tools = [] # via mcp.add_tool (sync @app.tool) + + def tool(self, *args, **kwargs): + name = kwargs.get("name", args[0] if args else None) + + def _decorator(func): + self.decorated_tools.append((name, func)) + return func + + return _decorator + + def add_tool( + self, + fn, + *, + name=None, + title=None, + description=None, + annotations=None, + structured_output=None, + ): + self.added_tools.append((name, fn, description, structured_output)) + + +def _make_ctx(server_context): + # Minimal fake MCPContext with request_context.lifespan_context + from types import SimpleNamespace + + ctx = SimpleNamespace() + # Ensure a workflow registry is available for status waits + if not hasattr(server_context, "workflow_registry"): + from mcp_agent.executor.workflow_registry import InMemoryWorkflowRegistry + + server_context.workflow_registry = InMemoryWorkflowRegistry() + + req = SimpleNamespace(lifespan_context=server_context) + ctx.request_context = req + ctx.fastmcp = SimpleNamespace(_mcp_agent_app=None) + return ctx + + +@pytest.mark.asyncio +async def test_app_tool_registers_and_executes_sync_tool(): + app = MCPApp(name="test_app_tool") + await app.initialize() + + @app.tool(name="echo", description="Echo input") + async def echo(text: str) -> str: + return text + "!" + + # Prepare mock FastMCP and server context + mcp = _ToolRecorder() + server_context = type( + "SC", (), {"workflows": app.workflows, "context": app.context} + )() + + # Register generated per-workflow tools and function-declared tools + create_workflow_tools(mcp, server_context) + create_declared_function_tools(mcp, server_context) + + # Verify tool names: sync tool and its status tool + decorated_names = {name for name, _ in mcp.decorated_tools} + added_names = {name for name, *_ in mcp.added_tools} + + # No workflows-* for sync tools; check echo and echo-get_status + assert "echo" in added_names # synchronous tool + assert "echo-get_status" in decorated_names + + # Execute the synchronous tool function and ensure it returns unwrapped value + # Find the registered sync tool function + sync_tool_fn = next(fn for name, fn, *_ in mcp.added_tools if name == "echo") + ctx = _make_ctx(server_context) + result = await sync_tool_fn(text="hi", ctx=ctx) + assert result == "hi!" # unwrapped (not WorkflowResult) + + # Also ensure the underlying workflow returned a WorkflowResult + # Start via workflow_run to get run_id, then wait for completion and inspect + run_info = await _workflow_run(ctx, "echo", {"text": "ok"}) + run_id = run_info["run_id"] + # Poll status until completed (bounded wait) + for _ in range(200): + status = await _workflow_status(ctx, run_id, "echo") + if status.get("completed"): + break + await asyncio.sleep(0.01) + assert status.get("completed") is True + # The recorded result is a WorkflowResult model dump; check value field + result_payload = status.get("result") + if isinstance(result_payload, dict) and "value" in result_payload: + assert result_payload["value"] == "ok!" + else: + assert result_payload in ("ok!", {"result": "ok!"}) + + +@pytest.mark.asyncio +async def test_app_async_tool_registers_aliases_and_workflow_tools(): + app = MCPApp(name="test_app_async_tool") + await app.initialize() + + @app.async_tool(name="long") + async def long_task(x: int) -> str: + return f"done:{x}" + + mcp = _ToolRecorder() + server_context = type( + "SC", (), {"workflows": app.workflows, "context": app.context} + )() + + create_workflow_tools(mcp, server_context) + create_declared_function_tools(mcp, server_context) + + decorated_names = {name for name, _ in mcp.decorated_tools} + + # async aliases only (we suppress workflows-* for async auto tools) + assert "long-async-run" in decorated_names + assert "long-async-get_status" in decorated_names + + +@pytest.mark.asyncio +async def test_auto_workflow_wraps_plain_return_in_workflowresult(): + app = MCPApp(name="test_wrap") + await app.initialize() + + @app.async_tool(name="wrapme") + async def wrapme(v: int) -> int: + # plain int, should be wrapped inside WorkflowResult internally + return v + 1 + + mcp = _ToolRecorder() + server_context = type( + "SC", (), {"workflows": app.workflows, "context": app.context} + )() + create_workflow_tools(mcp, server_context) + create_declared_function_tools(mcp, server_context) + + ctx = _make_ctx(server_context) + run_info = await _workflow_run(ctx, "wrapme", {"v": 41}) + run_id = run_info["run_id"] + + # Inspect workflow's task result type by polling status for completion + for _ in range(100): + status = await _workflow_status(ctx, run_id, "wrapme") + if status.get("completed"): + break + await asyncio.sleep(0.01) + assert status.get("completed") is True + + # Cross-check that the underlying run returned a WorkflowResult by re-running via registry path + # We can't import the internal task here; assert observable effect: result equals expected and no exceptions + assert status.get("error") in (None, "") + # And the computed value was correct + result_payload = status.get("result") + if isinstance(result_payload, dict) and "value" in result_payload: + assert result_payload["value"] == 42 + else: + assert result_payload in (42, {"result": 42}) From bccedcd625b32bffdca4706c59e13fc1d4bd7dbd Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Thu, 28 Aug 2025 23:17:51 -0400 Subject: [PATCH 02/24] lint and format --- src/mcp_agent/cli/cloud/main.py | 6 +++- src/mcp_agent/server/app_server.py | 4 +-- tests/cli/commands/test_app_delete.py | 22 ++++++-------- tests/cli/commands/test_app_status.py | 29 ++++++++++++------- tests/cli/commands/test_app_workflows.py | 29 ++++++++++++------- tests/cli/commands/test_cli_secrets.py | 15 +++++----- tests/cli/commands/test_configure.py | 3 +- tests/cli/commands/test_deploy_command.py | 18 ++++++++---- tests/cli/utils/jwt_generator.py | 1 + .../server/test_app_server_workflow_schema.py | 3 +- tests/server/test_tool_decorators.py | 2 +- 11 files changed, 76 insertions(+), 56 deletions(-) diff --git a/src/mcp_agent/cli/cloud/main.py b/src/mcp_agent/cli/cloud/main.py index 4de6dd7a3..a883c7c46 100644 --- a/src/mcp_agent/cli/cloud/main.py +++ b/src/mcp_agent/cli/cloud/main.py @@ -14,7 +14,11 @@ from typer.core import TyperGroup from mcp_agent.cli.cloud.commands import configure_app, deploy_config, login -from mcp_agent.cli.cloud.commands.app import delete_app, get_app_status, list_app_workflows +from mcp_agent.cli.cloud.commands.app import ( + delete_app, + get_app_status, + list_app_workflows, +) from mcp_agent.cli.cloud.commands.apps import list_apps from mcp_agent.cli.cloud.commands.workflow import get_workflow_status from mcp_agent.cli.exceptions import CLIError diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 1615f3543..3751cbe1a 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -629,9 +629,9 @@ async def _sync_status(ctx: MCPContext, run_id: str) -> Dict[str, Any]: registered.add(status_tool_name) elif mode == "async": - # Create only named aliases for async: -async-run and -async-get_status + # Create named aliases for async: -async-run and -get_status run_tool_name = f"{name}-async-run" - status_tool_name = f"{name}-async-get_status" + status_tool_name = f"{name}-get_status" if run_tool_name not in registered: diff --git a/tests/cli/commands/test_app_delete.py b/tests/cli/commands/test_app_delete.py index e2dc544cd..4d4dfa463 100644 --- a/tests/cli/commands/test_app_delete.py +++ b/tests/cli/commands/test_app_delete.py @@ -1,4 +1,5 @@ """Tests for the configure command.""" + import datetime from unittest.mock import AsyncMock, MagicMock, patch @@ -8,7 +9,8 @@ from mcp_agent.cli.mcp_app.api_client import MCPApp, MCPAppConfiguration from mcp_agent.cli.mcp_app.mock_client import ( MOCK_APP_CONFIG_ID, - MOCK_APP_ID, MockMCPAppClient, + MOCK_APP_ID, + MockMCPAppClient, ) @@ -72,17 +74,13 @@ def test_delete_app(patched_delete_app, mock_mcp_client): app_id_or_url=MOCK_APP_ID, ) - patched_delete_app( - app_id_or_url=MOCK_APP_ID, - dry_run=False - ) + patched_delete_app(app_id_or_url=MOCK_APP_ID, dry_run=False) mock_mcp_client.delete_app.assert_called_once_with(MOCK_APP_ID) def test_delete_app_config(patched_delete_app, mock_mcp_client): app_config = MCPAppConfiguration( - appConfigurationId=MOCK_APP_CONFIG_ID, - creatorId="creator" + appConfigurationId=MOCK_APP_CONFIG_ID, creatorId="creator" ) mock_mcp_client.get_app_or_config = AsyncMock(return_value=app_config) @@ -91,12 +89,10 @@ def test_delete_app_config(patched_delete_app, mock_mcp_client): app_id_or_url=MOCK_APP_ID, ) - patched_delete_app( - app_id_or_url=MOCK_APP_ID, - dry_run=False - ) + patched_delete_app(app_id_or_url=MOCK_APP_ID, dry_run=False) mock_mcp_client.delete_app_configuration.assert_called_once_with(MOCK_APP_CONFIG_ID) + def test_missing_app_id(patched_delete_app): """Test with missing app_id.""" @@ -122,8 +118,8 @@ def test_missing_api_key(patched_delete_app): # Patch load_api_key_credentials to return None with patch( - "mcp_agent.cli.cloud.commands.configure.main.load_api_key_credentials", - return_value=None, + "mcp_agent.cli.cloud.commands.configure.main.load_api_key_credentials", + return_value=None, ): with pytest.raises(CLIError): patched_delete_app( diff --git a/tests/cli/commands/test_app_status.py b/tests/cli/commands/test_app_status.py index 70c570728..0d10f39da 100644 --- a/tests/cli/commands/test_app_status.py +++ b/tests/cli/commands/test_app_status.py @@ -1,4 +1,5 @@ """Tests for the configure command.""" + import datetime from unittest.mock import AsyncMock, MagicMock, patch, Mock @@ -10,7 +11,8 @@ from mcp_agent.cli.mcp_app.api_client import MCPApp, MCPAppConfiguration, AppServerInfo from mcp_agent.cli.mcp_app.mock_client import ( MOCK_APP_CONFIG_ID, - MOCK_APP_ID, MockMCPAppClient, + MOCK_APP_ID, + MockMCPAppClient, ) @@ -68,14 +70,14 @@ def test_status_app(patched_status_app, mock_mcp_client): creatorId="creatorId", createdAt=datetime.datetime.now(), updatedAt=datetime.datetime.now(), - appServerInfo=app_server_info + appServerInfo=app_server_info, ) mock_mcp_client.get_app_or_config = AsyncMock(return_value=app) mock_mcp_print_server_details = Mock() with patch( - "mcp_agent.cli.cloud.commands.app.status.main.print_mcp_server_details", - side_effect=mock_mcp_print_server_details + "mcp_agent.cli.cloud.commands.app.status.main.print_mcp_server_details", + side_effect=mock_mcp_print_server_details, ) as mocked_function: mock_mcp_print_server_details.return_value = None @@ -85,7 +87,9 @@ def test_status_app(patched_status_app, mock_mcp_client): api_key=settings.API_KEY, ) - mocked_function.assert_called_once_with(server_url=server_url, api_key=settings.API_KEY) + mocked_function.assert_called_once_with( + server_url=server_url, api_key=settings.API_KEY + ) def test_status_app_config(patched_status_app, mock_mcp_client): @@ -97,14 +101,14 @@ def test_status_app_config(patched_status_app, mock_mcp_client): app_config = MCPAppConfiguration( appConfigurationId=MOCK_APP_CONFIG_ID, creatorId="creator", - appServerInfo=app_server_info + appServerInfo=app_server_info, ) mock_mcp_client.get_app_or_config = AsyncMock(return_value=app_config) mock_mcp_print_server_details = Mock() with patch( - "mcp_agent.cli.cloud.commands.app.status.main.print_mcp_server_details", - side_effect=mock_mcp_print_server_details + "mcp_agent.cli.cloud.commands.app.status.main.print_mcp_server_details", + side_effect=mock_mcp_print_server_details, ) as mocked_function: mock_mcp_print_server_details.return_value = None @@ -114,7 +118,10 @@ def test_status_app_config(patched_status_app, mock_mcp_client): api_key=settings.API_KEY, ) - mocked_function.assert_called_once_with(server_url=server_url, api_key=settings.API_KEY) + mocked_function.assert_called_once_with( + server_url=server_url, api_key=settings.API_KEY + ) + def test_missing_app_id(patched_status_app): """Test with missing app_id.""" @@ -141,8 +148,8 @@ def test_missing_api_key(patched_status_app): # Patch load_api_key_credentials to return None with patch( - "mcp_agent.cli.cloud.commands.configure.main.load_api_key_credentials", - return_value=None, + "mcp_agent.cli.cloud.commands.configure.main.load_api_key_credentials", + return_value=None, ): with pytest.raises(CLIError): patched_status_app( diff --git a/tests/cli/commands/test_app_workflows.py b/tests/cli/commands/test_app_workflows.py index 5ad850b1b..15c578225 100644 --- a/tests/cli/commands/test_app_workflows.py +++ b/tests/cli/commands/test_app_workflows.py @@ -1,4 +1,5 @@ """Tests for the configure command.""" + import datetime from unittest.mock import AsyncMock, MagicMock, patch, Mock @@ -10,7 +11,8 @@ from mcp_agent.cli.mcp_app.api_client import MCPApp, MCPAppConfiguration, AppServerInfo from mcp_agent.cli.mcp_app.mock_client import ( MOCK_APP_CONFIG_ID, - MOCK_APP_ID, MockMCPAppClient, + MOCK_APP_ID, + MockMCPAppClient, ) @@ -68,14 +70,14 @@ def test_status_app(patched_workflows_app, mock_mcp_client): creatorId="creatorId", createdAt=datetime.datetime.now(), updatedAt=datetime.datetime.now(), - appServerInfo=app_server_info + appServerInfo=app_server_info, ) mock_mcp_client.get_app_or_config = AsyncMock(return_value=app) mock_mcp_print_mcp_server_workflow_details = Mock() with patch( - "mcp_agent.cli.cloud.commands.app.workflows.main.print_mcp_server_workflow_details", - side_effect=mock_mcp_print_mcp_server_workflow_details + "mcp_agent.cli.cloud.commands.app.workflows.main.print_mcp_server_workflow_details", + side_effect=mock_mcp_print_mcp_server_workflow_details, ) as mocked_function: mock_mcp_print_mcp_server_workflow_details.return_value = None @@ -85,7 +87,9 @@ def test_status_app(patched_workflows_app, mock_mcp_client): api_key=settings.API_KEY, ) - mocked_function.assert_called_once_with(server_url=server_url, api_key=settings.API_KEY) + mocked_function.assert_called_once_with( + server_url=server_url, api_key=settings.API_KEY + ) def test_status_app_config(patched_workflows_app, mock_mcp_client): @@ -97,14 +101,14 @@ def test_status_app_config(patched_workflows_app, mock_mcp_client): app_config = MCPAppConfiguration( appConfigurationId=MOCK_APP_CONFIG_ID, creatorId="creator", - appServerInfo=app_server_info + appServerInfo=app_server_info, ) mock_mcp_client.get_app_or_config = AsyncMock(return_value=app_config) mock_mcp_print_mcp_server_workflow_details = Mock() with patch( - "mcp_agent.cli.cloud.commands.app.workflows.main.print_mcp_server_workflow_details", - side_effect=mock_mcp_print_mcp_server_workflow_details + "mcp_agent.cli.cloud.commands.app.workflows.main.print_mcp_server_workflow_details", + side_effect=mock_mcp_print_mcp_server_workflow_details, ) as mocked_function: mock_mcp_print_mcp_server_workflow_details.return_value = None @@ -114,7 +118,10 @@ def test_status_app_config(patched_workflows_app, mock_mcp_client): api_key=settings.API_KEY, ) - mocked_function.assert_called_once_with(server_url=server_url, api_key=settings.API_KEY) + mocked_function.assert_called_once_with( + server_url=server_url, api_key=settings.API_KEY + ) + def test_missing_app_id(patched_workflows_app): """Test with missing app_id.""" @@ -141,8 +148,8 @@ def test_missing_api_key(patched_workflows_app): # Patch load_api_key_credentials to return None with patch( - "mcp_agent.cli.cloud.commands.configure.main.load_api_key_credentials", - return_value=None, + "mcp_agent.cli.cloud.commands.configure.main.load_api_key_credentials", + return_value=None, ): with pytest.raises(CLIError): patched_workflows_app( diff --git a/tests/cli/commands/test_cli_secrets.py b/tests/cli/commands/test_cli_secrets.py index 4e32bd7f4..6df6edf43 100644 --- a/tests/cli/commands/test_cli_secrets.py +++ b/tests/cli/commands/test_cli_secrets.py @@ -431,15 +431,12 @@ def test_cli_error_handling(mock_api_credentials): # Error message should mention the file doesn't exist combined_output = result.stderr + result.stdout # remove all lines, dashes, etc - ascii_text = re.sub(r'[^A-z0-9 .,-]+', ' ', combined_output) + ascii_text = re.sub(r"[^A-z0-9 .,-]+", " ", combined_output) # remove any remnants of colour codes - without_escape_codes = re.sub(r'\[\d+m', ' ', ascii_text) + without_escape_codes = re.sub(r"\[\d+m", " ", ascii_text) # normalize spaces and convert to lower case - clean_text = ' '.join(without_escape_codes.split()).lower() - assert ( - "does not exist" in clean_text - or "no such file" in clean_text - ) + clean_text = " ".join(without_escape_codes.split()).lower() + assert "does not exist" in clean_text or "no such file" in clean_text # Test with the secret value not having a tag cmd = [ @@ -464,7 +461,9 @@ def test_cli_error_handling(mock_api_credentials): # It should mention using the tags combined_output = result.stderr + result.stdout - clean_text = ' '.join(re.sub(r'[^\x00-\x7F]+', ' ', combined_output).split()).lower() + clean_text = " ".join( + re.sub(r"[^\x00-\x7F]+", " ", combined_output).split() + ).lower() assert ( "secrets must be tagged with !developer_secret or !user_secret" in clean_text diff --git a/tests/cli/commands/test_configure.py b/tests/cli/commands/test_configure.py index 95c13a4f3..90b9be93c 100644 --- a/tests/cli/commands/test_configure.py +++ b/tests/cli/commands/test_configure.py @@ -7,7 +7,8 @@ from mcp_agent.cli.exceptions import CLIError from mcp_agent.cli.mcp_app.mock_client import ( MOCK_APP_CONFIG_ID, - MOCK_APP_ID, MockMCPAppClient, + MOCK_APP_ID, + MockMCPAppClient, ) diff --git a/tests/cli/commands/test_deploy_command.py b/tests/cli/commands/test_deploy_command.py index ac1bbb208..2a28b3873 100644 --- a/tests/cli/commands/test_deploy_command.py +++ b/tests/cli/commands/test_deploy_command.py @@ -62,11 +62,11 @@ def test_deploy_command_help(runner): assert result.exit_code == 0 # remove all lines, dashes, etc - ascii_text = re.sub(r'[^A-z0-9.,-]+', '', result.stdout) + ascii_text = re.sub(r"[^A-z0-9.,-]+", "", result.stdout) # remove any remnants of colour codes - without_escape_codes = re.sub(r'\[[0-9 ]+m', '', ascii_text) + without_escape_codes = re.sub(r"\[[0-9 ]+m", "", ascii_text) # normalize spaces and convert to lower case - clean_text = ' '.join(without_escape_codes.split()).lower() + clean_text = " ".join(without_escape_codes.split()).lower() # Expected options from the updated CLAUDE.md spec assert "--config-dir" in clean_text or "-c" in clean_text @@ -131,7 +131,9 @@ async def mock_process_secrets(*args, **kwargs): def test_deploy_command_no_secrets(runner, temp_config_dir): """Test deploy command with --no-secrets flag when a secrets file DOES NOT exist.""" # Run with --no-secrets flag and --dry-run to avoid real deployment - with patch("mcp_agent.cli.cloud.commands.deploy.main.wrangler_deploy") as mock_deploy: + with patch( + "mcp_agent.cli.cloud.commands.deploy.main.wrangler_deploy" + ) as mock_deploy: # Mock the wrangler deployment mock_deploy.return_value = None @@ -162,7 +164,9 @@ def test_deploy_command_no_secrets(runner, temp_config_dir): def test_deploy_command_no_secrets_with_existing_secrets(runner, temp_config_dir): """Test deploy command with --no-secrets flag when a secrets file DOES exist.""" # Run with --no-secrets flag and --dry-run to avoid real deployment - with patch("mcp_agent.cli.cloud.commands.deploy.main.wrangler_deploy") as mock_deploy: + with patch( + "mcp_agent.cli.cloud.commands.deploy.main.wrangler_deploy" + ) as mock_deploy: # Mock the wrangler deployment mock_deploy.return_value = None @@ -292,7 +296,9 @@ def test_rollback_secrets_file(temp_config_dir): pre_deploy_secrets_content = f.read() # Call deploy_config with wrangler_deploy mocked - with patch("mcp_agent.cli.cloud.commands.deploy.main.wrangler_deploy") as mock_deploy: + with patch( + "mcp_agent.cli.cloud.commands.deploy.main.wrangler_deploy" + ) as mock_deploy: # Mock wrangler_deploy to prevent actual deployment mock_deploy.side_effect = Exception("Deployment failed") diff --git a/tests/cli/utils/jwt_generator.py b/tests/cli/utils/jwt_generator.py index 82c5c4e17..6202321e0 100644 --- a/tests/cli/utils/jwt_generator.py +++ b/tests/cli/utils/jwt_generator.py @@ -206,5 +206,6 @@ def generate_test_token(): expiry_days=365, ) + if __name__ == "__main__": main() diff --git a/tests/server/test_app_server_workflow_schema.py b/tests/server/test_app_server_workflow_schema.py index 2b5f452ef..05f387127 100644 --- a/tests/server/test_app_server_workflow_schema.py +++ b/tests/server/test_app_server_workflow_schema.py @@ -1,4 +1,3 @@ -import json import pytest from types import SimpleNamespace @@ -42,7 +41,7 @@ async def run(self, q: int, flag: bool = False) -> WorkflowResult[str]: # Find the "workflows-MyWF-run" tool and inspect its parameters schema via FastMCP names = [name for name, *_ in mcp.decorated] - assert f"workflows-MyWF-run" in names + assert "workflows-MyWF-run" in names # We can’t call FastTool.from_function here since the tool is already created inside create_workflow_tools, # but we can at least ensure that the schema text embedded in the description JSON includes our parameters (q, flag) diff --git a/tests/server/test_tool_decorators.py b/tests/server/test_tool_decorators.py index 1a7bfcc78..f39143fad 100644 --- a/tests/server/test_tool_decorators.py +++ b/tests/server/test_tool_decorators.py @@ -130,7 +130,7 @@ async def long_task(x: int) -> str: # async aliases only (we suppress workflows-* for async auto tools) assert "long-async-run" in decorated_names - assert "long-async-get_status" in decorated_names + assert "long-get_status" in decorated_names @pytest.mark.asyncio From 874898e68cbe9d11c0314d2a83af38ea3747afdc Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 29 Aug 2025 11:44:40 -0400 Subject: [PATCH 03/24] mcp logging notifications --- src/mcp_agent/logging/listeners.py | 92 +++++++++++++++++++++++++- src/mcp_agent/logging/logger.py | 51 +++++++++++++- src/mcp_agent/server/app_server.py | 20 ++++++ tests/logging/test_upstream_logging.py | 79 ++++++++++++++++++++++ 4 files changed, 239 insertions(+), 3 deletions(-) create mode 100644 tests/logging/test_upstream_logging.py diff --git a/src/mcp_agent/logging/listeners.py b/src/mcp_agent/logging/listeners.py index 8d66b6862..993c2329a 100644 --- a/src/mcp_agent/logging/listeners.py +++ b/src/mcp_agent/logging/listeners.py @@ -7,11 +7,24 @@ import time from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Any, Dict, List, Optional, Protocol, TYPE_CHECKING from mcp_agent.logging.events import Event, EventFilter, EventType from mcp_agent.logging.event_progress import convert_log_event +if TYPE_CHECKING: # pragma: no cover - for type checking only + from mcp.types import LoggingLevel + + +class UpstreamServerSessionProtocol(Protocol): + async def send_log_message( + self, + level: "LoggingLevel", + data: Dict[str, Any], + logger: str | None = None, + related_request_id: str | None = None, + ) -> None: ... + class EventListener(ABC): """Base async listener that processes events.""" @@ -217,3 +230,80 @@ async def flush(self): async def _process_batch(self, events: List[Event]): pass + + +class MCPUpstreamLoggingListener(FilteredListener): + """ + Sends matched log events to the connected MCP client via upstream_session + using notifications/message. + + This relies on a globally available Context (see get_current_context()) + which is established when the app initializes. If no upstream_session is + present, events are skipped. + """ + + def __init__(self, event_filter: EventFilter | None = None) -> None: + super().__init__(event_filter=event_filter) + + async def handle_matched_event(self, event: Event) -> None: + # Resolve upstream session from the global/app context when available + try: + # Import inline to avoid import cycles at module load + from mcp_agent.core.context import get_current_context + except Exception: + return + + try: + ctx = get_current_context() + except Exception: + return + + upstream_session: Optional[UpstreamServerSessionProtocol] = getattr( + ctx, "upstream_session", None + ) + if upstream_session is None: + return + + # Map our EventType to MCP LoggingLevel; fold progress -> info + mcp_level_map: Dict[str, str] = { + "debug": "debug", + "info": "info", + "warning": "warning", + "error": "error", + "progress": "info", + } + # Use string type to avoid hard dependency; annotated for type checkers + mcp_level: "LoggingLevel" = mcp_level_map.get(event.type, "info") # type: ignore[assignment] + + # Build structured data payload + data: Dict[str, Any] = { + "message": event.message, + "namespace": event.namespace, + "name": event.name, + "timestamp": event.timestamp.isoformat(), + } + if event.data: + # Merge user-provided event data under 'data' + data["data"] = event.data + if event.trace_id or event.span_id: + data["trace"] = {"trace_id": event.trace_id, "span_id": event.span_id} + if event.context is not None: + try: + data["context"] = event.context.dict() + except Exception: + pass + + # Determine logger name (namespace + optional name) + logger_name: str = ( + event.namespace if not event.name else f"{event.namespace}.{event.name}" + ) + + try: + await upstream_session.send_log_message( + level=mcp_level, # type: ignore[arg-type] + data=data, + logger=logger_name, + ) + except Exception: + # Avoid raising inside listener; best-effort delivery + pass diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index dca6940e1..bb5522b0a 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -11,7 +11,7 @@ import threading import time -from typing import Any, Dict +from typing import Any, Dict, Final from contextlib import asynccontextmanager, contextmanager @@ -212,7 +212,8 @@ async def async_event_context( class LoggingConfig: """Global configuration for the logging system.""" - _initialized = False + _initialized: bool = False + _event_filter_ref: EventFilter | None = None @classmethod async def configure( @@ -237,6 +238,8 @@ async def configure( return bus = AsyncEventBus.get(transport=transport) + # Keep a reference to the provided filter so we can update at runtime + cls._event_filter_ref = event_filter # Add standard listeners if "logging" not in bus.listeners: @@ -259,6 +262,25 @@ async def configure( ), ) + # Forward logs upstream via MCP notifications if upstream_session is configured + # Avoid duplicate registration by checking existing instances, not key name. + try: + from mcp_agent.logging.listeners import MCPUpstreamLoggingListener + + has_upstream_listener = any( + isinstance(listener, MCPUpstreamLoggingListener) + for listener in bus.listeners.values() + ) + if not has_upstream_listener: + MCP_UPSTREAM_LISTENER_NAME: Final[str] = "mcp_upstream" + bus.add_listener( + MCP_UPSTREAM_LISTENER_NAME, + MCPUpstreamLoggingListener(event_filter=event_filter), + ) + except Exception: + # Non-fatal if import fails + pass + await bus.start() cls._initialized = True @@ -271,6 +293,31 @@ async def shutdown(cls): await bus.stop() cls._initialized = False + @classmethod + def set_min_level(cls, level: EventType | str) -> None: + """Update the minimum logging level on the shared event filter, if available.""" + if cls._event_filter_ref is None: + return + # Normalize level + normalized = str(level).lower() + # Map synonyms to our EventType scale + mapping: Dict[str, EventType] = { + "debug": "debug", + "info": "info", + "notice": "info", + "warning": "warning", + "warn": "warning", + "error": "error", + "critical": "error", + "alert": "error", + "emergency": "error", + } + cls._event_filter_ref.min_level = mapping.get(normalized, "info") + + @classmethod + def get_event_filter(cls) -> EventFilter | None: + return cls._event_filter_ref + @classmethod @asynccontextmanager async def managed(cls, **config_kwargs): diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 3751cbe1a..691cf4065 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -21,6 +21,7 @@ InMemoryWorkflowRegistry, ) from mcp_agent.logging.logger import get_logger +from mcp_agent.logging.logger import LoggingConfig from mcp_agent.mcp.mcp_server_registry import ServerRegistry if TYPE_CHECKING: @@ -250,6 +251,25 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: app.mcp = mcp setattr(mcp, "_mcp_agent_app", app) + # Register logging/setLevel handler so client can adjust verbosity dynamically + # This enables MCP logging capability in InitializeResult.capabilities.logging + lowlevel_server = getattr(mcp, "_mcp_server", None) + try: + if lowlevel_server is not None: + + @lowlevel_server.set_logging_level() + async def _set_level( + level: str, + ) -> None: # mcp.types.LoggingLevel is a Literal[str] + try: + LoggingConfig.set_min_level(level) + except Exception: + # Best-effort, do not crash server on invalid level + pass + except Exception: + # If handler registration fails, continue without dynamic level updates + pass + # region Workflow Tools @mcp.tool(name="workflows-list") diff --git a/tests/logging/test_upstream_logging.py b/tests/logging/test_upstream_logging.py new file mode 100644 index 000000000..96184ec99 --- /dev/null +++ b/tests/logging/test_upstream_logging.py @@ -0,0 +1,79 @@ +import asyncio +import pytest + +from types import SimpleNamespace + +from mcp_agent.logging.logger import LoggingConfig, get_logger +from mcp_agent.logging.events import EventFilter +from mcp_agent.logging.transport import AsyncEventBus + + +class DummyUpstreamSession: + def __init__(self): + self.calls = [] + + async def send_log_message(self, level, data, logger, related_request_id=None): + self.calls.append( + { + "level": level, + "data": data, + "logger": logger, + "related_request_id": related_request_id, + } + ) + + +@pytest.mark.asyncio +async def test_upstream_logging_listener_sends_notifications(monkeypatch): + # Ensure clean bus state + AsyncEventBus.reset() + + dummy_session = DummyUpstreamSession() + + # Monkeypatch get_current_context to return an object with upstream_session + def _fake_get_current_context(): + return SimpleNamespace(upstream_session=dummy_session) + + monkeypatch.setattr( + "mcp_agent.core.context.get_current_context", _fake_get_current_context + ) + + # Configure logging with low threshold so our event passes + await LoggingConfig.configure(event_filter=EventFilter(min_level="debug")) + + try: + logger = get_logger("tests.logging") + logger.info("hello world", name="unit", foo="bar") + + # Give the async bus a moment to process + await asyncio.sleep(0.05) + + assert len(dummy_session.calls) >= 1 + call = dummy_session.calls[-1] + assert call["level"] in ("info", "debug", "warning", "error") + assert call["logger"].startswith("tests.logging") + # Ensure our message and custom data are included + data = call["data"] + assert data.get("message") == "hello world" + assert data.get("data", {}).get("foo") == "bar" + finally: + await LoggingConfig.shutdown() + AsyncEventBus.reset() + + +@pytest.mark.asyncio +async def test_logging_capability_registered_in_fastmcp(): + # Import here to avoid heavy imports at module import time + from mcp_agent.app import MCPApp + from mcp_agent.server.app_server import create_mcp_server_for_app + import mcp.types as types + + app = MCPApp(name="test_app") + mcp = create_mcp_server_for_app(app) + + low = getattr(mcp, "_mcp_server", None) + assert low is not None + + # The presence of a SetLevelRequest handler indicates logging capability will be advertised + assert types.SetLevelRequest in low.request_handlers + From 27513ad9dfa205baba03603356a1a25c22a11502 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 29 Aug 2025 13:06:40 -0400 Subject: [PATCH 04/24] Fixes for logger --- src/mcp_agent/app.py | 12 +++++++++ src/mcp_agent/logging/listeners.py | 42 +++++++++++++++++++++++------- src/mcp_agent/logging/logger.py | 17 ++++++++++++ 3 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index c4fb8f6c2..d05a01d60 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -624,6 +624,12 @@ def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: return fn + # Support bare usage: @app.tool without parentheses + if callable(name) and description is None and structured_output is None: + fn = name # type: ignore[assignment] + name = None + return decorator(fn) # type: ignore[arg-type] + return decorator def async_tool( @@ -661,6 +667,12 @@ def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: ) return fn + # Support bare usage: @app.async_tool without parentheses + if callable(name) and description is None: + fn = name # type: ignore[assignment] + name = None + return decorator(fn) # type: ignore[arg-type] + return decorator def workflow_task( diff --git a/src/mcp_agent/logging/listeners.py b/src/mcp_agent/logging/listeners.py index 993c2329a..4099350e9 100644 --- a/src/mcp_agent/logging/listeners.py +++ b/src/mcp_agent/logging/listeners.py @@ -247,20 +247,40 @@ def __init__(self, event_filter: EventFilter | None = None) -> None: async def handle_matched_event(self, event: Event) -> None: # Resolve upstream session from the global/app context when available + upstream_session: Optional[UpstreamServerSessionProtocol] = None try: # Import inline to avoid import cycles at module load from mcp_agent.core.context import get_current_context - except Exception: - return - try: - ctx = get_current_context() + try: + ctx = get_current_context() + upstream_session = getattr(ctx, "upstream_session", None) + except Exception: + upstream_session = None except Exception: - return + upstream_session = None + + # First fallback: Event may carry upstream_session injected by logger + if upstream_session is None: + candidate = getattr(event, "upstream_session", None) + if candidate is not None: + upstream_session = candidate # type: ignore[assignment] + + # Second fallback: if within an MCP request handling context, use the low-level server RequestContext + if upstream_session is None: + try: + from mcp.server.lowlevel.server import ( + request_ctx as _lowlevel_request_ctx, + ) + + try: + req_ctx = _lowlevel_request_ctx.get() + upstream_session = getattr(req_ctx, "session", None) + except LookupError: + upstream_session = None + except Exception: + upstream_session = None - upstream_session: Optional[UpstreamServerSessionProtocol] = getattr( - ctx, "upstream_session", None - ) if upstream_session is None: return @@ -304,6 +324,8 @@ async def handle_matched_event(self, event: Event) -> None: data=data, logger=logger_name, ) - except Exception: + except Exception as e: # Avoid raising inside listener; best-effort delivery - pass + import sys + + print(f"[mcp_agent] upstream log send failed: {e}", file=sys.stderr) diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index bb5522b0a..309ea46fa 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -92,6 +92,22 @@ def event( elif context.session_id is None: context.session_id = self.session_id + # Attach upstream_session from active MCP request context if available so + # listeners can forward logs upstream even from background tasks + extra_event_fields: Dict[str, Any] = {} + try: + from mcp.server.lowlevel.server import request_ctx as _lowlevel_request_ctx # type: ignore + + try: + req_ctx = _lowlevel_request_ctx.get() + extra_event_fields["upstream_session"] = getattr( + req_ctx, "session", None + ) + except LookupError: + pass + except Exception: + pass + evt = Event( type=etype, name=ename, @@ -99,6 +115,7 @@ def event( message=message, context=context, data=data, + **extra_event_fields, ) self._emit_event(evt) From 29893d1832f0358048f2b748f9d5a131ab476806 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Sat, 30 Aug 2025 17:48:33 -0400 Subject: [PATCH 05/24] checkpoint --- examples/mcp_agent_server/asyncio/README.md | 134 +++--- .../asyncio/basic_agent_server.py | 183 +++++--- examples/mcp_agent_server/asyncio/client.py | 143 ++++++- src/mcp_agent/app.py | 184 +++++++- src/mcp_agent/executor/workflow.py | 4 +- src/mcp_agent/logging/events.py | 4 + src/mcp_agent/logging/listeners.py | 39 +- src/mcp_agent/logging/logger.py | 100 ++++- src/mcp_agent/logging/transport.py | 10 +- src/mcp_agent/server/app_server.py | 396 ++++++++++++++---- 10 files changed, 947 insertions(+), 250 deletions(-) diff --git a/examples/mcp_agent_server/asyncio/README.md b/examples/mcp_agent_server/asyncio/README.md index 8b5b59965..1de776c3a 100644 --- a/examples/mcp_agent_server/asyncio/README.md +++ b/examples/mcp_agent_server/asyncio/README.md @@ -55,17 +55,18 @@ Before running the example, you'll need to configure the necessary paths and API 1. Copy the example secrets file: - ```bash - cp mcp_agent.secrets.yaml.example mcp_agent.secrets.yaml - ``` +``` +cp mcp_agent.secrets.yaml.example mcp_agent.secrets.yaml +``` 2. Edit `mcp_agent.secrets.yaml` to add your API keys: - ```yaml - anthropic: - api_key: "your-anthropic-api-key" - openai: - api_key: "your-openai-api-key" - ``` + +``` +anthropic: + api_key: "your-anthropic-api-key" +openai: + api_key: "your-openai-api-key" +``` ## How to Run @@ -73,7 +74,7 @@ Before running the example, you'll need to configure the necessary paths and API The simplest way to run the example is using the provided client script: -```bash +``` # Make sure you're in the mcp_agent_server/asyncio directory uv run client.py ``` @@ -91,21 +92,52 @@ You can also run the server and client separately: 1. In one terminal, start the server: - ```bash - uv run basic_agent_server.py +``` +uv run basic_agent_server.py - # Optionally, run with the example custom FastMCP settings - uv run basic_agent_server.py --custom-fastmcp-settings - ``` +# Optionally, run with the example custom FastMCP settings +uv run basic_agent_server.py --custom-fastmcp-settings +``` 2. In another terminal, run the client: - ```bash - uv run client.py +``` +uv run client.py + +# Optionally, run with the example custom FastMCP settings +uv run client.py --custom-fastmcp-settings +``` + +## Receiving Server Logs in the Client - # Optionally, run with the example custom FastMCP settings - uv run client.py --custom-fastmcp-settings - ``` +The server advertises the `logging` capability (via `logging/setLevel`) and forwards its structured logs upstream using `notifications/message`. To receive these logs in a client session, pass a `logging_callback` when constructing the client session and set the desired level: + +```python +from datetime import timedelta +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession +from mcp.types import LoggingMessageNotificationParams +from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession + +async def on_server_log(params: LoggingMessageNotificationParams) -> None: + print(f"[SERVER LOG] [{params.level.upper()}] [{params.logger}] {params.data}") + +def make_session(read_stream: MemoryObjectReceiveStream, + write_stream: MemoryObjectSendStream, + read_timeout_seconds: timedelta | None) -> ClientSession: + return MCPAgentClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=read_timeout_seconds, + logging_callback=on_server_log, + ) + +# Later, when connecting via gen_client(..., client_session_factory=make_session) +# you can request the minimum server log level: +# await server.set_logging_level("info") +``` + +The example client (`client.py`) demonstrates this end-to-end: it registers a logging callback and calls `set_logging_level("info")` so logs from the server appear in the client's console. ## MCP Clients @@ -116,7 +148,7 @@ like any other MCP server. You can inspect and test the server using [MCP Inspector](https://github.com/modelcontextprotocol/inspector): -```bash +``` npx @modelcontextprotocol/inspector \ uv \ --directory /path/to/mcp-agent/examples/mcp_agent_server/asyncio \ @@ -138,41 +170,41 @@ To use this server with Claude Desktop: 2. Add a new server configuration: - ```json - "basic-agent-server": { - "command": "/path/to/uv", - "args": [ - "--directory", - "/path/to/mcp-agent/examples/mcp_agent_server/asyncio", - "run", - "basic_agent_server.py" - ] - } - ``` +```json +"basic-agent-server": { + "command": "/path/to/uv", + "args": [ + "--directory", + "/path/to/mcp-agent/examples/mcp_agent_server/asyncio", + "run", + "basic_agent_server.py" + ] +} +``` 3. Restart Claude Desktop, and you'll see the server available in the tool drawer 4. (**claude desktop workaround**) Update `mcp_agent.config.yaml` file with the full paths to npx/uvx on your system: - Find the full paths to `uvx` and `npx` on your system: - - ```bash - which uvx - which npx - ``` - - Update the `mcp_agent.config.yaml` file with these paths: - - ```yaml - mcp: - servers: - fetch: - command: "/full/path/to/uvx" # Replace with your path - args: ["mcp-server-fetch"] - filesystem: - command: "/full/path/to/npx" # Replace with your path - args: ["-y", "@modelcontextprotocol/server-filesystem"] - ``` +Find the full paths to `uvx` and `npx` on your system: + +``` +which uvx +which npx +``` + +Update the `mcp_agent.config.yaml` file with these paths: + +```yaml +mcp: + servers: + fetch: + command: "/full/path/to/uvx" # Replace with your path + args: ["mcp-server-fetch"] + filesystem: + command: "/full/path/to/npx" # Replace with your path + args: ["-y", "@modelcontextprotocol/server-filesystem"] +``` ## Code Structure diff --git a/examples/mcp_agent_server/asyncio/basic_agent_server.py b/examples/mcp_agent_server/asyncio/basic_agent_server.py index f54f171d9..9a8476005 100644 --- a/examples/mcp_agent_server/asyncio/basic_agent_server.py +++ b/examples/mcp_agent_server/asyncio/basic_agent_server.py @@ -11,9 +11,10 @@ import asyncio import os import logging -from typing import Dict, Any +from typing import Dict, Any, Optional -from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import FastMCP, Context as MCPContext +from mcp_agent.core.context import Context as AppContext from mcp_agent.app import MCPApp from mcp_agent.server.app_server import create_mcp_server_for_app @@ -26,15 +27,12 @@ from mcp_agent.executor.workflow import Workflow, WorkflowResult from mcp_agent.tracing.token_counter import TokenNode -# Initialize logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - # Note: This is purely optional: # if not provided, a default FastMCP server will be created by MCPApp using create_mcp_server_for_app() mcp = FastMCP(name="basic_agent_server", description="My basic agent server example.") -# Define the MCPApp instance +# Define the MCPApp instance. The server created for this app will advertise the +# MCP logging capability and forward structured logs upstream to connected clients. app = MCPApp( name="basic_agent_server", description="Basic agent server example", @@ -112,71 +110,146 @@ async def run(self, input: str) -> WorkflowResult[str]: return WorkflowResult(value=result) -@app.workflow -class ParallelWorkflow(Workflow[str]): +@app.tool +async def grade_story(story: str, app_ctx: Optional[AppContext] = None) -> str: """ - This workflow can be used to grade a student's short story submission and generate a report. + This tool can be used to grade a student's short story submission and generate a report. It uses multiple agents to perform different tasks in parallel. The agents include: - Proofreader: Reviews the story for grammar, spelling, and punctuation errors. - Fact Checker: Verifies the factual consistency within the story. - Style Enforcer: Analyzes the story for adherence to style guidelines. - Grader: Compiles the feedback from the other agents into a structured report. + + Args: + story: The student's short story to grade + app_ctx: Optional MCPApp context for accessing app resources and logging """ + # Use the context's app if available for proper logging with upstream_session + _app = app_ctx.app if app_ctx else app + # Ensure the app's logger is bound to the current context with upstream_session + if _app._logger and hasattr(_app._logger, '_bound_context'): + _app._logger._bound_context = app_ctx + logger = _app.logger + logger.info(f"grade_story: Received input: {story}") + + proofreader = Agent( + name="proofreader", + instruction=""""Review the short story for grammar, spelling, and punctuation errors. + Identify any awkward phrasing or structural issues that could improve clarity. + Provide detailed feedback on corrections.""", + ) - @app.workflow_run - async def run(self, input: str) -> WorkflowResult[str]: - """ - Run the workflow, processing the input data. + fact_checker = Agent( + name="fact_checker", + instruction="""Verify the factual consistency within the story. Identify any contradictions, + logical inconsistencies, or inaccuracies in the plot, character actions, or setting. + Highlight potential issues with reasoning or coherence.""", + ) - Args: - input_data: The data to process + style_enforcer = Agent( + name="style_enforcer", + instruction="""Analyze the story for adherence to style guidelines. + Evaluate the narrative flow, clarity of expression, and tone. Suggest improvements to + enhance storytelling, readability, and engagement.""", + ) - Returns: - A WorkflowResult containing the processed data - """ + grader = Agent( + name="grader", + instruction="""Compile the feedback from the Proofreader, Fact Checker, and Style Enforcer + into a structured report. Summarize key issues and categorize them by type. + Provide actionable recommendations for improving the story, + and give an overall grade based on the feedback.""", + ) - proofreader = Agent( - name="proofreader", - instruction=""""Review the short story for grammar, spelling, and punctuation errors. - Identify any awkward phrasing or structural issues that could improve clarity. - Provide detailed feedback on corrections.""", - ) + parallel = ParallelLLM( + fan_in_agent=grader, + fan_out_agents=[proofreader, fact_checker, style_enforcer], + llm_factory=OpenAIAugmentedLLM, + context=app_ctx if app_ctx else app.context, + ) - fact_checker = Agent( - name="fact_checker", - instruction="""Verify the factual consistency within the story. Identify any contradictions, - logical inconsistencies, or inaccuracies in the plot, character actions, or setting. - Highlight potential issues with reasoning or coherence.""", + try: + result = await parallel.generate_str( + message=f"Student short story submission: {story}", ) + except Exception as e: + logger.error(f"grade_story: Error generating result: {e}") + return None - style_enforcer = Agent( - name="style_enforcer", - instruction="""Analyze the story for adherence to style guidelines. - Evaluate the narrative flow, clarity of expression, and tone. Suggest improvements to - enhance storytelling, readability, and engagement.""", - ) + if not result: + logger.error("grade_story: No result from parallel LLM") + else: + logger.info(f"grade_story: Result: {result}") - grader = Agent( - name="grader", - instruction="""Compile the feedback from the Proofreader, Fact Checker, and Style Enforcer - into a structured report. Summarize key issues and categorize them by type. - Provide actionable recommendations for improving the story, - and give an overall grade based on the feedback.""", - ) + return result - parallel = ParallelLLM( - fan_in_agent=grader, - fan_out_agents=[proofreader, fact_checker, style_enforcer], - llm_factory=OpenAIAugmentedLLM, - context=app.context, - ) +@app.async_tool(name="grade_story_async") +async def grade_story_async(story: str, app_ctx: Optional[AppContext] = None) -> str: + """ + Async variant of grade_story that starts a workflow run and returns IDs. + + Args: + story: The student's short story to grade + app_ctx: Optional MCPApp context for accessing app resources and logging + """ + + # Use the context's app if available for proper logging with upstream_session + _app = app_ctx.app if app_ctx else app + # Ensure the app's logger is bound to the current context with upstream_session + if _app._logger and hasattr(_app._logger, '_bound_context'): + _app._logger._bound_context = app_ctx + logger = _app.logger + logger.info(f"grade_story_async: Received input: {story}") + + proofreader = Agent( + name="proofreader", + instruction="""Review the short story for grammar, spelling, and punctuation errors. + Identify any awkward phrasing or structural issues that could improve clarity. + Provide detailed feedback on corrections.""", + ) + + fact_checker = Agent( + name="fact_checker", + instruction="""Verify the factual consistency within the story. Identify any contradictions, + logical inconsistencies, or inaccuracies in the plot, character actions, or setting. + Highlight potential issues with reasoning or coherence.""", + ) + + style_enforcer = Agent( + name="style_enforcer", + instruction="""Analyze the story for adherence to style guidelines. + Evaluate the narrative flow, clarity of expression, and tone. Suggest improvements to + enhance storytelling, readability, and engagement.""", + ) + + grader = Agent( + name="grader", + instruction="""Compile the feedback from the Proofreader, Fact Checker, and Style Enforcer + into a structured report. Summarize key issues and categorize them by type. + Provide actionable recommendations for improving the story, + and give an overall grade based on the feedback.""", + ) + + parallel = ParallelLLM( + fan_in_agent=grader, + fan_out_agents=[proofreader, fact_checker, style_enforcer], + llm_factory=OpenAIAugmentedLLM, + context=app_ctx if app_ctx else app.context, + ) + + logger.info("grade_story_async: Starting parallel LLM") + + try: result = await parallel.generate_str( - message=f"Student short story submission: {input}", + message=f"Student short story submission: {story}", ) + except Exception as e: + logger.error(f"grade_story_async: Error generating result: {e}") + return None - return WorkflowResult(value=result) + return result # Add custom tool to get token usage for a workflow @@ -313,11 +386,11 @@ async def main(): context.config.mcp.servers["filesystem"].args.extend([os.getcwd()]) # Log registered workflows and agent configurations - logger.info(f"Creating MCP server for {agent_app.name}") + agent_app.logger.info(f"Creating MCP server for {agent_app.name}") - logger.info("Registered workflows:") + agent_app.logger.info("Registered workflows:") for workflow_id in agent_app.workflows: - logger.info(f" - {workflow_id}") + agent_app.logger.info(f" - {workflow_id}") # Create the MCP server that exposes both workflows and agent configurations, # optionally using custom FastMCP settings @@ -327,7 +400,7 @@ async def main(): else None ) mcp_server = create_mcp_server_for_app(agent_app, **(fast_mcp_settings or {})) - logger.info(f"MCP Server settings: {mcp_server.settings}") + agent_app.logger.info(f"MCP Server settings: {mcp_server.settings}") # Run the server await mcp_server.run_stdio_async() diff --git a/examples/mcp_agent_server/asyncio/client.py b/examples/mcp_agent_server/asyncio/client.py index 27ea442ed..1cbfafc27 100644 --- a/examples/mcp_agent_server/asyncio/client.py +++ b/examples/mcp_agent_server/asyncio/client.py @@ -2,11 +2,15 @@ import asyncio import json import time -from mcp.types import CallToolResult +from datetime import timedelta +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession +from mcp.types import CallToolResult, LoggingMessageNotificationParams from mcp_agent.app import MCPApp from mcp_agent.config import MCPServerSettings from mcp_agent.executor.workflow import WorkflowExecution from mcp_agent.mcp.gen_client import gen_client +from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession from rich import print @@ -18,6 +22,12 @@ async def main(): action="store_true", help="Enable custom FastMCP settings for the server", ) + parser.add_argument( + "--server-log-level", + type=str, + default=None, + help="Set initial server logging level (debug, info, notice, warning, error, critical, alert, emergency)", + ) args = parser.parse_args() use_custom_fastmcp_settings = args.custom_fastmcp_settings @@ -45,7 +55,41 @@ async def main(): ) # Connect to the workflow server - async with gen_client("basic_agent_server", context.server_registry) as server: + # Define a logging callback to receive server-side log notifications + async def on_server_log(params: LoggingMessageNotificationParams) -> None: + # Pretty-print server logs locally for demonstration + level = params.level.upper() + name = params.logger or "server" + # params.data can be any JSON-serializable data + print(f"[SERVER LOG] [{level}] [{name}] {params.data}") + + # Provide a client session factory that installs our logging callback + def make_session( + read_stream: MemoryObjectReceiveStream, + write_stream: MemoryObjectSendStream, + read_timeout_seconds: timedelta | None, + ) -> ClientSession: + return MCPAgentClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=read_timeout_seconds, + logging_callback=on_server_log, + ) + + async with gen_client( + "basic_agent_server", + context.server_registry, + client_session_factory=make_session, + ) as server: + # Ask server to send logs at the requested level (default info) + level = (args.server_log_level or "info").lower() + print(f"[client] Setting server logging level to: {level}") + try: + await server.set_logging_level(level) + except Exception: + # Older servers may not support logging capability + print("[client] Server does not support logging/setLevel") + # List available tools tools_result = await server.list_tools() logger.info( @@ -61,7 +105,7 @@ async def main(): data=_tool_result_to_json(workflows_response) or workflows_response, ) - # Call the BasicAgentWorkflow + # Call the BasicAgentWorkflow (run + status) run_result = await server.call_tool( "workflows-BasicAgentWorkflow-run", arguments={ @@ -71,7 +115,22 @@ async def main(): }, ) - execution = WorkflowExecution(**json.loads(run_result.content[0].text)) + # Tolerant parsing of run IDs from tool result + run_payload = _tool_result_to_json(run_result) + if not run_payload: + sc = getattr(run_result, "structuredContent", None) + if isinstance(sc, dict): + run_payload = sc.get("result") or sc + if not run_payload: + # Last resort: parse unstructured content if present and non-empty + if getattr(run_result, "content", None) and run_result.content[0].text: + run_payload = json.loads(run_result.content[0].text) + else: + raise RuntimeError( + "Unable to extract workflow run IDs from tool result" + ) + + execution = WorkflowExecution(**run_payload) run_id = execution.run_id logger.info( f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" @@ -84,7 +143,12 @@ async def main(): arguments={"run_id": run_id}, ) + # Tolerant parsing of get_status result workflow_status = _tool_result_to_json(get_status_result) + if workflow_status is None: + sc = getattr(get_status_result, "structuredContent", None) + if isinstance(sc, dict): + workflow_status = sc.get("result") or sc if workflow_status is None: logger.error( f"Failed to parse workflow status response: {get_status_result}" @@ -108,7 +172,6 @@ async def main(): f"Workflow run {run_id} completed successfully! Result:", data=workflow_status.get("result"), ) - break elif workflow_status.get("status") == "error": logger.error( @@ -135,12 +198,6 @@ async def main(): await asyncio.sleep(5) - # TODO: UNCOMMENT ME to try out cancellation: - # await server.call_tool( - # "workflows-cancel", - # arguments={"workflow_id": "BasicAgentWorkflow", "run_id": run_id}, - # ) - # Get the token usage summary logger.info("Fetching token usage summary...") token_usage_result = await server.call_tool( @@ -159,6 +216,70 @@ async def main(): # Display the token usage summary print(token_usage_result.structuredContent) + # Call the sync tool 'grade_story' separately (no run/status loop) + try: + grade_result = await server.call_tool( + "grade_story", + arguments={"story": "This is a test story."}, + ) + grade_payload = _tool_result_to_json(grade_result) or ( + ( + grade_result.structuredContent.get("result") + if getattr(grade_result, "structuredContent", None) + else None + ) + or (grade_result.content[0].text if grade_result.content else None) + ) + logger.info("grade_story result:", data=grade_payload) + except Exception as e: + logger.error("grade_story call failed", data=str(e)) + + # Call the async tool 'grade_story_async': start then poll status + try: + async_run_result = await server.call_tool( + "grade_story_async-async-run", + arguments={"run_parameters": {"story": "This is a test story."}}, + ) + async_ids = ( + (getattr(async_run_result, "structuredContent", {}) or {}).get( + "result" + ) + or _tool_result_to_json(async_run_result) + or json.loads(async_run_result.content[0].text) + ) + async_run_id = async_ids["run_id"] + logger.info( + f"Started grade_story_async. run ID={async_run_id}", + ) + + # Poll status until completion + while True: + async_status = await server.call_tool( + "grade_story_async-get_status", + arguments={"run_id": async_run_id}, + ) + async_status_json = ( + getattr(async_status, "structuredContent", {}) or {} + ).get("result") or _tool_result_to_json(async_status) + if async_status_json is None: + logger.error( + "grade_story_async: failed to parse status", + data=async_status, + ) + break + logger.info("grade_story_async status:", data=async_status_json) + if async_status_json.get("status") in ( + "completed", + "error", + "cancelled", + ): + break + await asyncio.sleep(2) + except Exception as e: + logger.error("grade_story_async call failed", data=str(e)) + + await asyncio.sleep(5) + def _tool_result_to_json(tool_result: CallToolResult): if tool_result.content and len(tool_result.content) > 0: diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index d05a01d60..d9b059705 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -189,7 +189,13 @@ def session_id(self): def logger(self): if self._logger is None: session_id = self._context.session_id if self._context else None - self._logger = get_logger(f"mcp_agent.{self.name}", session_id=session_id) + self._logger = get_logger( + f"mcp_agent.{self.name}", session_id=session_id, context=self._context + ) + else: + # Update the logger's bound context in case upstream_session was set after logger creation + if self._context and hasattr(self._logger, '_bound_context'): + self._logger._bound_context = self._context return self._logger async def initialize(self): @@ -541,11 +547,112 @@ def _create_workflow_from_function( """ import asyncio as _asyncio + import inspect as _inspect from mcp_agent.executor.workflow import Workflow as _Workflow - async def _invoke_target(*args, **kwargs): + # Determine if the function requests a FastMCP Context param; if so, shim it out + _has_ctx_param = False + try: + sig = _inspect.signature(fn) + params = list(sig.parameters.values()) + # A param is considered a context param if it's annotated as FastMCP Context + try: + from mcp.server.fastmcp import Context as _Ctx # type: ignore + except Exception: + _Ctx = None # type: ignore + for p in params: + if ( + p.annotation is not _inspect._empty + and _Ctx is not None + and p.annotation is _Ctx + ): + _has_ctx_param = True + break + if p.name == "ctx": + _has_ctx_param = True + break + except Exception: + _has_ctx_param = False + + async def _invoke_target(workflow_self, *args, **kwargs): + # The workflow_self is the fresh instance created by Workflow.create() # Support both async and sync callables - res = fn(*args, **kwargs) + call_kwargs = dict(kwargs) + + # Check if the function expects an app context parameter + # Look for parameters named 'app_ctx' or with mcp_agent.core.context.Context type annotation + import inspect as _inspect + + sig = _inspect.signature(fn) + app_context_param_name = None + + for param_name, param in sig.parameters.items(): + # Check if parameter is named app_ctx + if param_name == "app_ctx": + app_context_param_name = param_name + break + # Check if parameter has mcp_agent Context type annotation (not FastMCP Context) + if param.annotation != _inspect.Parameter.empty: + annotation_str = str(param.annotation) + # Look for mcp_agent.core.context.Context specifically + if "mcp_agent.core.context.Context" in annotation_str or ( + "Context" in annotation_str and param_name == "app_ctx" + ): + app_context_param_name = param_name + break + + # If function expects app context, provide the workflow's context + if app_context_param_name and workflow_self._context: + import sys + + print( + f"[DEBUG] _invoke_target: Passing workflow context to parameter '{app_context_param_name}'", + file=sys.stderr, + ) + print( + f"[DEBUG] _invoke_target: Context has upstream_session: {workflow_self._context.upstream_session is not None}", + file=sys.stderr, + ) + print( + f"[DEBUG] _invoke_target: Context.app: {workflow_self._context.app}", + file=sys.stderr, + ) + if workflow_self._context.app: + print( + f"[DEBUG] _invoke_target: App.name: {workflow_self._context.app.name}", + file=sys.stderr, + ) + print( + f"[DEBUG] _invoke_target: App.context.upstream_session: {workflow_self._context.app.context.upstream_session is not None}", + file=sys.stderr, + ) + call_kwargs[app_context_param_name] = workflow_self._context + + # Handle FastMCP ctx parameter separately (for backward compatibility) + if _has_ctx_param and "ctx" not in call_kwargs: + # FastMCP ctx parameter (set to None during workflow run) + call_kwargs["ctx"] = None + + # Emit a high-signal log from inside AutoWorkflow before calling the target + try: + _app_for_log = getattr(workflow_self, "_context", None) + _app_for_log = getattr(_app_for_log, "app", None) + if _app_for_log and getattr(_app_for_log, "logger", None): + _app_for_log.logger.info( + f"AutoWorkflow[{workflow_name}]: invoking tool function", + data={ + "has_upstream_session": bool( + _app_for_log.context.upstream_session + ) + if getattr(_app_for_log, "context", None) + else None + }, + ) + except Exception: + pass + + res = fn(*args, **call_kwargs) + if _asyncio.iscoroutine(res): res = await res @@ -562,16 +669,83 @@ async def _invoke_target(*args, **kwargs): return res async def _run(self, *args, **kwargs): # type: ignore[no-redef] - return await _invoke_target(*args, **kwargs) + # The 'self' here is the fresh workflow instance created by Workflow.create() + # For AutoWorkflow to behave like BasicAgentWorkflow, we need to ensure + # the workflow instance's context is properly set up + + # Debug: log the workflow instance and its context + import sys + + print( + f"[DEBUG AutoWorkflow] Running AutoWorkflow_{workflow_name}", + file=sys.stderr, + ) + print( + f"[DEBUG AutoWorkflow] self._context: {self._context}", file=sys.stderr + ) + if self._context: + print( + f"[DEBUG AutoWorkflow] Context has upstream_session: {self._context.upstream_session is not None}", + file=sys.stderr, + ) + print( + f"[DEBUG AutoWorkflow] Context.app: {self._context.app}", + file=sys.stderr, + ) + + # Emit a high-signal log from AutoWorkflow.run start + try: + _app_for_log = getattr(self, "_context", None) + _app_for_log = getattr(_app_for_log, "app", None) + if _app_for_log and getattr(_app_for_log, "logger", None): + _app_for_log.logger.info( + f"AutoWorkflow[{workflow_name}]: run starting", + data={ + "has_upstream_session": bool( + _app_for_log.context.upstream_session + ) + if getattr(_app_for_log, "context", None) + else None + }, + ) + except Exception: + pass + + # Now invoke the original function with the workflow instance + return await _invoke_target(self, *args, **kwargs) # Decorate run with engine-specific decorator decorated_run = self.workflow_run(_run) # Build the Workflow subclass dynamically + # Build a param-source proxy for schema generation that removes any FastMCP Context parameter + def _make_param_source_proxy(original): + try: + sig = _inspect.signature(original) + params = [p for p in sig.parameters.values() if p.name != "ctx"] + proxy_params = params + + def _proxy(*args, **kwargs): + return None + + # Copy annotations sans ctx + ann = dict(getattr(original, "__annotations__", {})) + if "ctx" in ann: + ann.pop("ctx", None) + _proxy.__annotations__ = ann + _proxy.__signature__ = _inspect.Signature( + parameters=proxy_params, return_annotation=sig.return_annotation + ) + return _proxy + except Exception: + return original + + param_source_proxy = _make_param_source_proxy(fn) + cls_dict: Dict[str, Any] = { "__doc__": description or (fn.__doc__ or ""), "run": decorated_run, - "__mcp_agent_param_source_fn__": fn, + "__mcp_agent_param_source_fn__": param_source_proxy, } if mark_sync_tool: cls_dict["__mcp_agent_sync_tool__"] = True diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index a7b2c6ba0..b3a384d3c 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -94,7 +94,9 @@ def __init__( ContextDependent.__init__(self, context=context) self.name = name or self.__class__.__name__ - self._logger = get_logger(f"workflow.{self.name}") + # Bind workflow logger to the provided context so events can carry + # the current upstream_session even when emitted from background tasks. + self._logger = get_logger(f"workflow.{self.name}", context=context) self._initialized = False self._workflow_id = None # Will be set during run_async self._run_id = None # Will be set during run_async diff --git a/src/mcp_agent/logging/events.py b/src/mcp_agent/logging/events.py index 3f934d427..1244d7e54 100644 --- a/src/mcp_agent/logging/events.py +++ b/src/mcp_agent/logging/events.py @@ -50,6 +50,10 @@ class Event(BaseModel): data: Dict[str, Any] = Field(default_factory=dict) context: EventContext | None = None + # Runtime-only handle for upstream forwarding. Present for listeners to + # use, explicitly excluded from any serialization/dumps. + upstream_session: Any | None = Field(default=None, exclude=True) + # For distributed tracing span_id: str | None = None trace_id: str | None = None diff --git a/src/mcp_agent/logging/listeners.py b/src/mcp_agent/logging/listeners.py index 4099350e9..2b47621a5 100644 --- a/src/mcp_agent/logging/listeners.py +++ b/src/mcp_agent/logging/listeners.py @@ -246,27 +246,17 @@ def __init__(self, event_filter: EventFilter | None = None) -> None: super().__init__(event_filter=event_filter) async def handle_matched_event(self, event: Event) -> None: - # Resolve upstream session from the global/app context when available + # Prefer an upstream session bound directly to this event (most precise) upstream_session: Optional[UpstreamServerSessionProtocol] = None - try: - # Import inline to avoid import cycles at module load - from mcp_agent.core.context import get_current_context - - try: - ctx = get_current_context() - upstream_session = getattr(ctx, "upstream_session", None) - except Exception: - upstream_session = None - except Exception: - upstream_session = None + resolved_via: str | None = None + candidate = getattr(event, "upstream_session", None) + if candidate is not None: + upstream_session = candidate # type: ignore[assignment] + resolved_via = "event" - # First fallback: Event may carry upstream_session injected by logger - if upstream_session is None: - candidate = getattr(event, "upstream_session", None) - if candidate is not None: - upstream_session = candidate # type: ignore[assignment] + # No fallback to global get_current_context() to avoid unsafe globals - # Second fallback: if within an MCP request handling context, use the low-level server RequestContext + # Finally, try the low-level request contextvar if in a request handler if upstream_session is None: try: from mcp.server.lowlevel.server import ( @@ -276,12 +266,15 @@ async def handle_matched_event(self, event: Event) -> None: try: req_ctx = _lowlevel_request_ctx.get() upstream_session = getattr(req_ctx, "session", None) + if upstream_session is not None: + resolved_via = "request_ctx" except LookupError: upstream_session = None except Exception: upstream_session = None if upstream_session is None: + # No upstream_session available, event cannot be forwarded return # Map our EventType to MCP LoggingLevel; fold progress -> info @@ -324,6 +317,16 @@ async def handle_matched_event(self, event: Event) -> None: data=data, logger=logger_name, ) + # Diagnostic path for success + try: + import sys + + print( + f"[mcp_agent] forwarded event via {resolved_via or 'unknown'} (namespace={event.namespace}, name={event.name})", + file=sys.stderr, + ) + except Exception: + pass except Exception as e: # Avoid raising inside listener; best-effort delivery import sys diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index 309ea46fa..c16c7346b 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -15,7 +15,12 @@ from contextlib import asynccontextmanager, contextmanager -from mcp_agent.logging.events import Event, EventContext, EventFilter, EventType +from mcp_agent.logging.events import ( + Event, + EventContext, + EventFilter, + EventType, +) from mcp_agent.logging.listeners import ( BatchingListener, LoggingListener, @@ -31,10 +36,16 @@ class Logger: - `name` can be a custom domain-specific event name, e.g. "ORDER_PLACED". """ - def __init__(self, namespace: str, session_id: str | None = None): + def __init__( + self, namespace: str, session_id: str | None = None, bound_context=None + ): self.namespace = namespace self.session_id = session_id self.event_bus = AsyncEventBus.get() + # Optional reference to an application/context object that may carry + # an "upstream_session" attribute. This allows cached loggers to + # observe the current upstream session without relying on globals. + self._bound_context = bound_context def _ensure_event_loop(self): """Ensure we have an event loop we can use.""" @@ -92,21 +103,37 @@ def event( elif context.session_id is None: context.session_id = self.session_id - # Attach upstream_session from active MCP request context if available so - # listeners can forward logs upstream even from background tasks + # Attach upstream_session to the event so the upstream listener + # can forward reliably, regardless of the current task context. + # 1) Prefer logger-bound app context (set at creation or refreshed by caller) extra_event_fields: Dict[str, Any] = {} try: - from mcp.server.lowlevel.server import request_ctx as _lowlevel_request_ctx # type: ignore + upstream = ( + getattr(self._bound_context, "upstream_session", None) + if getattr(self, "_bound_context", None) is not None + else None + ) + if upstream is not None: + extra_event_fields["upstream_session"] = upstream + except Exception: + pass + # 2) Fallback to low-level request_ctx if available (in-request logs) + if "upstream_session" not in extra_event_fields: try: - req_ctx = _lowlevel_request_ctx.get() - extra_event_fields["upstream_session"] = getattr( - req_ctx, "session", None - ) - except LookupError: + from mcp.server.lowlevel.server import ( + request_ctx as _lowlevel_request_ctx, + ) # type: ignore + + try: + req_ctx = _lowlevel_request_ctx.get() + extra_event_fields["upstream_session"] = getattr( + req_ctx, "session", None + ) + except LookupError: + pass + except Exception: pass - except Exception: - pass evt = Event( type=etype, @@ -251,12 +278,32 @@ async def configure( flush_interval: Default flush interval for batching listener **kwargs: Additional configuration options """ - if cls._initialized: - return - bus = AsyncEventBus.get(transport=transport) # Keep a reference to the provided filter so we can update at runtime - cls._event_filter_ref = event_filter + if event_filter is not None: + cls._event_filter_ref = event_filter + + # If already initialized, ensure critical listeners exist and return + if cls._initialized: + # Forward logs upstream via MCP notifications if upstream_session is configured + try: + from mcp_agent.logging.listeners import MCPUpstreamLoggingListener + + has_upstream_listener = any( + isinstance(listener, MCPUpstreamLoggingListener) + for listener in bus.listeners.values() + ) + if not has_upstream_listener: + from typing import Final as _Final + + MCP_UPSTREAM_LISTENER_NAME: _Final[str] = "mcp_upstream" + bus.add_listener( + MCP_UPSTREAM_LISTENER_NAME, + MCPUpstreamLoggingListener(event_filter=cls._event_filter_ref), + ) + except Exception: + pass + return # Add standard listeners if "logging" not in bus.listeners: @@ -350,7 +397,7 @@ async def managed(cls, **config_kwargs): _loggers: Dict[str, Logger] = {} -def get_logger(namespace: str, session_id: str | None = None) -> Logger: +def get_logger(namespace: str, session_id: str | None = None, context=None) -> Logger: """ Get a logger instance for a given namespace. Creates a new logger if one doesn't exist for this namespace. @@ -358,13 +405,24 @@ def get_logger(namespace: str, session_id: str | None = None) -> Logger: Args: namespace: The namespace for the logger (e.g. "agent.helper", "workflow.demo") session_id: Optional session ID to associate with all events from this logger + context: Deprecated/ignored. Present for backwards compatibility. Returns: A Logger instance for the given namespace """ with _logger_lock: - # Create a new logger if one doesn't exist - if namespace not in _loggers: - _loggers[namespace] = Logger(namespace, session_id) - return _loggers[namespace] + existing = _loggers.get(namespace) + if existing is None: + logger = Logger(namespace, session_id, bound_context=context) + _loggers[namespace] = logger + return logger + # Update session_id/bound context if caller provides them + if session_id is not None: + existing.session_id = session_id + if context is not None: + try: + existing._bound_context = context + except Exception: + pass + return existing diff --git a/src/mcp_agent/logging/transport.py b/src/mcp_agent/logging/transport.py index 1aa732992..2bf78a968 100644 --- a/src/mcp_agent/logging/transport.py +++ b/src/mcp_agent/logging/transport.py @@ -325,7 +325,15 @@ def reset(cls) -> None: # Signal shutdown cls._instance._running = False if hasattr(cls._instance, "_stop_event"): - cls._instance._stop_event.set() + try: + # _stop_event.set() schedules on the event's loop; this can fail if + # the loop is already closed in test teardown. Swallow to ensure + # reset never raises in those cases. + cls._instance._stop_event.set() + except RuntimeError: + pass + except Exception: + pass # Clear the singleton instance cls._instance = None diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 691cf4065..bca2d2fd6 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -106,6 +106,103 @@ def _get_attached_server_context(mcp: FastMCP) -> ServerContext | None: return getattr(mcp, "_mcp_agent_server_context", None) +def _set_upstream_from_request_ctx_if_available(ctx: MCPContext) -> None: + """Attach the low-level server session to the app context for upstream log forwarding. + + This ensures logs emitted from background workflow tasks are forwarded to the client + even when the low-level request contextvar is not available in those tasks. + """ + # First, try to use the session property from the FastMCP Context + session = None + try: + session = ( + ctx.session + ) # This accesses the property which returns ctx.request_context.session + except (AttributeError, ValueError): + # ctx.session property might raise ValueError if context not available + pass + + if session is not None: + app: MCPApp | None = _get_attached_app(ctx.fastmcp) + if app is not None and getattr(app, "context", None) is not None: + # Set on global app context so the logger can access it + old_session = app.context.upstream_session + # Use direct assignment for Pydantic model + app.context.upstream_session = session + import sys + + # Verify it was actually set + actual_value = app.context.upstream_session + # Check via model_dump to see what Pydantic thinks + if hasattr(app.context, "model_dump"): + model_data = app.context.model_dump(exclude_none=False) + us_from_dump = model_data.get("upstream_session") + print( + f"[DEBUG] Set upstream_session on app.context for {app.name}, old={old_session is not None}, new={session is not None}, actual_after_set={actual_value is not None}, from_model_dump={us_from_dump is not None}, context_id={id(app.context)}", + file=sys.stderr, + ) + else: + print( + f"[DEBUG] Set upstream_session on app.context for {app.name}, old={old_session is not None}, new={session is not None}, actual_after_set={actual_value is not None}, context_id={id(app.context)}", + file=sys.stderr, + ) + try: + logger.debug( + f"Set upstream_session on app.context for {app.name}", + data={ + "server": getattr(ctx.fastmcp, "name", None), + "has_session": session is not None, + "app_name": app.name, + "old_session": old_session is not None, + }, + ) + except Exception: + pass + return + else: + import sys + + print( + f"[DEBUG] Could not set upstream_session - app={app is not None}, context={app.context is not None if app else False}", + file=sys.stderr, + ) + logger.debug( + "Could not set upstream_session - no app or context", + data={ + "has_app": app is not None, + "has_context": app.context if app else None, + }, + ) + + # Fallback: try the low-level request_ctx contextvar + try: + from mcp.server.lowlevel.server import request_ctx as _lowlevel_request_ctx # type: ignore + + try: + req_ctx = _lowlevel_request_ctx.get() + except LookupError: + return + + session = getattr(req_ctx, "session", None) + if session is None: + return + + app: MCPApp | None = _get_attached_app(ctx.fastmcp) + if app is not None and getattr(app, "context", None) is not None: + # Set on global app context; listeners read from get_current_context() + app.context.upstream_session = session + try: + logger.debug( + "Attached upstream_session via lowlevel request_ctx", + data={"server": getattr(ctx.fastmcp, "name", None)}, + ) + except Exception: + pass + except Exception: + # Best-effort only + pass + + def _resolve_workflows_and_context( ctx: MCPContext, ) -> Tuple[Dict[str, Type["Workflow"]] | None, Optional["Context"]]: @@ -124,7 +221,13 @@ def _resolve_workflows_and_context( # Fall back to app attached to FastMCP app: MCPApp | None = _get_attached_app(ctx.fastmcp) + if app is not None: + # Ensure the app context has the current request's session set so background logs forward + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass return app.workflows, app.context return None, None @@ -279,18 +382,34 @@ def list_workflows(ctx: MCPContext) -> Dict[str, Dict[str, Any]]: Returns information about each workflow type including name, description, and parameters. This helps in making an informed decision about which workflow to run. """ + # Ensure upstream session is set for any logs emitted during this call + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass result: Dict[str, Dict[str, Any]] = {} workflows, _ = _resolve_workflows_and_context(ctx) workflows = workflows or {} for workflow_name, workflow_cls in workflows.items(): - # Get workflow documentation - run_fn_tool = FastTool.from_function(workflow_cls.run) - - # Define common endpoints for all workflows - endpoints = [ - f"workflows-{workflow_name}-run", - f"workflows-{workflow_name}-get_status", - ] + # Determine parameter schema (strip self / prefer original function) + run_fn_tool = _build_run_param_tool(workflow_cls) + + # Determine endpoints based on whether this is an auto sync/async tool + if getattr(workflow_cls, "__mcp_agent_sync_tool__", False): + endpoints = [ + f"{workflow_name}", + f"{workflow_name}-get_status", + ] + elif getattr(workflow_cls, "__mcp_agent_async_tool__", False): + endpoints = [ + f"{workflow_name}-async-run", + f"{workflow_name}-get_status", + ] + else: + endpoints = [ + f"workflows-{workflow_name}-run", + f"workflows-{workflow_name}-get_status", + ] result[workflow_name] = { "name": workflow_name, @@ -314,6 +433,12 @@ async def list_workflow_runs(ctx: MCPContext) -> List[Dict[str, Any]]: Returns: A dictionary mapping workflow instance IDs to their detailed status information. """ + # Ensure upstream session is set for any logs emitted during this call + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass + server_context = getattr( ctx.request_context, "lifespan_context", None ) or _get_attached_server_context(ctx.fastmcp) @@ -346,6 +471,11 @@ async def run_workflow( A dict with workflow_id and run_id for the started workflow run, can be passed to workflows/get_status, workflows/resume, and workflows/cancel. """ + # Ensure upstream session is set before starting the workflow + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass return await _workflow_run(ctx, workflow_name, run_parameters, **kwargs) @mcp.tool(name="workflows-get_status") @@ -366,6 +496,11 @@ async def get_workflow_status( Returns: A dictionary with comprehensive information about the workflow status. """ + # Ensure upstream session is available for any status-related logs + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass return await _workflow_status(ctx, run_id, workflow_name) @mcp.tool(name="workflows-resume") @@ -393,6 +528,11 @@ async def resume_workflow( Returns: True if the workflow was resumed, False otherwise. """ + # Ensure upstream session is available for any status-related logs + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass server_context: ServerContext = ctx.request_context.lifespan_context workflow_registry = server_context.workflow_registry @@ -435,6 +575,11 @@ async def cancel_workflow( Returns: True if the workflow was cancelled, False otherwise. """ + # Ensure upstream session is available for any status-related logs + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass server_context: ServerContext = ctx.request_context.lifespan_context workflow_registry = server_context.workflow_registry @@ -563,114 +708,172 @@ async def _wait_for_completion( description = decl.get("description") structured_output = decl.get("structured_output") + # Capture loop variables for closures (avoid late-binding bugs) + _wname = workflow_name + _tname = name + if mode == "sync" and fn is not None: sig = inspect.signature(fn) return_ann = sig.return_annotation - async def _wrapper(**kwargs): - # Context will be injected by FastMCP using the special annotation below - ctx: MCPContext = kwargs.pop( - "__context__" - ) # placeholder, reassigned below via signature name - # Start workflow and wait for completion - result_ids = await _workflow_run(ctx, workflow_name, kwargs) - run_id = result_ids["run_id"] - result = await _wait_for_completion(ctx, run_id) - # Unwrap WorkflowResult to match the original function's return type - try: - from mcp_agent.executor.workflow import WorkflowResult as _WFRes - except Exception: - _WFRes = None # type: ignore - if _WFRes is not None and isinstance(result, _WFRes): - return getattr(result, "value", None) - # If get_status returned dict/str, pass through; otherwise return model - return result - - # Attach introspection metadata to match the original function - ann = dict(getattr(fn, "__annotations__", {})) - - # Choose a context kwarg name unlikely to clash with user params - ctx_param_name = "ctx" - from mcp.server.fastmcp import Context as _Ctx - - ann[ctx_param_name] = _Ctx - ann["return"] = getattr(fn, "__annotations__", {}).get("return", return_ann) - _wrapper.__annotations__ = ann - _wrapper.__name__ = name - _wrapper.__doc__ = description or (fn.__doc__ or "") - - # Build a fake signature containing original params plus context kwarg - params = list(sig.parameters.values()) - ctx_param = inspect.Parameter( - ctx_param_name, - kind=inspect.Parameter.KEYWORD_ONLY, - annotation=_Ctx, - ) - _wrapper.__signature__ = inspect.Signature( - parameters=params + [ctx_param], return_annotation=return_ann + # Build a per-tool wrapper bound to this workflow name + def _make_wrapper(bound_wname: str): + async def _wrapper(**kwargs): + # Context will be injected by FastMCP using the special annotation below + ctx: MCPContext = kwargs.pop("__context__") + # Start workflow and wait for completion + result_ids = await _workflow_run(ctx, bound_wname, kwargs) + run_id = result_ids["run_id"] + result = await _wait_for_completion(ctx, run_id) + # Unwrap WorkflowResult to match the original function's return type + try: + from mcp_agent.executor.workflow import WorkflowResult as _WFRes + except Exception: + _WFRes = None # type: ignore + if _WFRes is not None and isinstance(result, _WFRes): + return getattr(result, "value", None) + # If get_status returned dict/str, pass through; otherwise return model + return result + + return _wrapper + + _wrapper = _make_wrapper(_wname) + + # Create adapter that removes app_ctx from the exposed signature + # but still passes it through when the workflow runs + + # Filter out app_ctx from the signature since it's internal + filtered_params = [] + filtered_annotations = {} + orig_annotations = getattr(fn, "__annotations__", {}) + + for param in sig.parameters.values(): + # Skip app_ctx parameter - it will be injected by the workflow + if param.name == "app_ctx": + continue + filtered_params.append(param) + if param.name in orig_annotations: + filtered_annotations[param.name] = orig_annotations[param.name] + + # Add return annotation if present + if "return" in orig_annotations: + filtered_annotations["return"] = orig_annotations["return"] + + # Create filtered signature for FastMCP, but include a keyword-only + # 'ctx' param annotated as MCPContext so FastMCP can inject it. + try: + ctx_param = inspect.Parameter( + name="ctx", + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=MCPContext, + ) + params_for_schema = filtered_params + [ctx_param] + filtered_annotations["ctx"] = MCPContext + except Exception: + params_for_schema = filtered_params + filtered_sig = inspect.Signature( + parameters=params_for_schema, return_annotation=sig.return_annotation ) - # FastMCP expects the actual kwarg name for context; it detects it by annotation - # We need to map the injected kwarg inside the wrapper body. Achieve this by - # creating a thin adapter that renames the injected context kwarg. - async def _adapter(**kw): - # Receive validated args plus injected context kwarg - if ctx_param_name not in kw: - raise ToolError("Context not provided") - # Rename to the placeholder expected by _wrapper - kw["__context__"] = kw.pop(ctx_param_name) + async def _adapter( + ctx: MCPContext | None = None, + context: MCPContext | None = None, + **kw, + ): + # Prefer explicit ctx; some FastMCP versions may use 'context' + ctx_obj = ctx or context + if ctx_obj is not None: + try: + _set_upstream_from_request_ctx_if_available(ctx_obj) + except Exception: + pass + kw["__context__"] = ctx_obj + else: + kw["__context__"] = None return await _wrapper(**kw) - # Copy the visible signature/annotations to adapter for correct schema - _adapter.__annotations__ = _wrapper.__annotations__ - _adapter.__name__ = _wrapper.__name__ - _adapter.__doc__ = _wrapper.__doc__ - _adapter.__signature__ = _wrapper.__signature__ + # Expose a filtered signature (no app_ctx/ctx/context/**kw) for schema, + # but keep actual parameters (ctx/context/**kw) to receive FastMCP Context. + _adapter.__name__ = _tname + _adapter.__doc__ = description or (fn.__doc__ or "") + _adapter.__signature__ = filtered_sig + _adapter.__annotations__ = filtered_annotations # Register the main tool with the same signature as original mcp.add_tool( _adapter, - name=name, + name=_tname, description=description or (fn.__doc__ or ""), structured_output=structured_output, ) - registered.add(name) + registered.add(_tname) # Also register a per-run status tool: -get_status - status_tool_name = f"{name}-get_status" + status_tool_name = f"{_tname}-get_status" if status_tool_name not in registered: - @mcp.tool(name=status_tool_name) - async def _sync_status(ctx: MCPContext, run_id: str) -> Dict[str, Any]: - return await _workflow_status( - ctx, run_id=run_id, workflow_name=workflow_name - ) - + def _make_sync_status(bound_wname: str): + @mcp.tool(name=status_tool_name) + async def _sync_status( + ctx: MCPContext, run_id: str + ) -> Dict[str, Any]: + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass + return await _workflow_status( + ctx, run_id=run_id, workflow_name=bound_wname + ) + + return _sync_status + + _make_sync_status(_wname) registered.add(status_tool_name) elif mode == "async": # Create named aliases for async: -async-run and -get_status - run_tool_name = f"{name}-async-run" - status_tool_name = f"{name}-get_status" + run_tool_name = f"{_tname}-async-run" + status_tool_name = f"{_tname}-get_status" if run_tool_name not in registered: - @mcp.tool(name=run_tool_name) - async def _alias_run( - ctx: MCPContext, run_parameters: Dict[str, Any] | None = None - ) -> Dict[str, str]: - return await _workflow_run(ctx, workflow_name, run_parameters or {}) - + def _make_alias_run(bound_wname: str): + @mcp.tool(name=run_tool_name) + async def _alias_run( + ctx: MCPContext, run_parameters: Dict[str, Any] | None = None + ) -> Dict[str, str]: + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass + return await _workflow_run( + ctx, bound_wname, run_parameters or {} + ) + + return _alias_run + + _make_alias_run(_wname) registered.add(run_tool_name) if status_tool_name not in registered: - @mcp.tool(name=status_tool_name) - async def _alias_status(ctx: MCPContext, run_id: str) -> Dict[str, Any]: - return await _workflow_status( - ctx, run_id=run_id, workflow_name=workflow_name - ) - + def _make_alias_status(bound_wname: str): + @mcp.tool(name=status_tool_name) + async def _alias_status( + ctx: MCPContext, run_id: str + ) -> Dict[str, Any]: + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass + return await _workflow_status( + ctx, run_id=run_id, workflow_name=bound_wname + ) + + return _alias_status + + _make_alias_status(_wname) registered.add(status_tool_name) _set_registered_function_tools(mcp, registered) @@ -725,6 +928,7 @@ async def run( ctx: MCPContext, run_parameters: Dict[str, Any] | None = None, ) -> Dict[str, str]: + _set_upstream_from_request_ctx_if_available(ctx) return await _workflow_run(ctx, workflow_name, run_parameters) @mcp.tool( @@ -737,6 +941,7 @@ async def run( """, ) async def get_status(ctx: MCPContext, run_id: str) -> Dict[str, Any]: + _set_upstream_from_request_ctx_if_available(ctx) return await _workflow_status(ctx, run_id=run_id, workflow_name=workflow_name) @@ -792,6 +997,7 @@ async def _workflow_run( **kwargs: Any, ) -> Dict[str, str]: # Resolve workflows and app context irrespective of startup mode + # This now returns a context with upstream_session already set workflows_dict, app_context = _resolve_workflows_and_context(ctx) if not workflows_dict or not app_context: raise ToolError("Server context not available for MCPApp Server.") @@ -802,9 +1008,20 @@ async def _workflow_run( # Get the workflow class workflow_cls = workflows_dict[workflow_name] + # Bind the app-level logger (cached) to this per-request context so logs + # emitted from AutoWorkflow path forward upstream even outside request_ctx. + try: + app = _get_attached_app(ctx.fastmcp) + if app is not None and getattr(app, "name", None): + from mcp_agent.logging.logger import get_logger as _get_logger + + _get_logger(f"mcp_agent.{app.name}", context=app_context) + except Exception: + pass + # Create and initialize the workflow instance using the factory method try: - # Create workflow instance + # Create workflow instance with context that has upstream_session workflow = await workflow_cls.create(name=workflow_name, context=app_context) run_parameters = run_parameters or {} @@ -839,6 +1056,11 @@ async def _workflow_run( async def _workflow_status( ctx: MCPContext, run_id: str, workflow_name: str | None = None ) -> Dict[str, Any]: + # Ensure upstream session so status-related logs are forwarded + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass workflow_registry: WorkflowRegistry | None = _resolve_workflow_registry(ctx) if not workflow_registry: From 38d23a3ddc44366f101be202270076681a85f2f7 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Mon, 1 Sep 2025 10:42:50 -0400 Subject: [PATCH 06/24] Cleanup --- src/mcp_agent/app.py | 2 +- src/mcp_agent/logging/listeners.py | 53 ++++----------------- src/mcp_agent/logging/logger.py | 17 +------ src/mcp_agent/server/app_server.py | 74 ++---------------------------- 4 files changed, 14 insertions(+), 132 deletions(-) diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index d9b059705..95d617cac 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -194,7 +194,7 @@ def logger(self): ) else: # Update the logger's bound context in case upstream_session was set after logger creation - if self._context and hasattr(self._logger, '_bound_context'): + if self._context and hasattr(self._logger, "_bound_context"): self._logger._bound_context = self._context return self._logger diff --git a/src/mcp_agent/logging/listeners.py b/src/mcp_agent/logging/listeners.py index 2b47621a5..0abc4329f 100644 --- a/src/mcp_agent/logging/listeners.py +++ b/src/mcp_agent/logging/listeners.py @@ -234,44 +234,19 @@ async def _process_batch(self, events: List[Event]): class MCPUpstreamLoggingListener(FilteredListener): """ - Sends matched log events to the connected MCP client via upstream_session - using notifications/message. - - This relies on a globally available Context (see get_current_context()) - which is established when the app initializes. If no upstream_session is - present, events are skipped. + Sends matched log events to the connected MCP client via the upstream_session + carried on each Event (runtime-only field). If no upstream_session is present, + the event is skipped. """ def __init__(self, event_filter: EventFilter | None = None) -> None: super().__init__(event_filter=event_filter) async def handle_matched_event(self, event: Event) -> None: - # Prefer an upstream session bound directly to this event (most precise) - upstream_session: Optional[UpstreamServerSessionProtocol] = None - resolved_via: str | None = None - candidate = getattr(event, "upstream_session", None) - if candidate is not None: - upstream_session = candidate # type: ignore[assignment] - resolved_via = "event" - - # No fallback to global get_current_context() to avoid unsafe globals - - # Finally, try the low-level request contextvar if in a request handler - if upstream_session is None: - try: - from mcp.server.lowlevel.server import ( - request_ctx as _lowlevel_request_ctx, - ) - - try: - req_ctx = _lowlevel_request_ctx.get() - upstream_session = getattr(req_ctx, "session", None) - if upstream_session is not None: - resolved_via = "request_ctx" - except LookupError: - upstream_session = None - except Exception: - upstream_session = None + # Use upstream session provided on the event + upstream_session: Optional[UpstreamServerSessionProtocol] = getattr( + event, "upstream_session", None + ) if upstream_session is None: # No upstream_session available, event cannot be forwarded @@ -317,18 +292,6 @@ async def handle_matched_event(self, event: Event) -> None: data=data, logger=logger_name, ) - # Diagnostic path for success - try: - import sys - - print( - f"[mcp_agent] forwarded event via {resolved_via or 'unknown'} (namespace={event.namespace}, name={event.name})", - file=sys.stderr, - ) - except Exception: - pass except Exception as e: # Avoid raising inside listener; best-effort delivery - import sys - - print(f"[mcp_agent] upstream log send failed: {e}", file=sys.stderr) + _ = e diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index c16c7346b..dc64b1d0c 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -118,22 +118,7 @@ def event( except Exception: pass - # 2) Fallback to low-level request_ctx if available (in-request logs) - if "upstream_session" not in extra_event_fields: - try: - from mcp.server.lowlevel.server import ( - request_ctx as _lowlevel_request_ctx, - ) # type: ignore - - try: - req_ctx = _lowlevel_request_ctx.get() - extra_event_fields["upstream_session"] = getattr( - req_ctx, "session", None - ) - except LookupError: - pass - except Exception: - pass + # No further fallback: rely solely on the bound context for upstream_session evt = Event( type=etype, diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index bca2d2fd6..03802700e 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -126,81 +126,15 @@ def _set_upstream_from_request_ctx_if_available(ctx: MCPContext) -> None: app: MCPApp | None = _get_attached_app(ctx.fastmcp) if app is not None and getattr(app, "context", None) is not None: # Set on global app context so the logger can access it - old_session = app.context.upstream_session + # Previously captured; no need to keep old value # Use direct assignment for Pydantic model app.context.upstream_session = session - import sys - - # Verify it was actually set - actual_value = app.context.upstream_session - # Check via model_dump to see what Pydantic thinks - if hasattr(app.context, "model_dump"): - model_data = app.context.model_dump(exclude_none=False) - us_from_dump = model_data.get("upstream_session") - print( - f"[DEBUG] Set upstream_session on app.context for {app.name}, old={old_session is not None}, new={session is not None}, actual_after_set={actual_value is not None}, from_model_dump={us_from_dump is not None}, context_id={id(app.context)}", - file=sys.stderr, - ) - else: - print( - f"[DEBUG] Set upstream_session on app.context for {app.name}, old={old_session is not None}, new={session is not None}, actual_after_set={actual_value is not None}, context_id={id(app.context)}", - file=sys.stderr, - ) - try: - logger.debug( - f"Set upstream_session on app.context for {app.name}", - data={ - "server": getattr(ctx.fastmcp, "name", None), - "has_session": session is not None, - "app_name": app.name, - "old_session": old_session is not None, - }, - ) - except Exception: - pass + # Minimal, no diagnostics return else: - import sys - - print( - f"[DEBUG] Could not set upstream_session - app={app is not None}, context={app.context is not None if app else False}", - file=sys.stderr, - ) - logger.debug( - "Could not set upstream_session - no app or context", - data={ - "has_app": app is not None, - "has_context": app.context if app else None, - }, - ) - - # Fallback: try the low-level request_ctx contextvar - try: - from mcp.server.lowlevel.server import request_ctx as _lowlevel_request_ctx # type: ignore - - try: - req_ctx = _lowlevel_request_ctx.get() - except LookupError: - return - - session = getattr(req_ctx, "session", None) - if session is None: return - app: MCPApp | None = _get_attached_app(ctx.fastmcp) - if app is not None and getattr(app, "context", None) is not None: - # Set on global app context; listeners read from get_current_context() - app.context.upstream_session = session - try: - logger.debug( - "Attached upstream_session via lowlevel request_ctx", - data={"server": getattr(ctx.fastmcp, "name", None)}, - ) - except Exception: - pass - except Exception: - # Best-effort only - pass + # No low-level request_ctx fallback: upstream_session must come from app context def _resolve_workflows_and_context( @@ -714,7 +648,7 @@ async def _wait_for_completion( if mode == "sync" and fn is not None: sig = inspect.signature(fn) - return_ann = sig.return_annotation + # Preserve original return annotation implicitly via FastMCP tool # Build a per-tool wrapper bound to this workflow name def _make_wrapper(bound_wname: str): From 09099f9f640f06e932b0927eb019070c6d4738e2 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Mon, 1 Sep 2025 11:06:34 -0400 Subject: [PATCH 07/24] Checkpoint --- src/mcp_agent/executor/workflow.py | 63 ++++++++++++++++++++ src/mcp_agent/logging/logger.py | 33 +++++++++++ src/mcp_agent/mcp/client_proxy.py | 70 +++++++++++++++++++++++ src/mcp_agent/server/app_server.py | 92 ++++++++++++++++++++++++++++++ 4 files changed, 258 insertions(+) create mode 100644 src/mcp_agent/mcp/client_proxy.py diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index b3a384d3c..8921b257b 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -22,6 +22,7 @@ ) from mcp_agent.executor.workflow_signal import Signal from mcp_agent.logging.logger import get_logger +from mcp_agent.mcp.client_proxy import log_via_proxy, ask_via_proxy if TYPE_CHECKING: from temporalio.client import WorkflowHandle @@ -246,6 +247,13 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": f"Workflow started with workflow ID: {self._workflow_id}, run ID: {self._run_id}" ) + # Hint the logger with the current run_id for Temporal proxy fallback + try: + if self.context.config.execution_engine == "temporal": + setattr(self._logger, "_temporal_run_id", self._run_id) + except Exception: + pass + # Define the workflow execution function async def _execute_workflow(): try: @@ -359,6 +367,61 @@ async def _execute_workflow(): workflow_id=self._workflow_id, ) + # Engine-aware helpers to unify upstream interactions + async def log_upstream( + self, + level: str, + namespace: str, + message: str, + data: Dict[str, Any] | None = None, + ): + if self.context.config.execution_engine == "temporal": + try: + await log_via_proxy( + self.context.server_registry, + run_id=self._run_id or "", + level=level, + namespace=namespace, + message=message, + data=data or {}, + ) + except Exception: + pass + else: + # asyncio: use local logger + if level == "debug": + self._logger.debug(message, **(data or {})) + elif level == "warning": + self._logger.warning(message, **(data or {})) + elif level == "error": + self._logger.error(message, **(data or {})) + else: + self._logger.info(message, **(data or {})) + + async def ask_user( + self, prompt: str, metadata: Dict[str, Any] | None = None + ) -> Any: + if self.context.config.execution_engine == "temporal": + try: + res = await ask_via_proxy( + self.context.server_registry, + run_id=self._run_id or "", + prompt=prompt, + metadata=metadata or {}, + ) + if isinstance(res, dict): + return res.get("result") if "result" in res else res + return res + except Exception as e: + return {"error": str(e)} + else: + handler = getattr(self.context, "human_input_handler", None) + if not handler: + return None + if asyncio.iscoroutinefunction(handler): # type: ignore[arg-type] + return await handler({"prompt": prompt, "metadata": metadata or {}}) + return handler({"prompt": prompt, "metadata": metadata or {}}) + async def resume( self, signal_name: str | None = "resume", payload: str | None = None ) -> bool: diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index dc64b1d0c..4c6fd8d15 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -129,6 +129,39 @@ def event( data=data, **extra_event_fields, ) + # If we are running under Temporal (logger tagged with a run_id) and we + # don't yet have an upstream session, opportunistically relay via proxy + # to keep the user-facing code unchanged. + try: + if ( + extra_event_fields.get("upstream_session") is None + and getattr(self, "_temporal_run_id", None) + and getattr(self, "_bound_context", None) is not None + ): + from mcp_agent.mcp.client_proxy import ( + log_via_proxy as _log_proxy, + ) # lazy import + + server_registry = getattr(self._bound_context, "server_registry", None) + run_id = getattr(self, "_temporal_run_id", "") + if server_registry and run_id: + # Fire-and-forget best-effort proxy log; don't block emission + try: + loop = self._ensure_event_loop() + loop.create_task( + _log_proxy( + server_registry, + run_id=run_id, + level=str(etype), + namespace=self.namespace, + message=message, + data=data, + ) + ) + except Exception: + pass + except Exception: + pass self._emit_event(evt) def debug( diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py new file mode 100644 index 000000000..bfcf935b4 --- /dev/null +++ b/src/mcp_agent/mcp/client_proxy.py @@ -0,0 +1,70 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import Any, Dict, AsyncIterator + +from mcp_agent.mcp.gen_client import gen_client +from mcp_agent.mcp.mcp_server_registry import ServerRegistry + + +@asynccontextmanager +async def _proxy_client( + server_name: str, + server_registry: ServerRegistry, +) -> AsyncIterator[Any]: + async with gen_client(server_name, server_registry) as client: + yield client + + +async def log_via_proxy( + server_registry: ServerRegistry, + run_id: str, + level: str, + namespace: str, + message: str, + data: Dict[str, Any] | None = None, + server_name: str = "basic_agent_server", +) -> bool: + async with _proxy_client(server_name, server_registry) as client: + try: + await client.call_tool( + "workflows-proxy-log", + arguments={ + "run_id": run_id, + "level": level, + "namespace": namespace, + "message": message, + "data": data or {}, + }, + ) + return True + except Exception: + return False + + +async def ask_via_proxy( + server_registry: ServerRegistry, + run_id: str, + prompt: str, + metadata: Dict[str, Any] | None = None, + server_name: str = "basic_agent_server", +) -> Dict[str, Any]: + async with _proxy_client(server_name, server_registry) as client: + try: + resp = await client.call_tool( + "workflows-proxy-ask", + arguments={ + "run_id": run_id, + "prompt": prompt, + "metadata": metadata or {}, + }, + ) + sc = getattr(resp, "structuredContent", None) + if isinstance(sc, dict) and "result" in sc: + return ( + sc["result"] + if isinstance(sc["result"], dict) + else {"result": sc["result"]} + ) + return {"result": None} + except Exception as e: + return {"error": str(e)} diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 03802700e..145201106 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -7,6 +7,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING +import asyncio from mcp.server.fastmcp import Context as MCPContext, FastMCP from mcp.server.fastmcp.exceptions import ToolError @@ -28,6 +29,25 @@ from mcp_agent.core.context import Context logger = get_logger(__name__) +# Simple in-memory registry mapping workflow run_id -> upstream session handle. +# Allows external workers (e.g., Temporal) to relay logs/prompts through MCPApp. +_RUN_SESSION_REGISTRY: Dict[str, Any] = {} +_RUN_SESSION_LOCK = asyncio.Lock() + + +async def _register_run_session(run_id: str, session: Any) -> None: + async with _RUN_SESSION_LOCK: + _RUN_SESSION_REGISTRY[run_id] = session + + +async def _unregister_run_session(run_id: str) -> None: + async with _RUN_SESSION_LOCK: + _RUN_SESSION_REGISTRY.pop(run_id, None) + + +async def _get_run_session(run_id: str) -> Any | None: + async with _RUN_SESSION_LOCK: + return _RUN_SESSION_REGISTRY.get(run_id) class ServerContext(ContextDependent): @@ -529,6 +549,61 @@ async def cancel_workflow( else: logger.error(f"Failed to cancel workflow {workflow_name} with ID {run_id}") + # region Proxy tools for external runners (e.g., Temporal workers) + + @mcp.tool(name="workflows-proxy-log") + async def proxy_log( + run_id: str, + level: str, + namespace: str, + message: str, + data: Dict[str, Any] | None = None, + ) -> bool: + session = await _get_run_session(run_id) + if not session: + return False + lvl = str(level).lower() + if lvl not in ("debug", "info", "warning", "error"): + lvl = "info" + try: + await session.send_log_message( + level=lvl, # type: ignore[arg-type] + data={ + "message": message, + "namespace": namespace, + "data": data or {}, + }, + logger=namespace, + ) + return True + except Exception: + return False + + @mcp.tool(name="workflows-proxy-ask") + async def proxy_ask( + run_id: str, + prompt: str, + metadata: Dict[str, Any] | None = None, + ) -> Dict[str, Any]: + app = _get_attached_app(mcp) + if app is None or not getattr(app.context, "human_input_handler", None): + return {"error": "human_input not available"} + handler = app.context.human_input_handler + try: + if asyncio.iscoroutinefunction(handler): # type: ignore[arg-type] + result = await handler( + {"prompt": prompt, "metadata": metadata or {}, "run_id": run_id} + ) + else: + result = handler( + {"prompt": prompt, "metadata": metadata or {}, "run_id": run_id} + ) + return {"result": result} + except Exception as e: + return {"error": str(e)} + + # endregion + # endregion return mcp @@ -977,6 +1052,15 @@ async def _workflow_run( f"Workflow {workflow_name} started with workflow ID {execution.workflow_id} and run ID {execution.run_id}. Parameters: {run_parameters}" ) + # Register upstream session for this run so external workers can proxy logs/prompts + try: + if execution.run_id is not None: + await _register_run_session( + execution.run_id, getattr(ctx, "session", None) + ) + except Exception: + pass + return { "workflow_id": execution.workflow_id, "run_id": execution.run_id, @@ -1007,6 +1091,14 @@ async def _workflow_status( run_id=run_id, workflow_id=workflow_id ) + # Cleanup run registry on terminal states + try: + state = str(status.get("status", "")).lower() + if state in ("completed", "error", "cancelled"): + await _unregister_run_session(run_id) + except Exception: + pass + return status From 90265aef41d432bd08aa6dbdfbccb1bd519ea1b7 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Mon, 1 Sep 2025 11:15:40 -0400 Subject: [PATCH 08/24] Checkpoint 2 --- src/mcp_agent/executor/temporal/__init__.py | 6 +++ .../executor/temporal/system_activities.py | 51 +++++++++++++++++++ src/mcp_agent/logging/logger.py | 33 ------------ src/mcp_agent/server/app_server.py | 46 ++++++++++++----- 4 files changed, 90 insertions(+), 46 deletions(-) create mode 100644 src/mcp_agent/executor/temporal/system_activities.py diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index 97ade82db..16468ae50 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -36,6 +36,7 @@ from mcp_agent.executor.workflow_signal import SignalHandler from mcp_agent.logging.logger import get_logger from mcp_agent.utils.common import unwrap +from mcp_agent.executor.temporal.system_activities import SystemActivities if TYPE_CHECKING: from mcp_agent.app import MCPApp @@ -490,6 +491,11 @@ async def create_temporal_worker_for_app(app: "MCPApp"): # Collect activities from the global registry activity_registry = running_app.context.task_registry + # Register system activities (logging, human input proxy) + sys_acts = SystemActivities(context=running_app.context) + app.workflow_task()(sys_acts.forward_log) + app.workflow_task()(sys_acts.request_user_input) + for name in activity_registry.list_activities(): activities.append(activity_registry.get_activity(name)) diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py new file mode 100644 index 000000000..c483c787f --- /dev/null +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -0,0 +1,51 @@ +from typing import Any, Dict + +from temporalio import activity + +from mcp_agent.mcp.client_proxy import log_via_proxy, ask_via_proxy +from mcp_agent.core.context_dependent import ContextDependent + + +class SystemActivities(ContextDependent): + """Activities used by Temporal workflows to interact with the MCPApp gateway.""" + + @activity.defn(name="mcp_forward_log") + async def forward_log( + self, + run_id: str, + level: str, + namespace: str, + message: str, + data: Dict[str, Any] | None = None, + ) -> bool: + registry = self.context.server_registry + return await log_via_proxy( + registry, + run_id=run_id, + level=level, + namespace=namespace, + message=message, + data=data or {}, + ) + + @activity.defn(name="mcp_request_user_input") + async def request_user_input( + self, + session_id: str, + workflow_id: str, + run_id: str, + prompt: str, + signal_name: str = "human_input", + ) -> Dict[str, Any]: + # Reuse proxy ask API; returns {result} or {error} + registry = self.context.server_registry + return await ask_via_proxy( + registry, + run_id=run_id, + prompt=prompt, + metadata={ + "session_id": session_id, + "workflow_id": workflow_id, + "signal_name": signal_name, + }, + ) diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index 4c6fd8d15..dc64b1d0c 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -129,39 +129,6 @@ def event( data=data, **extra_event_fields, ) - # If we are running under Temporal (logger tagged with a run_id) and we - # don't yet have an upstream session, opportunistically relay via proxy - # to keep the user-facing code unchanged. - try: - if ( - extra_event_fields.get("upstream_session") is None - and getattr(self, "_temporal_run_id", None) - and getattr(self, "_bound_context", None) is not None - ): - from mcp_agent.mcp.client_proxy import ( - log_via_proxy as _log_proxy, - ) # lazy import - - server_registry = getattr(self._bound_context, "server_registry", None) - run_id = getattr(self, "_temporal_run_id", "") - if server_registry and run_id: - # Fire-and-forget best-effort proxy log; don't block emission - try: - loop = self._ensure_event_loop() - loop.create_task( - _log_proxy( - server_registry, - run_id=run_id, - level=str(etype), - namespace=self.namespace, - message=message, - data=data, - ) - ) - except Exception: - pass - except Exception: - pass self._emit_event(evt) def debug( diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 145201106..4a836d704 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -585,23 +585,43 @@ async def proxy_ask( prompt: str, metadata: Dict[str, Any] | None = None, ) -> Dict[str, Any]: - app = _get_attached_app(mcp) - if app is None or not getattr(app.context, "human_input_handler", None): - return {"error": "human_input not available"} - handler = app.context.human_input_handler + # Emit a human_input_request notification; client replies via human_input.submit + session = await _get_run_session(run_id) + if not session: + return {"error": "no upstream session for run"} + import uuid + + request_id = str(uuid.uuid4()) + payload = { + "kind": "human_input_request", + "request_id": request_id, + "prompt": {"text": prompt}, + "metadata": metadata or {}, + } try: - if asyncio.iscoroutinefunction(handler): # type: ignore[arg-type] - result = await handler( - {"prompt": prompt, "metadata": metadata or {}, "run_id": run_id} - ) - else: - result = handler( - {"prompt": prompt, "metadata": metadata or {}, "run_id": run_id} - ) - return {"result": result} + await session.send_log_message( + level="info", # type: ignore[arg-type] + data=payload, + logger="mcp_agent.human", + ) + return {"result": {"request_id": request_id}} except Exception as e: return {"error": str(e)} + @mcp.tool(name="human_input.submit") + async def human_input_submit( + request_id: str, text: str, workflow_id: str | None = None + ) -> Dict[str, Any]: + """Client replies to a human_input_request. Signal the Temporal workflow via the registry mapping. + + Note: For a full implementation you may want to persist request_id -> (workflow_id, run_id, signal_name), + but for now we only pass run_id inside client payloads if needed. This endpoint is a thin placeholder + that can be extended later to call TemporalClient.signal_workflow. + """ + # Best-effort stub; you can wire to TemporalClient.signal_workflow here if desired. + # Returning ok=True helps the client UX. + return {"ok": True, "request_id": request_id, "text": text} + # endregion # endregion From 75589d58f5e286d3e63a7adc470dbcb583cd860c Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Mon, 1 Sep 2025 11:18:56 -0400 Subject: [PATCH 09/24] Switch from mcp tool to custom route --- src/mcp_agent/mcp/client_proxy.py | 85 +++++++++--------- src/mcp_agent/server/app_server.py | 134 +++++++++++++++++------------ 2 files changed, 118 insertions(+), 101 deletions(-) diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index bfcf935b4..0cc0f345d 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -1,18 +1,17 @@ -import asyncio -from contextlib import asynccontextmanager -from typing import Any, Dict, AsyncIterator +from typing import Any, Dict + +import httpx -from mcp_agent.mcp.gen_client import gen_client from mcp_agent.mcp.mcp_server_registry import ServerRegistry -@asynccontextmanager -async def _proxy_client( - server_name: str, - server_registry: ServerRegistry, -) -> AsyncIterator[Any]: - async with gen_client(server_name, server_registry) as client: - yield client +def _resolve_gateway_url(server_registry: ServerRegistry, server_name: str) -> str: + cfg = server_registry.get_server_context(server_name) + # Prefer streamable-http if configured; else assume localhost: settings in examples + if cfg and getattr(cfg, "url", None): + return cfg.url.rstrip("/") + host = "http://{}:{}".format("127.0.0.1", 8000) + return host async def log_via_proxy( @@ -24,21 +23,23 @@ async def log_via_proxy( data: Dict[str, Any] | None = None, server_name: str = "basic_agent_server", ) -> bool: - async with _proxy_client(server_name, server_registry) as client: - try: - await client.call_tool( - "workflows-proxy-log", - arguments={ - "run_id": run_id, - "level": level, - "namespace": namespace, - "message": message, - "data": data or {}, - }, - ) - return True - except Exception: + base = _resolve_gateway_url(server_registry, server_name) + url = f"{base}/internal/workflows/log" + async with httpx.AsyncClient(timeout=10) as client: + r = await client.post( + url, + json={ + "run_id": run_id, + "level": level, + "namespace": namespace, + "message": message, + "data": data or {}, + }, + ) + if r.status_code >= 400: return False + resp = r.json() + return bool(resp.get("ok", False)) async def ask_via_proxy( @@ -48,23 +49,17 @@ async def ask_via_proxy( metadata: Dict[str, Any] | None = None, server_name: str = "basic_agent_server", ) -> Dict[str, Any]: - async with _proxy_client(server_name, server_registry) as client: - try: - resp = await client.call_tool( - "workflows-proxy-ask", - arguments={ - "run_id": run_id, - "prompt": prompt, - "metadata": metadata or {}, - }, - ) - sc = getattr(resp, "structuredContent", None) - if isinstance(sc, dict) and "result" in sc: - return ( - sc["result"] - if isinstance(sc["result"], dict) - else {"result": sc["result"]} - ) - return {"result": None} - except Exception as e: - return {"error": str(e)} + base = _resolve_gateway_url(server_registry, server_name) + url = f"{base}/internal/human/prompts" + async with httpx.AsyncClient(timeout=10) as client: + r = await client.post( + url, + json={ + "run_id": run_id, + "prompt": {"text": prompt}, + "metadata": metadata or {}, + }, + ) + if r.status_code >= 400: + return {"error": r.text} + return r.json() diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 4a836d704..d815428aa 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -10,6 +10,8 @@ import asyncio from mcp.server.fastmcp import Context as MCPContext, FastMCP +from starlette.requests import Request +from starlette.responses import JSONResponse from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import Tool as FastTool @@ -276,6 +278,71 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Don't clean up the MCPApp here - let the caller handle that pass + # Helper: install internal HTTP routes (not MCP tools) + def _install_internal_routes(mcp_server: FastMCP) -> None: + @mcp_server.custom_route( + "/internal/workflows/log", methods=["POST"], include_in_schema=False + ) + async def _internal_workflows_log(request: Request): + body = await request.json() + run_id = body.get("run_id") + level = str(body.get("level", "info")).lower() + namespace = body.get("namespace") or "mcp_agent" + message = body.get("message") or "" + data = body.get("data") or {} + + session = await _get_run_session(run_id) + if not session: + return JSONResponse( + {"ok": False, "error": "no session for run"}, status_code=404 + ) + if level not in ("debug", "info", "warning", "error"): + level = "info" + try: + await session.send_log_message( + level=level, # type: ignore[arg-type] + data={ + "message": message, + "namespace": namespace, + "data": data, + }, + logger=namespace, + ) + return JSONResponse({"ok": True}) + except Exception as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=500) + + @mcp_server.custom_route( + "/internal/human/prompts", methods=["POST"], include_in_schema=False + ) + async def _internal_human_prompts(request: Request): + body = await request.json() + run_id = body.get("run_id") + prompt = body.get("prompt") or {} + metadata = body.get("metadata") or {} + + session = await _get_run_session(run_id) + if not session: + return JSONResponse({"error": "no session for run"}, status_code=404) + import uuid + + request_id = str(uuid.uuid4()) + payload = { + "kind": "human_input_request", + "request_id": request_id, + "prompt": prompt if isinstance(prompt, dict) else {"text": str(prompt)}, + "metadata": metadata, + } + try: + await session.send_log_message( + level="info", # type: ignore[arg-type] + data=payload, + logger="mcp_agent.human", + ) + return JSONResponse({"request_id": request_id}) + except Exception as e: + return JSONResponse({"error": str(e)}, status_code=500) + # Create or attach FastMCP server if app.mcp: # Using an externally provided FastMCP instance: attach app and context @@ -294,6 +361,11 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: create_workflow_tools(mcp, server_context) # Register function-declared tools (from @app.tool/@app.async_tool) create_declared_function_tools(mcp, server_context) + # Install internal HTTP routes + try: + _install_internal_routes(mcp) + except Exception: + pass else: mcp = FastMCP( name=app.name or "mcp_agent_server", @@ -307,6 +379,11 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Store the server on the app so it's discoverable and can be extended further app.mcp = mcp setattr(mcp, "_mcp_agent_app", app) + # Install internal HTTP routes + try: + _install_internal_routes(mcp) + except Exception: + pass # Register logging/setLevel handler so client can adjust verbosity dynamically # This enables MCP logging capability in InitializeResult.capabilities.logging @@ -551,62 +628,7 @@ async def cancel_workflow( # region Proxy tools for external runners (e.g., Temporal workers) - @mcp.tool(name="workflows-proxy-log") - async def proxy_log( - run_id: str, - level: str, - namespace: str, - message: str, - data: Dict[str, Any] | None = None, - ) -> bool: - session = await _get_run_session(run_id) - if not session: - return False - lvl = str(level).lower() - if lvl not in ("debug", "info", "warning", "error"): - lvl = "info" - try: - await session.send_log_message( - level=lvl, # type: ignore[arg-type] - data={ - "message": message, - "namespace": namespace, - "data": data or {}, - }, - logger=namespace, - ) - return True - except Exception: - return False - - @mcp.tool(name="workflows-proxy-ask") - async def proxy_ask( - run_id: str, - prompt: str, - metadata: Dict[str, Any] | None = None, - ) -> Dict[str, Any]: - # Emit a human_input_request notification; client replies via human_input.submit - session = await _get_run_session(run_id) - if not session: - return {"error": "no upstream session for run"} - import uuid - - request_id = str(uuid.uuid4()) - payload = { - "kind": "human_input_request", - "request_id": request_id, - "prompt": {"text": prompt}, - "metadata": metadata or {}, - } - try: - await session.send_log_message( - level="info", # type: ignore[arg-type] - data=payload, - logger="mcp_agent.human", - ) - return {"result": {"request_id": request_id}} - except Exception as e: - return {"error": str(e)} + # Removed MCP tools for internal proxying in favor of private HTTP routes @mcp.tool(name="human_input.submit") async def human_input_submit( From 5d7ecfc8931d0f96c93e6447a39259ed0d93f904 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Mon, 1 Sep 2025 11:40:31 -0400 Subject: [PATCH 10/24] Some fixes to determinism --- src/mcp_agent/executor/temporal/__init__.py | 2 +- src/mcp_agent/executor/workflow.py | 66 +++++++++++---------- src/mcp_agent/server/app_server.py | 38 ++++++++---- 3 files changed, 61 insertions(+), 45 deletions(-) diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index 16468ae50..b80bef236 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -176,7 +176,7 @@ async def _execute_task( try: result = await workflow.execute_activity( activity_task, - args=args, + *args, task_queue=self.config.task_queue, schedule_to_close_timeout=schedule_to_close, retry_policy=retry_policy, diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 8921b257b..448af5706 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -22,7 +22,7 @@ ) from mcp_agent.executor.workflow_signal import Signal from mcp_agent.logging.logger import get_logger -from mcp_agent.mcp.client_proxy import log_via_proxy, ask_via_proxy +# (Temporal path now uses activities; HTTP proxy helpers unused here) if TYPE_CHECKING: from temporalio.client import WorkflowHandle @@ -376,51 +376,53 @@ async def log_upstream( data: Dict[str, Any] | None = None, ): if self.context.config.execution_engine == "temporal": + # Route via Temporal activity for determinism try: - await log_via_proxy( - self.context.server_registry, - run_id=self._run_id or "", - level=level, - namespace=namespace, - message=message, - data=data or {}, + act = self.context.task_registry.get_activity("mcp_forward_log") + await self.executor.execute( + act, + self._run_id or "", + level, + namespace, + message, + data or {}, ) except Exception: pass + return + # asyncio: use local logger + if level == "debug": + self._logger.debug(message, **(data or {})) + elif level == "warning": + self._logger.warning(message, **(data or {})) + elif level == "error": + self._logger.error(message, **(data or {})) else: - # asyncio: use local logger - if level == "debug": - self._logger.debug(message, **(data or {})) - elif level == "warning": - self._logger.warning(message, **(data or {})) - elif level == "error": - self._logger.error(message, **(data or {})) - else: - self._logger.info(message, **(data or {})) + self._logger.info(message, **(data or {})) async def ask_user( self, prompt: str, metadata: Dict[str, Any] | None = None ) -> Any: if self.context.config.execution_engine == "temporal": + # Route via Temporal activity for determinism; returns request_id or error try: - res = await ask_via_proxy( - self.context.server_registry, - run_id=self._run_id or "", - prompt=prompt, - metadata=metadata or {}, + act = self.context.task_registry.get_activity("mcp_request_user_input") + return await self.executor.execute( + act, + self.context.session_id or "", + self.id or self.name, + self._run_id or "", + prompt, + (metadata or {}).get("signal_name", "human_input"), ) - if isinstance(res, dict): - return res.get("result") if "result" in res else res - return res except Exception as e: return {"error": str(e)} - else: - handler = getattr(self.context, "human_input_handler", None) - if not handler: - return None - if asyncio.iscoroutinefunction(handler): # type: ignore[arg-type] - return await handler({"prompt": prompt, "metadata": metadata or {}}) - return handler({"prompt": prompt, "metadata": metadata or {}}) + handler = getattr(self.context, "human_input_handler", None) + if not handler: + return None + if asyncio.iscoroutinefunction(handler): # type: ignore[arg-type] + return await handler({"prompt": prompt, "metadata": metadata or {}}) + return handler({"prompt": prompt, "metadata": metadata or {}}) async def resume( self, signal_name: str | None = "resume", payload: str | None = None diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index d815428aa..38940e874 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -35,6 +35,7 @@ # Allows external workers (e.g., Temporal) to relay logs/prompts through MCPApp. _RUN_SESSION_REGISTRY: Dict[str, Any] = {} _RUN_SESSION_LOCK = asyncio.Lock() +_PENDING_PROMPTS: Dict[str, Dict[str, Any]] = {} async def _register_run_session(run_id: str, session: Any) -> None: @@ -631,18 +632,31 @@ async def cancel_workflow( # Removed MCP tools for internal proxying in favor of private HTTP routes @mcp.tool(name="human_input.submit") - async def human_input_submit( - request_id: str, text: str, workflow_id: str | None = None - ) -> Dict[str, Any]: - """Client replies to a human_input_request. Signal the Temporal workflow via the registry mapping. - - Note: For a full implementation you may want to persist request_id -> (workflow_id, run_id, signal_name), - but for now we only pass run_id inside client payloads if needed. This endpoint is a thin placeholder - that can be extended later to call TemporalClient.signal_workflow. - """ - # Best-effort stub; you can wire to TemporalClient.signal_workflow here if desired. - # Returning ok=True helps the client UX. - return {"ok": True, "request_id": request_id, "text": text} + async def human_input_submit(request_id: str, text: str) -> Dict[str, Any]: + """Client replies to a human_input_request; signal the Temporal workflow.""" + app_ref = _get_attached_app(mcp) + if app_ref is None or app_ref.context is None: + return {"ok": False, "error": "server not ready"} + info = _PENDING_PROMPTS.pop(request_id, None) + if not info: + return {"ok": False, "error": "unknown request_id"} + try: + from mcp_agent.executor.temporal import TemporalExecutor + + executor = app_ref.context.executor + if not isinstance(executor, TemporalExecutor): + return {"ok": False, "error": "temporal executor not active"} + client = await executor.ensure_client() + handle = client.get_workflow_handle( + workflow_id=info.get("workflow_id"), run_id=info.get("run_id") + ) + await handle.signal( + info.get("signal_name", "human_input"), + {"request_id": request_id, "text": text}, + ) + return {"ok": True} + except Exception as e: + return {"ok": False, "error": str(e)} # endregion From e2cd642265e76376ca967b4b8e2bcf7d5890b868 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Tue, 2 Sep 2025 14:18:51 -0400 Subject: [PATCH 11/24] checkpoint - sessionproxy --- .../executor/temporal/session_proxy.py | 64 +++++++ .../executor/temporal/system_activities.py | 25 ++- src/mcp_agent/executor/workflow.py | 26 +++ src/mcp_agent/mcp/client_proxy.py | 85 +++++++-- src/mcp_agent/server/app_server.py | 178 +++++++++++++++++- 5 files changed, 360 insertions(+), 18 deletions(-) create mode 100644 src/mcp_agent/executor/temporal/session_proxy.py diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py new file mode 100644 index 000000000..a41f8a85c --- /dev/null +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + + +class SessionProxy: + """ + A 'virtual' MCP ServerSession bound to a Temporal workflow run. + + This proxy exposes a subset of the ServerSession API and routes calls + through generic Temporal activities to keep workflow code deterministic. + + Methods: + - send_log_message(level, data, logger=None, related_request_id=None) + - notify(method, params) + - request(method, params) + """ + + def __init__(self, *, executor, run_id: str): + self._executor = executor + self._run_id = run_id + + @property + def run_id(self) -> str: + return self._run_id + + async def send_log_message( + self, + *, + level: str, + data: Dict[str, Any] | Any, + logger: Optional[str] = None, + related_request_id: Optional[str] = None, + ) -> None: + # Map to notifications/message via generic relay + params: Dict[str, Any] = { + "level": level, + "data": data, + "logger": logger, + } + if related_request_id is not None: + params["related_request_id"] = related_request_id + + activity = self._executor.context.task_registry.get_activity("mcp_relay_notify") + await self._executor.execute( + activity, self._run_id, "notifications/message", params + ) + + async def notify(self, method: str, params: Dict[str, Any] | None = None) -> bool: + activity = self._executor.context.task_registry.get_activity("mcp_relay_notify") + result = await self._executor.execute( + activity, self._run_id, method, params or {} + ) + return bool(result) + + async def request( + self, method: str, params: Dict[str, Any] | None = None + ) -> Dict[str, Any]: + activity = self._executor.context.task_registry.get_activity( + "mcp_relay_request" + ) + return await self._executor.execute( + activity, self._run_id, method, params or {} + ) diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py index c483c787f..347fff88f 100644 --- a/src/mcp_agent/executor/temporal/system_activities.py +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -2,7 +2,12 @@ from temporalio import activity -from mcp_agent.mcp.client_proxy import log_via_proxy, ask_via_proxy +from mcp_agent.mcp.client_proxy import ( + log_via_proxy, + ask_via_proxy, + notify_via_proxy, + request_via_proxy, +) from mcp_agent.core.context_dependent import ContextDependent @@ -49,3 +54,21 @@ async def request_user_input( "signal_name": signal_name, }, ) + + @activity.defn(name="mcp_relay_notify") + async def relay_notify( + self, run_id: str, method: str, params: Dict[str, Any] | None = None + ) -> bool: + registry = self.context.server_registry + return await notify_via_proxy( + registry, run_id=run_id, method=method, params=params or {} + ) + + @activity.defn(name="mcp_relay_request") + async def relay_request( + self, run_id: str, method: str, params: Dict[str, Any] | None = None + ) -> Dict[str, Any]: + registry = self.context.server_registry + return await request_via_proxy( + registry, run_id=run_id, method=method, params=params or {} + ) diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 448af5706..4acc5ca73 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -20,6 +20,7 @@ SignalMailbox, TemporalSignalHandler, ) +from mcp_agent.executor.temporal.session_proxy import SessionProxy from mcp_agent.executor.workflow_signal import Signal from mcp_agent.logging.logger import get_logger # (Temporal path now uses activities; HTTP proxy helpers unused here) @@ -251,6 +252,17 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": try: if self.context.config.execution_engine == "temporal": setattr(self._logger, "_temporal_run_id", self._run_id) + # Ensure upstream_session is a passthrough SessionProxy bound to this run + if ( + getattr(self.context, "upstream_session", None) is None + and self._run_id + ): + try: + self.context.upstream_session = SessionProxy( + executor=self.executor, run_id=self._run_id + ) + except Exception: + pass except Exception: pass @@ -790,6 +802,20 @@ async def initialize(self): "Signal handler not attached: executor.signal_bus is not a TemporalSignalHandler" ) + # Expose a virtual upstream session (passthrough) bound to this run via activities + # This lets any code use context.upstream_session like a real session. + try: + if ( + getattr(self.context, "upstream_session", None) is None + and self._run_id + ): + self.context.upstream_session = SessionProxy( + executor=self.executor, run_id=self._run_id + ) + except Exception: + # Non-fatal if context is immutable early; will be set after run_id assignment in run_async + pass + self._initialized = True self.state.updated_at = datetime.now(timezone.utc).timestamp() diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index 0cc0f345d..dee805127 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -1,29 +1,47 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional +import os import httpx from mcp_agent.mcp.mcp_server_registry import ServerRegistry -def _resolve_gateway_url(server_registry: ServerRegistry, server_name: str) -> str: - cfg = server_registry.get_server_context(server_name) - # Prefer streamable-http if configured; else assume localhost: settings in examples - if cfg and getattr(cfg, "url", None): - return cfg.url.rstrip("/") - host = "http://{}:{}".format("127.0.0.1", 8000) - return host +def _resolve_gateway_url( + server_registry: Optional[ServerRegistry] = None, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, +) -> str: + # Highest precedence: explicit override + if gateway_url: + return gateway_url.rstrip("/") + + # Next: environment variable + env_url = os.environ.get("MCP_GATEWAY_URL") + if env_url: + return env_url.rstrip("/") + + # Next: a registry entry (if provided) + if server_registry and server_name: + cfg = server_registry.get_server_context(server_name) + if cfg and getattr(cfg, "url", None): + return cfg.url.rstrip("/") + + # Fallback: default local server + return "http://127.0.0.1:8000" async def log_via_proxy( - server_registry: ServerRegistry, + server_registry: Optional[ServerRegistry], run_id: str, level: str, namespace: str, message: str, data: Dict[str, Any] | None = None, - server_name: str = "basic_agent_server", + *, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, ) -> bool: - base = _resolve_gateway_url(server_registry, server_name) + base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/workflows/log" async with httpx.AsyncClient(timeout=10) as client: r = await client.post( @@ -43,13 +61,15 @@ async def log_via_proxy( async def ask_via_proxy( - server_registry: ServerRegistry, + server_registry: Optional[ServerRegistry], run_id: str, prompt: str, metadata: Dict[str, Any] | None = None, - server_name: str = "basic_agent_server", + *, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, ) -> Dict[str, Any]: - base = _resolve_gateway_url(server_registry, server_name) + base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/human/prompts" async with httpx.AsyncClient(timeout=10) as client: r = await client.post( @@ -63,3 +83,40 @@ async def ask_via_proxy( if r.status_code >= 400: return {"error": r.text} return r.json() + + +async def notify_via_proxy( + server_registry: Optional[ServerRegistry], + run_id: str, + method: str, + params: Dict[str, Any] | None = None, + *, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, +) -> bool: + base = _resolve_gateway_url(server_registry, server_name, gateway_url) + url = f"{base}/internal/session/by-run/{run_id}/notify" + async with httpx.AsyncClient(timeout=10) as client: + r = await client.post(url, json={"method": method, "params": params or {}}) + if r.status_code >= 400: + return False + resp = r.json() if r.content else {"ok": True} + return bool(resp.get("ok", True)) + + +async def request_via_proxy( + server_registry: Optional[ServerRegistry], + run_id: str, + method: str, + params: Dict[str, Any] | None = None, + *, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, +) -> Dict[str, Any]: + base = _resolve_gateway_url(server_registry, server_name, gateway_url) + url = f"{base}/internal/session/by-run/{run_id}/request" + async with httpx.AsyncClient(timeout=20) as client: + r = await client.post(url, json={"method": method, "params": params or {}}) + if r.status_code >= 400: + return {"error": r.text} + return r.json() diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 38940e874..38835e824 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -7,6 +7,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING +import os import asyncio from mcp.server.fastmcp import Context as MCPContext, FastMCP @@ -36,6 +37,9 @@ _RUN_SESSION_REGISTRY: Dict[str, Any] = {} _RUN_SESSION_LOCK = asyncio.Lock() _PENDING_PROMPTS: Dict[str, Dict[str, Any]] = {} +_PENDING_PROMPTS_LOCK = asyncio.Lock() +_IDEMPOTENCY_KEYS_SEEN: Dict[str, Set[str]] = {} +_IDEMPOTENCY_KEYS_LOCK = asyncio.Lock() async def _register_run_session(run_id: str, session: Any) -> None: @@ -281,6 +285,165 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Helper: install internal HTTP routes (not MCP tools) def _install_internal_routes(mcp_server: FastMCP) -> None: + @mcp_server.custom_route( + "/internal/session/by-run/{run_id}/notify", + methods=["POST"], + include_in_schema=False, + ) + async def _relay_notify(request: Request): + body = await request.json() + run_id = request.path_params.get("run_id") + method = body.get("method") + params = body.get("params") or {} + + # Optional shared-secret auth + gw_token = os.environ.get("MCP_GATEWAY_TOKEN") + if gw_token and request.headers.get("X-MCP-Gateway-Token") != gw_token: + return JSONResponse( + {"ok": False, "error": "unauthorized"}, status_code=401 + ) + + # Optional idempotency handling + idempotency_key = params.get("idempotency_key") + if idempotency_key: + async with _IDEMPOTENCY_KEYS_LOCK: + seen = _IDEMPOTENCY_KEYS_SEEN.setdefault(run_id or "", set()) + if idempotency_key in seen: + return JSONResponse({"ok": True, "idempotent": True}) + seen.add(idempotency_key) + + session = await _get_run_session(run_id) + if not session: + return JSONResponse( + {"ok": False, "error": "session_not_available"}, status_code=503 + ) + + try: + # Special-case the common logging notification helper + if method == "notifications/message": + level = str(params.get("level", "info")) + data = params.get("data") + logger_name = params.get("logger") + related_request_id = params.get("related_request_id") + await session.send_log_message( # type: ignore[attr-defined] + level=level, # type: ignore[arg-type] + data=data, + logger=logger_name, + related_request_id=related_request_id, + ) + elif method == "notifications/progress": + # Minimal support for progress relay + progress_token = params.get("progressToken") + progress = params.get("progress") + total = params.get("total") + message = params.get("message") + await session.send_progress_notification( # type: ignore[attr-defined] + progress_token=progress_token, + progress=progress, + total=total, + message=message, + ) + elif method == "notifications/resources/list_changed": + await session.send_resource_list_changed() # type: ignore[attr-defined] + elif method == "notifications/tools/list_changed": + await session.send_tool_list_changed() # type: ignore[attr-defined] + elif method == "notifications/prompts/list_changed": + await session.send_prompt_list_changed() # type: ignore[attr-defined] + else: + # Unsupported generic notification at this layer + return JSONResponse( + {"ok": False, "error": f"unsupported method: {method}"}, + status_code=400, + ) + + return JSONResponse({"ok": True}) + except Exception as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=500) + + @mcp_server.custom_route( + "/internal/session/by-run/{run_id}/request", + methods=["POST"], + include_in_schema=False, + ) + async def _relay_request(request: Request): + from mcp.types import ( + CreateMessageRequest, + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequest, + ElicitRequestParams, + ElicitResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + EmptyResult, + ServerRequest, + ) + + body = await request.json() + run_id = request.path_params.get("run_id") + method = body.get("method") + params = body.get("params") or {} + + session = await _get_run_session(run_id) + if not session: + return JSONResponse({"error": "session_not_available"}, status_code=503) + + try: + # Map a small set of supported server->client requests + if method == "sampling/createMessage": + req = ServerRequest( + CreateMessageRequest( + method="sampling/createMessage", + params=CreateMessageRequestParams(**params), + ) + ) + result = await session.send_request( # type: ignore[attr-defined] + request=req, + result_type=CreateMessageResult, + ) + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "elicitation/create": + req = ServerRequest( + ElicitRequest( + method="elicitation/create", + params=ElicitRequestParams(**params), + ) + ) + result = await session.send_request( # type: ignore[attr-defined] + request=req, + result_type=ElicitResult, + ) + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "roots/list": + req = ServerRequest(ListRootsRequest(method="roots/list")) + result = await session.send_request( # type: ignore[attr-defined] + request=req, + result_type=ListRootsResult, + ) + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "ping": + req = ServerRequest(PingRequest(method="ping")) + result = await session.send_request( # type: ignore[attr-defined] + request=req, + result_type=EmptyResult, + ) + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + else: + return JSONResponse( + {"error": f"unsupported method: {method}"}, status_code=400 + ) + except Exception as e: + return JSONResponse({"error": str(e)}, status_code=500) + @mcp_server.custom_route( "/internal/workflows/log", methods=["POST"], include_in_schema=False ) @@ -295,7 +458,7 @@ async def _internal_workflows_log(request: Request): session = await _get_run_session(run_id) if not session: return JSONResponse( - {"ok": False, "error": "no session for run"}, status_code=404 + {"ok": False, "error": "session_not_available"}, status_code=503 ) if level not in ("debug", "info", "warning", "error"): level = "info" @@ -324,7 +487,7 @@ async def _internal_human_prompts(request: Request): session = await _get_run_session(run_id) if not session: - return JSONResponse({"error": "no session for run"}, status_code=404) + return JSONResponse({"error": "session_not_available"}, status_code=503) import uuid request_id = str(uuid.uuid4()) @@ -335,6 +498,14 @@ async def _internal_human_prompts(request: Request): "metadata": metadata, } try: + # Store pending prompt correlation for submit tool + async with _PENDING_PROMPTS_LOCK: + _PENDING_PROMPTS[request_id] = { + "workflow_id": metadata.get("workflow_id"), + "run_id": run_id, + "signal_name": metadata.get("signal_name", "human_input"), + "session_id": metadata.get("session_id"), + } await session.send_log_message( level="info", # type: ignore[arg-type] data=payload, @@ -637,7 +808,8 @@ async def human_input_submit(request_id: str, text: str) -> Dict[str, Any]: app_ref = _get_attached_app(mcp) if app_ref is None or app_ref.context is None: return {"ok": False, "error": "server not ready"} - info = _PENDING_PROMPTS.pop(request_id, None) + async with _PENDING_PROMPTS_LOCK: + info = _PENDING_PROMPTS.pop(request_id, None) if not info: return {"ok": False, "error": "unknown request_id"} try: From 67025a184cb91dc756984fd7a9e8aca3927e177d Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Tue, 2 Sep 2025 14:27:51 -0400 Subject: [PATCH 12/24] Some more fixes --- src/mcp_agent/executor/temporal/__init__.py | 8 +++-- src/mcp_agent/mcp/client_proxy.py | 38 +++++++++++++++++---- src/mcp_agent/server/app_server.py | 38 ++++++++++++++------- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index b80bef236..d88a2070a 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -491,10 +491,12 @@ async def create_temporal_worker_for_app(app: "MCPApp"): # Collect activities from the global registry activity_registry = running_app.context.task_registry - # Register system activities (logging, human input proxy) + # Register system activities (logging, human input proxy, generic relays) sys_acts = SystemActivities(context=running_app.context) - app.workflow_task()(sys_acts.forward_log) - app.workflow_task()(sys_acts.request_user_input) + app.workflow_task(name="mcp_forward_log")(sys_acts.forward_log) + app.workflow_task(name="mcp_request_user_input")(sys_acts.request_user_input) + app.workflow_task(name="mcp_relay_notify")(sys_acts.relay_notify) + app.workflow_task(name="mcp_relay_request")(sys_acts.relay_request) for name in activity_registry.list_activities(): activities.append(activity_registry.get_activity(name)) diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index dee805127..f8cff746c 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -43,7 +43,12 @@ async def log_via_proxy( ) -> bool: base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/workflows/log" - async with httpx.AsyncClient(timeout=10) as client: + headers: Dict[str, str] = {} + tok = os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) + async with httpx.AsyncClient(timeout=timeout) as client: r = await client.post( url, json={ @@ -53,6 +58,7 @@ async def log_via_proxy( "message": message, "data": data or {}, }, + headers=headers, ) if r.status_code >= 400: return False @@ -71,7 +77,12 @@ async def ask_via_proxy( ) -> Dict[str, Any]: base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/human/prompts" - async with httpx.AsyncClient(timeout=10) as client: + headers: Dict[str, str] = {} + tok = os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) + async with httpx.AsyncClient(timeout=timeout) as client: r = await client.post( url, json={ @@ -79,6 +90,7 @@ async def ask_via_proxy( "prompt": {"text": prompt}, "metadata": metadata or {}, }, + headers=headers, ) if r.status_code >= 400: return {"error": r.text} @@ -96,8 +108,15 @@ async def notify_via_proxy( ) -> bool: base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/session/by-run/{run_id}/notify" - async with httpx.AsyncClient(timeout=10) as client: - r = await client.post(url, json={"method": method, "params": params or {}}) + headers: Dict[str, str] = {} + tok = os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, json={"method": method, "params": params or {}}, headers=headers + ) if r.status_code >= 400: return False resp = r.json() if r.content else {"ok": True} @@ -115,8 +134,15 @@ async def request_via_proxy( ) -> Dict[str, Any]: base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/session/by-run/{run_id}/request" - async with httpx.AsyncClient(timeout=20) as client: - r = await client.post(url, json={"method": method, "params": params or {}}) + headers: Dict[str, str] = {} + tok = os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "20")) + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, json={"method": method, "params": params or {}}, headers=headers + ) if r.status_code >= 400: return {"error": r.text} return r.json() diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 38835e824..20ab08ff9 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -343,18 +343,16 @@ async def _relay_notify(request: Request): total=total, message=message, ) - elif method == "notifications/resources/list_changed": - await session.send_resource_list_changed() # type: ignore[attr-defined] - elif method == "notifications/tools/list_changed": - await session.send_tool_list_changed() # type: ignore[attr-defined] - elif method == "notifications/prompts/list_changed": - await session.send_prompt_list_changed() # type: ignore[attr-defined] else: - # Unsupported generic notification at this layer - return JSONResponse( - {"ok": False, "error": f"unsupported method: {method}"}, - status_code=400, - ) + # Generic passthrough using low-level RPC if available + rpc = getattr(session, "rpc", None) + if rpc and hasattr(rpc, "notify"): + await rpc.notify(method, params) + else: + return JSONResponse( + {"ok": False, "error": f"unsupported method: {method}"}, + status_code=400, + ) return JSONResponse({"ok": True}) except Exception as e: @@ -390,7 +388,12 @@ async def _relay_request(request: Request): return JSONResponse({"error": "session_not_available"}, status_code=503) try: - # Map a small set of supported server->client requests + # Prefer generic request passthrough if available + rpc = getattr(session, "rpc", None) + if rpc and hasattr(rpc, "request"): + result = await rpc.request(method, params) + return JSONResponse(result) + # Fallback: Map a small set of supported server->client requests if method == "sampling/createMessage": req = ServerRequest( CreateMessageRequest( @@ -455,6 +458,13 @@ async def _internal_workflows_log(request: Request): message = body.get("message") or "" data = body.get("data") or {} + # Optional shared-secret auth + gw_token = os.environ.get("MCP_GATEWAY_TOKEN") + if gw_token and request.headers.get("X-MCP-Gateway-Token") != gw_token: + return JSONResponse( + {"ok": False, "error": "unauthorized"}, status_code=401 + ) + session = await _get_run_session(run_id) if not session: return JSONResponse( @@ -484,6 +494,10 @@ async def _internal_human_prompts(request: Request): run_id = body.get("run_id") prompt = body.get("prompt") or {} metadata = body.get("metadata") or {} + # Optional shared-secret auth + gw_token = os.environ.get("MCP_GATEWAY_TOKEN") + if gw_token and request.headers.get("X-MCP-Gateway-Token") != gw_token: + return JSONResponse({"error": "unauthorized"}, status_code=401) session = await _get_run_session(run_id) if not session: From 5ff81d3b6cf513ee48fd9e0383dabe9593a3c2b9 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Wed, 3 Sep 2025 15:27:00 -0400 Subject: [PATCH 13/24] temp --- src/mcp_agent/core/context.py | 4 ++ src/mcp_agent/executor/temporal/__init__.py | 3 ++ .../executor/temporal/system_activities.py | 26 +++++++++- src/mcp_agent/executor/workflow.py | 35 +++++++++++++ src/mcp_agent/mcp/client_proxy.py | 12 +++-- src/mcp_agent/server/app_server.py | 49 ++++++++++++++++++- 6 files changed, 122 insertions(+), 7 deletions(-) diff --git a/src/mcp_agent/core/context.py b/src/mcp_agent/core/context.py index 462e62653..d449c938a 100644 --- a/src/mcp_agent/core/context.py +++ b/src/mcp_agent/core/context.py @@ -89,6 +89,10 @@ class Context(BaseModel): # Token counting and cost tracking token_counter: Optional[TokenCounter] = None + # Dynamic gateway configuration (per-run overrides via Temporal memo) + gateway_url: str | None = None + gateway_token: str | None = None + model_config = ConfigDict( extra="allow", arbitrary_types_allowed=True, # Tell Pydantic to defer type evaluation diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index d88a2070a..f0aec1bda 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -279,6 +279,7 @@ async def start_workflow( wait_for_result: bool = False, workflow_id: str | None = None, task_queue: str | None = None, + workflow_memo: Dict[str, Any] | None = None, **kwargs: Any, ) -> WorkflowHandle: """ @@ -363,6 +364,7 @@ async def start_workflow( task_queue=task_queue, id_reuse_policy=id_reuse_policy, rpc_metadata=self.config.rpc_metadata or {}, + memo=workflow_memo or {}, ) else: handle: WorkflowHandle = await self.client.start_workflow( @@ -371,6 +373,7 @@ async def start_workflow( task_queue=task_queue, id_reuse_policy=id_reuse_policy, rpc_metadata=self.config.rpc_metadata or {}, + memo=workflow_memo or {}, ) # Wait for the result if requested diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py index 347fff88f..cd6d2b2a4 100644 --- a/src/mcp_agent/executor/temporal/system_activities.py +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -24,6 +24,8 @@ async def forward_log( data: Dict[str, Any] | None = None, ) -> bool: registry = self.context.server_registry + gateway_url = getattr(self.context, "gateway_url", None) + gateway_token = getattr(self.context, "gateway_token", None) return await log_via_proxy( registry, run_id=run_id, @@ -31,6 +33,8 @@ async def forward_log( namespace=namespace, message=message, data=data or {}, + gateway_url=gateway_url, + gateway_token=gateway_token, ) @activity.defn(name="mcp_request_user_input") @@ -44,6 +48,8 @@ async def request_user_input( ) -> Dict[str, Any]: # Reuse proxy ask API; returns {result} or {error} registry = self.context.server_registry + gateway_url = getattr(self.context, "gateway_url", None) + gateway_token = getattr(self.context, "gateway_token", None) return await ask_via_proxy( registry, run_id=run_id, @@ -53,6 +59,8 @@ async def request_user_input( "workflow_id": workflow_id, "signal_name": signal_name, }, + gateway_url=gateway_url, + gateway_token=gateway_token, ) @activity.defn(name="mcp_relay_notify") @@ -60,8 +68,15 @@ async def relay_notify( self, run_id: str, method: str, params: Dict[str, Any] | None = None ) -> bool: registry = self.context.server_registry + gateway_url = getattr(self.context, "gateway_url", None) + gateway_token = getattr(self.context, "gateway_token", None) return await notify_via_proxy( - registry, run_id=run_id, method=method, params=params or {} + registry, + run_id=run_id, + method=method, + params=params or {}, + gateway_url=gateway_url, + gateway_token=gateway_token, ) @activity.defn(name="mcp_relay_request") @@ -69,6 +84,13 @@ async def relay_request( self, run_id: str, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: registry = self.context.server_registry + gateway_url = getattr(self.context, "gateway_url", None) + gateway_token = getattr(self.context, "gateway_token", None) return await request_via_proxy( - registry, run_id=run_id, method=method, params=params or {} + registry, + run_id=run_id, + method=method, + params=params or {}, + gateway_url=gateway_url, + gateway_token=gateway_token, ) diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 4acc5ca73..9197700d2 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -218,6 +218,7 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": # Using __mcp_agent_ prefix to avoid conflicts with user parameters provided_workflow_id = kwargs.pop("__mcp_agent_workflow_id", None) provided_task_queue = kwargs.pop("__mcp_agent_task_queue", None) + workflow_memo = kwargs.pop("__mcp_agent_workflow_memo", None) self.update_status("scheduled") @@ -235,6 +236,7 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": *args, workflow_id=provided_workflow_id, task_queue=provided_task_queue, + workflow_memo=workflow_memo, **kwargs, ) self._workflow_id = handle.id @@ -802,6 +804,39 @@ async def initialize(self): "Signal handler not attached: executor.signal_bus is not a TemporalSignalHandler" ) + # Read memo (if any) and set gateway overrides on context for activities + try: + from temporalio import workflow as _twf + + # Preferred API: direct memo mapping from Temporal runtime + memo_map = None + try: + memo_map = _twf.memo() + except Exception: + # Fallback to info().memo if available + try: + _info = _twf.info() + memo_map = getattr(_info, "memo", None) + except Exception: + memo_map = None + + if isinstance(memo_map, dict): + gw = memo_map.get("gateway_url") + gt = memo_map.get("gateway_token") + if gw: + try: + self.context.gateway_url = gw + except Exception: + pass + if gt: + try: + self.context.gateway_token = gt + except Exception: + pass + except Exception: + # Safe to ignore if called outside workflow sandbox or memo unavailable + pass + # Expose a virtual upstream session (passthrough) bound to this run via activities # This lets any code use context.upstream_session like a real session. try: diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index f8cff746c..80e9feec6 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -40,11 +40,12 @@ async def log_via_proxy( *, server_name: Optional[str] = None, gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, ) -> bool: base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/workflows/log" headers: Dict[str, str] = {} - tok = os.environ.get("MCP_GATEWAY_TOKEN") + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) @@ -74,11 +75,12 @@ async def ask_via_proxy( *, server_name: Optional[str] = None, gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, ) -> Dict[str, Any]: base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/human/prompts" headers: Dict[str, str] = {} - tok = os.environ.get("MCP_GATEWAY_TOKEN") + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) @@ -105,11 +107,12 @@ async def notify_via_proxy( *, server_name: Optional[str] = None, gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, ) -> bool: base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/session/by-run/{run_id}/notify" headers: Dict[str, str] = {} - tok = os.environ.get("MCP_GATEWAY_TOKEN") + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) @@ -131,11 +134,12 @@ async def request_via_proxy( *, server_name: Optional[str] = None, gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, ) -> Dict[str, Any]: base = _resolve_gateway_url(server_registry, server_name, gateway_url) url = f"{base}/internal/session/by-run/{run_id}/request" headers: Dict[str, str] = {} - tok = os.environ.get("MCP_GATEWAY_TOKEN") + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "20")) diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 20ab08ff9..1ee591df7 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -1287,8 +1287,55 @@ async def _workflow_run( if task_queue: run_parameters["__mcp_agent_task_queue"] = task_queue + # Build memo for Temporal runs if gateway info is available + workflow_memo = None + try: + # Prefer explicit kwargs, else infer from request headers/environment + # FastMCP keeps raw request under ctx.request_context.request if available + gateway_url = kwargs.get("gateway_url") + gateway_token = kwargs.get("gateway_token") + + if gateway_url is None: + try: + req = getattr(ctx.request_context, "request", None) + if req is not None: + # Custom header if present + h = req.headers + gateway_url = ( + h.get("X-MCP-Gateway-URL") + or h.get("X-Forwarded-Url") + or h.get("X-Forwarded-Proto") + ) + # Best-effort reconstruction if only proto/host provided + if gateway_url is None: + proto = h.get("X-Forwarded-Proto") or "http" + host = h.get("X-Forwarded-Host") or h.get("Host") + if host: + gateway_url = f"{proto}://{host}" + except Exception: + pass + + if gateway_token is None: + try: + req = getattr(ctx.request_context, "request", None) + if req is not None: + gateway_token = req.headers.get("X-MCP-Gateway-Token") + except Exception: + pass + + if gateway_url or gateway_token: + workflow_memo = { + "gateway_url": gateway_url, + "gateway_token": gateway_token, + } + except Exception: + workflow_memo = None + # Run the workflow asynchronously and get its ID - execution = await workflow.run_async(**run_parameters) + execution = await workflow.run_async( + __mcp_agent_workflow_memo=workflow_memo, + **run_parameters, + ) logger.info( f"Workflow {workflow_name} started with workflow ID {execution.workflow_id} and run ID {execution.run_id}. Parameters: {run_parameters}" From 24bc4f716ae7a1a9e97eb78c4a32fab1086ba1d9 Mon Sep 17 00:00:00 2001 From: Roman van der Krogt Date: Thu, 4 Sep 2025 18:37:16 +0100 Subject: [PATCH 14/24] fixes for temporal --- .../asyncio/basic_agent_server.py | 3 +- .../temporal/basic_agent_server.py | 8 +++ examples/mcp_agent_server/temporal/client.py | 34 ++++++++++- src/mcp_agent/app.py | 7 +++ src/mcp_agent/core/context.py | 1 + src/mcp_agent/executor/temporal/__init__.py | 1 + .../executor/temporal/session_proxy.py | 34 ++++++----- src/mcp_agent/executor/workflow.py | 24 ++++++-- src/mcp_agent/logging/logger.py | 8 +-- src/mcp_agent/logging/transport.py | 18 ++++++ src/mcp_agent/mcp/client_proxy.py | 1 + src/mcp_agent/server/app_server.py | 57 +++++++++++-------- 12 files changed, 141 insertions(+), 55 deletions(-) diff --git a/examples/mcp_agent_server/asyncio/basic_agent_server.py b/examples/mcp_agent_server/asyncio/basic_agent_server.py index 9a8476005..22d02c437 100644 --- a/examples/mcp_agent_server/asyncio/basic_agent_server.py +++ b/examples/mcp_agent_server/asyncio/basic_agent_server.py @@ -10,10 +10,9 @@ import argparse import asyncio import os -import logging from typing import Dict, Any, Optional -from mcp.server.fastmcp import FastMCP, Context as MCPContext +from mcp.server.fastmcp import FastMCP from mcp_agent.core.context import Context as AppContext from mcp_agent.app import MCPApp diff --git a/examples/mcp_agent_server/temporal/basic_agent_server.py b/examples/mcp_agent_server/temporal/basic_agent_server.py index d56a492fc..92f44de2b 100644 --- a/examples/mcp_agent_server/temporal/basic_agent_server.py +++ b/examples/mcp_agent_server/temporal/basic_agent_server.py @@ -58,12 +58,20 @@ async def run( context = app.context context.config.mcp.servers["filesystem"].args.extend([os.getcwd()]) + # Use of the app.logger will forward logs back to the mcp client + app_logger = app.logger + + app_logger.info("Starting finder agent") async with finder_agent: finder_llm = await finder_agent.attach_llm(OpenAIAugmentedLLM) result = await finder_llm.generate_str( message=input, ) + + # forwards the log to the caller + app_logger.info(f"Finder agent completed with result {result}") + # print to the console (for when running locally) print(f"Agent result: {result}") return WorkflowResult(value=result) diff --git a/examples/mcp_agent_server/temporal/client.py b/examples/mcp_agent_server/temporal/client.py index 634dea5c3..fef970f6d 100644 --- a/examples/mcp_agent_server/temporal/client.py +++ b/examples/mcp_agent_server/temporal/client.py @@ -7,6 +7,12 @@ from mcp_agent.executor.workflow import WorkflowExecution from mcp_agent.mcp.gen_client import gen_client +from datetime import timedelta +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession +from mcp.types import LoggingMessageNotificationParams +from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession + async def main(): # Create MCPApp to get the server registry @@ -27,7 +33,33 @@ async def main(): ) # Connect to the workflow server - async with gen_client("basic_agent_server", context.server_registry) as server: + # Define a logging callback to receive server-side log notifications + async def on_server_log(params: LoggingMessageNotificationParams) -> None: + # Pretty-print server logs locally for demonstration + level = params.level.upper() + name = params.logger or "server" + # params.data can be any JSON-serializable data + print(f"[SERVER LOG] [{level}] [{name}] {params.data}") + + # Provide a client session factory that installs our logging callback + def make_session( + read_stream: MemoryObjectReceiveStream, + write_stream: MemoryObjectSendStream, + read_timeout_seconds: timedelta | None, + ) -> ClientSession: + return MCPAgentClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=read_timeout_seconds, + logging_callback=on_server_log, + ) + + # Connect to the workflow server + async with gen_client( + "basic_agent_server", + context.server_registry, + client_session_factory=make_session, + ) as server: # Call the BasicAgentWorkflow run_result = await server.call_tool( "workflows-BasicAgentWorkflow-run", diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index 95d617cac..261153429 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -898,6 +898,13 @@ def decorator(target: Callable[..., R]) -> Callable[..., R]: ) if task_defn: + # prevent trying to decorate an already decorated function + if hasattr(target, "__temporal_activity_definition"): + self.logger.debug( + f"target {name} has __temporal_activity_definition" + ) + return target # Already decorated with @activity + if isinstance(target, MethodType): self_ref = target.__self__ diff --git a/src/mcp_agent/core/context.py b/src/mcp_agent/core/context.py index d449c938a..c43047bd9 100644 --- a/src/mcp_agent/core/context.py +++ b/src/mcp_agent/core/context.py @@ -92,6 +92,7 @@ class Context(BaseModel): # Dynamic gateway configuration (per-run overrides via Temporal memo) gateway_url: str | None = None gateway_token: str | None = None + execution_id: str | None = None model_config = ConfigDict( extra="allow", diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index f0aec1bda..380423014 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -32,6 +32,7 @@ from mcp_agent.config import TemporalSettings from mcp_agent.executor.executor import Executor, ExecutorConfig, R + from mcp_agent.executor.temporal.workflow_signal import TemporalSignalHandler from mcp_agent.executor.workflow_signal import SignalHandler from mcp_agent.logging.logger import get_logger diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index a41f8a85c..b4563297c 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -1,7 +1,8 @@ -from __future__ import annotations - from typing import Any, Dict, Optional +from mcp_agent.core.context import Context +from mcp_agent.executor.temporal.system_activities import SystemActivities + class SessionProxy: """ @@ -16,13 +17,14 @@ class SessionProxy: - request(method, params) """ - def __init__(self, *, executor, run_id: str): + def __init__(self, *, executor, execution_id: str, context: Context): self._executor = executor - self._run_id = run_id + self._execution_id = execution_id + self.sys_acts = SystemActivities(context) @property - def run_id(self) -> str: - return self._run_id + def execution_id(self) -> str: + return self._execution_id async def send_log_message( self, @@ -41,24 +43,20 @@ async def send_log_message( if related_request_id is not None: params["related_request_id"] = related_request_id - activity = self._executor.context.task_registry.get_activity("mcp_relay_notify") - await self._executor.execute( - activity, self._run_id, "notifications/message", params + # We are outside of the temporal loop. So even though we'd like to do something like + # result = await self._executor.execute(self.sys_acts.relay_notify, self.execution_id, "notifications/message", params) + # we can't. + await self.sys_acts.relay_notify( + self.execution_id, "notifications/message", params ) async def notify(self, method: str, params: Dict[str, Any] | None = None) -> bool: - activity = self._executor.context.task_registry.get_activity("mcp_relay_notify") - result = await self._executor.execute( - activity, self._run_id, method, params or {} + result = await self.sys_acts.relay_notify( + self.execution_id, method, params or {} ) return bool(result) async def request( self, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: - activity = self._executor.context.task_registry.get_activity( - "mcp_relay_request" - ) - return await self._executor.execute( - activity, self._run_id, method, params or {} - ) + await self.sys_acts.relay_request(self.execution_id, method, params or {}) diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 9197700d2..823584e61 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -257,11 +257,13 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": # Ensure upstream_session is a passthrough SessionProxy bound to this run if ( getattr(self.context, "upstream_session", None) is None - and self._run_id + and self.context.execution_id ): try: self.context.upstream_session = SessionProxy( - executor=self.executor, run_id=self._run_id + executor=self.executor, + execution_id=self.context.execution_id, + context=self.context, ) except Exception: pass @@ -823,6 +825,7 @@ async def initialize(self): if isinstance(memo_map, dict): gw = memo_map.get("gateway_url") gt = memo_map.get("gateway_token") + e_id = memo_map.get("execution_id") if gw: try: self.context.gateway_url = gw @@ -833,6 +836,11 @@ async def initialize(self): self.context.gateway_token = gt except Exception: pass + if e_id: + try: + self.context.execution_id = e_id + except Exception: + pass except Exception: # Safe to ignore if called outside workflow sandbox or memo unavailable pass @@ -842,11 +850,19 @@ async def initialize(self): try: if ( getattr(self.context, "upstream_session", None) is None - and self._run_id + and self.context.execution_id ): self.context.upstream_session = SessionProxy( - executor=self.executor, run_id=self._run_id + executor=self.executor, + execution_id=self.context.execution_id, + context=self.context, ) + + app = self.context.app + if app: + # Ensure the app's logger is bound to the current context with upstream_session + if app._logger and hasattr(app._logger, "_bound_context"): + app._logger._bound_context = self.context except Exception: # Non-fatal if context is immutable early; will be set after run_id assignment in run_async pass diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index dc64b1d0c..09357c77d 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -78,12 +78,7 @@ def _emit_event(self, event: Event): # Handle Temporal workflow environment where run_until_complete() is not implemented # In Temporal, we can't block on async operations, so we'll need to avoid this # Simply log to stdout/stderr as a fallback - import sys - - print( - f"[{event.type}] {event.namespace}: {event.message}", - file=sys.stderr, - ) + self.event_bus.emit_with_stderr_transport(event) def event( self, @@ -113,6 +108,7 @@ def event( if getattr(self, "_bound_context", None) is not None else None ) + if upstream is not None: extra_event_fields["upstream_session"] = upstream except Exception: diff --git a/src/mcp_agent/logging/transport.py b/src/mcp_agent/logging/transport.py index 2bf78a968..c067712fe 100644 --- a/src/mcp_agent/logging/transport.py +++ b/src/mcp_agent/logging/transport.py @@ -8,6 +8,7 @@ import json import uuid import datetime +import sys from abc import ABC, abstractmethod from typing import Dict, List, Protocol from pathlib import Path @@ -432,6 +433,23 @@ async def emit(self, event: Event): # Then queue for listeners await self._queue.put(event) + def emit_with_stderr_transport(self, event: Event): + print( + f"[{event.type}] {event.namespace}: {event.message}", + file=sys.stderr, + ) + + # Initialize queue and start processing if needed + if not hasattr(self, "_queue"): + self.init_queue() + # Auto-start the event processing task if not running + if not self._running: + self._running = True + self._task = asyncio.create_task(self._process_events()) + + # Then queue for listeners + self._queue.put_nowait(event) + async def _send_to_transport(self, event: Event): """Send event to transport with error handling.""" try: diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index 80e9feec6..8250a8aed 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -116,6 +116,7 @@ async def notify_via_proxy( if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) + async with httpx.AsyncClient(timeout=timeout) as client: r = await client.post( url, json={"method": method, "params": params or {}}, headers=headers diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 1ee591df7..10ace99fe 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -8,6 +8,7 @@ from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING import os +import uuid import asyncio from mcp.server.fastmcp import Context as MCPContext, FastMCP @@ -32,9 +33,10 @@ from mcp_agent.core.context import Context logger = get_logger(__name__) -# Simple in-memory registry mapping workflow run_id -> upstream session handle. +# Simple in-memory registry mapping workflow execution_id -> upstream session handle. # Allows external workers (e.g., Temporal) to relay logs/prompts through MCPApp. _RUN_SESSION_REGISTRY: Dict[str, Any] = {} +_RUN_EXECUTION_ID_REGISTRY: Dict[str, str] = {} _RUN_SESSION_LOCK = asyncio.Lock() _PENDING_PROMPTS: Dict[str, Dict[str, Any]] = {} _PENDING_PROMPTS_LOCK = asyncio.Lock() @@ -42,19 +44,22 @@ _IDEMPOTENCY_KEYS_LOCK = asyncio.Lock() -async def _register_run_session(run_id: str, session: Any) -> None: +async def _register_session(run_id: str, execution_id: str, session: Any) -> None: async with _RUN_SESSION_LOCK: - _RUN_SESSION_REGISTRY[run_id] = session + _RUN_SESSION_REGISTRY[execution_id] = session + _RUN_EXECUTION_ID_REGISTRY[run_id] = execution_id -async def _unregister_run_session(run_id: str) -> None: +async def _unregister_session(run_id: str) -> None: async with _RUN_SESSION_LOCK: - _RUN_SESSION_REGISTRY.pop(run_id, None) + execution_id = _RUN_EXECUTION_ID_REGISTRY.pop(run_id, None) + if execution_id: + _RUN_SESSION_REGISTRY.pop(execution_id, None) -async def _get_run_session(run_id: str) -> Any | None: +async def _get_session(execution_id: str) -> Any | None: async with _RUN_SESSION_LOCK: - return _RUN_SESSION_REGISTRY.get(run_id) + return _RUN_SESSION_REGISTRY.get(execution_id) class ServerContext(ContextDependent): @@ -286,13 +291,13 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Helper: install internal HTTP routes (not MCP tools) def _install_internal_routes(mcp_server: FastMCP) -> None: @mcp_server.custom_route( - "/internal/session/by-run/{run_id}/notify", + "/internal/session/by-run/{execution_id}/notify", methods=["POST"], include_in_schema=False, ) async def _relay_notify(request: Request): body = await request.json() - run_id = request.path_params.get("run_id") + execution_id = request.path_params.get("execution_id") method = body.get("method") params = body.get("params") or {} @@ -307,12 +312,12 @@ async def _relay_notify(request: Request): idempotency_key = params.get("idempotency_key") if idempotency_key: async with _IDEMPOTENCY_KEYS_LOCK: - seen = _IDEMPOTENCY_KEYS_SEEN.setdefault(run_id or "", set()) + seen = _IDEMPOTENCY_KEYS_SEEN.setdefault(execution_id or "", set()) if idempotency_key in seen: return JSONResponse({"ok": True, "idempotent": True}) seen.add(idempotency_key) - session = await _get_run_session(run_id) + session = await _get_session(execution_id) if not session: return JSONResponse( {"ok": False, "error": "session_not_available"}, status_code=503 @@ -359,7 +364,7 @@ async def _relay_notify(request: Request): return JSONResponse({"ok": False, "error": str(e)}, status_code=500) @mcp_server.custom_route( - "/internal/session/by-run/{run_id}/request", + "/internal/session/by-run/{execution_id}/request", methods=["POST"], include_in_schema=False, ) @@ -379,11 +384,11 @@ async def _relay_request(request: Request): ) body = await request.json() - run_id = request.path_params.get("run_id") + execution_id = request.path_params.get("execution_id") method = body.get("method") params = body.get("params") or {} - session = await _get_run_session(run_id) + session = await _get_session(execution_id) if not session: return JSONResponse({"error": "session_not_available"}, status_code=503) @@ -452,7 +457,7 @@ async def _relay_request(request: Request): ) async def _internal_workflows_log(request: Request): body = await request.json() - run_id = body.get("run_id") + execution_id = body.get("execution_id") level = str(body.get("level", "info")).lower() namespace = body.get("namespace") or "mcp_agent" message = body.get("message") or "" @@ -465,7 +470,7 @@ async def _internal_workflows_log(request: Request): {"ok": False, "error": "unauthorized"}, status_code=401 ) - session = await _get_run_session(run_id) + session = await _get_session(execution_id) if not session: return JSONResponse( {"ok": False, "error": "session_not_available"}, status_code=503 @@ -491,15 +496,16 @@ async def _internal_workflows_log(request: Request): ) async def _internal_human_prompts(request: Request): body = await request.json() - run_id = body.get("run_id") + execution_id = body.get("execution_id") prompt = body.get("prompt") or {} metadata = body.get("metadata") or {} + # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") if gw_token and request.headers.get("X-MCP-Gateway-Token") != gw_token: return JSONResponse({"error": "unauthorized"}, status_code=401) - session = await _get_run_session(run_id) + session = await _get_session(execution_id) if not session: return JSONResponse({"error": "session_not_available"}, status_code=503) import uuid @@ -516,7 +522,7 @@ async def _internal_human_prompts(request: Request): async with _PENDING_PROMPTS_LOCK: _PENDING_PROMPTS[request_id] = { "workflow_id": metadata.get("workflow_id"), - "run_id": run_id, + "execution_id": execution_id, "signal_name": metadata.get("signal_name", "human_input"), "session_id": metadata.get("session_id"), } @@ -1247,6 +1253,10 @@ async def _workflow_run( run_parameters: Dict[str, Any] | None = None, **kwargs: Any, ) -> Dict[str, str]: + # Generate a unique execution ID to track this run. We need to pass this to the workflow, and the run_id is only established + # after we create the workflow + execution_id = str(uuid.uuid4()) + # Resolve workflows and app context irrespective of startup mode # This now returns a context with upstream_session already set workflows_dict, app_context = _resolve_workflows_and_context(ctx) @@ -1327,6 +1337,7 @@ async def _workflow_run( workflow_memo = { "gateway_url": gateway_url, "gateway_token": gateway_token, + "execution_id": execution_id, } except Exception: workflow_memo = None @@ -1343,16 +1354,14 @@ async def _workflow_run( # Register upstream session for this run so external workers can proxy logs/prompts try: - if execution.run_id is not None: - await _register_run_session( - execution.run_id, getattr(ctx, "session", None) - ) + await _register_session(execution_id, getattr(ctx, "session", None)) except Exception: pass return { "workflow_id": execution.workflow_id, "run_id": execution.run_id, + "execution_id": execution_id, } except Exception as e: @@ -1384,7 +1393,7 @@ async def _workflow_status( try: state = str(status.get("status", "")).lower() if state in ("completed", "error", "cancelled"): - await _unregister_run_session(run_id) + await _unregister_session(run_id) except Exception: pass From 463ddb577b77164ce6ba7fa13bca38383804190b Mon Sep 17 00:00:00 2001 From: Roman van der Krogt Date: Thu, 4 Sep 2025 19:31:41 +0100 Subject: [PATCH 15/24] linter & test fixes --- .../executor/temporal/session_proxy.py | 4 +++ .../executor/temporal/system_activities.py | 16 ++++----- src/mcp_agent/executor/workflow.py | 36 +++++++++++-------- src/mcp_agent/logging/logger.py | 16 ++++----- src/mcp_agent/mcp/client_proxy.py | 16 ++++----- src/mcp_agent/server/app_server.py | 9 +++-- tests/executor/test_workflow.py | 5 ++- tests/logging/test_upstream_logging.py | 11 ++---- tests/test_app.py | 6 ++-- 9 files changed, 66 insertions(+), 53 deletions(-) diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index b4563297c..e29010f65 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -26,6 +26,10 @@ def __init__(self, *, executor, execution_id: str, context: Context): def execution_id(self) -> str: return self._execution_id + @execution_id.setter + def execution_id(self, value: str): + self._execution_id = value + async def send_log_message( self, *, diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py index cd6d2b2a4..b215ece79 100644 --- a/src/mcp_agent/executor/temporal/system_activities.py +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -17,7 +17,7 @@ class SystemActivities(ContextDependent): @activity.defn(name="mcp_forward_log") async def forward_log( self, - run_id: str, + execution_id: str, level: str, namespace: str, message: str, @@ -28,7 +28,7 @@ async def forward_log( gateway_token = getattr(self.context, "gateway_token", None) return await log_via_proxy( registry, - run_id=run_id, + execution_id=execution_id, level=level, namespace=namespace, message=message, @@ -42,7 +42,7 @@ async def request_user_input( self, session_id: str, workflow_id: str, - run_id: str, + execution_id: str, prompt: str, signal_name: str = "human_input", ) -> Dict[str, Any]: @@ -52,7 +52,7 @@ async def request_user_input( gateway_token = getattr(self.context, "gateway_token", None) return await ask_via_proxy( registry, - run_id=run_id, + execution_id=execution_id, prompt=prompt, metadata={ "session_id": session_id, @@ -65,14 +65,14 @@ async def request_user_input( @activity.defn(name="mcp_relay_notify") async def relay_notify( - self, run_id: str, method: str, params: Dict[str, Any] | None = None + self, execution_id: str, method: str, params: Dict[str, Any] | None = None ) -> bool: registry = self.context.server_registry gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) return await notify_via_proxy( registry, - run_id=run_id, + execution_id=execution_id, method=method, params=params or {}, gateway_url=gateway_url, @@ -81,14 +81,14 @@ async def relay_notify( @activity.defn(name="mcp_relay_request") async def relay_request( - self, run_id: str, method: str, params: Dict[str, Any] | None = None + self, execution_id: str, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: registry = self.context.server_registry gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) return await request_via_proxy( registry, - run_id=run_id, + execution_id=execution_id, method=method, params=params or {}, gateway_url=gateway_url, diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 823584e61..710067456 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -826,6 +826,11 @@ async def initialize(self): gw = memo_map.get("gateway_url") gt = memo_map.get("gateway_token") e_id = memo_map.get("execution_id") + + self._logger.debug( + f"Proxy parameters: gateway_url={gw}, gateway_token={gt}, execution_id={e_id}" + ) + if gw: try: self.context.gateway_url = gw @@ -848,21 +853,24 @@ async def initialize(self): # Expose a virtual upstream session (passthrough) bound to this run via activities # This lets any code use context.upstream_session like a real session. try: - if ( - getattr(self.context, "upstream_session", None) is None - and self.context.execution_id - ): - self.context.upstream_session = SessionProxy( - executor=self.executor, - execution_id=self.context.execution_id, - context=self.context, - ) + upstream_session = getattr(self.context, "upstream_session", None) + + if upstream_session is None: + if self.context.execution_id: + self.context.upstream_session = SessionProxy( + executor=self.executor, + execution_id=self.context.execution_id, + context=self.context, + ) - app = self.context.app - if app: - # Ensure the app's logger is bound to the current context with upstream_session - if app._logger and hasattr(app._logger, "_bound_context"): - app._logger._bound_context = self.context + app = self.context.app + if app: + # Ensure the app's logger is bound to the current context with upstream_session + if app._logger and hasattr(app._logger, "_bound_context"): + app._logger._bound_context = self.context + elif self.context.execution_id: + # ensure the upstream session's execution_id is the current one. (We may be in a different workflow.) + upstream_session.execution_id = self.context.execution_id except Exception: # Non-fatal if context is immutable early; will be set after run_id assignment in run_async pass diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index 09357c77d..89089fbb9 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -386,8 +386,7 @@ def get_logger(namespace: str, session_id: str | None = None, context=None) -> L Args: namespace: The namespace for the logger (e.g. "agent.helper", "workflow.demo") session_id: Optional session ID to associate with all events from this logger - context: Deprecated/ignored. Present for backwards compatibility. - + context: Optional context to bind to the logger Returns: A Logger instance for the given namespace """ @@ -398,12 +397,11 @@ def get_logger(namespace: str, session_id: str | None = None, context=None) -> L logger = Logger(namespace, session_id, bound_context=context) _loggers[namespace] = logger return logger - # Update session_id/bound context if caller provides them - if session_id is not None: - existing.session_id = session_id - if context is not None: - try: + else: + # Update session_id/bound context if caller provides them + if session_id is not None: + existing.session_id = session_id + if context is not None: existing._bound_context = context - except Exception: - pass + return existing diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index 8250a8aed..e07db0d80 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -32,7 +32,7 @@ def _resolve_gateway_url( async def log_via_proxy( server_registry: Optional[ServerRegistry], - run_id: str, + execution_id: str, level: str, namespace: str, message: str, @@ -53,7 +53,7 @@ async def log_via_proxy( r = await client.post( url, json={ - "run_id": run_id, + "execution_id": execution_id, "level": level, "namespace": namespace, "message": message, @@ -69,7 +69,7 @@ async def log_via_proxy( async def ask_via_proxy( server_registry: Optional[ServerRegistry], - run_id: str, + execution_id: str, prompt: str, metadata: Dict[str, Any] | None = None, *, @@ -88,7 +88,7 @@ async def ask_via_proxy( r = await client.post( url, json={ - "run_id": run_id, + "execution_id": execution_id, "prompt": {"text": prompt}, "metadata": metadata or {}, }, @@ -101,7 +101,7 @@ async def ask_via_proxy( async def notify_via_proxy( server_registry: Optional[ServerRegistry], - run_id: str, + execution_id: str, method: str, params: Dict[str, Any] | None = None, *, @@ -110,7 +110,7 @@ async def notify_via_proxy( gateway_token: Optional[str] = None, ) -> bool: base = _resolve_gateway_url(server_registry, server_name, gateway_url) - url = f"{base}/internal/session/by-run/{run_id}/notify" + url = f"{base}/internal/session/by-run/{execution_id}/notify" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: @@ -129,7 +129,7 @@ async def notify_via_proxy( async def request_via_proxy( server_registry: Optional[ServerRegistry], - run_id: str, + execution_id: str, method: str, params: Dict[str, Any] | None = None, *, @@ -138,7 +138,7 @@ async def request_via_proxy( gateway_token: Optional[str] = None, ) -> Dict[str, Any]: base = _resolve_gateway_url(server_registry, server_name, gateway_url) - url = f"{base}/internal/session/by-run/{run_id}/request" + url = f"{base}/internal/session/by-run/{execution_id}/request" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 10ace99fe..292746e31 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -1349,12 +1349,17 @@ async def _workflow_run( ) logger.info( - f"Workflow {workflow_name} started with workflow ID {execution.workflow_id} and run ID {execution.run_id}. Parameters: {run_parameters}" + f"Workflow {workflow_name} started execution {execution_id} for workflow ID {execution.workflow_id}, " + f"run ID {execution.run_id}. Parameters: {run_parameters}" ) # Register upstream session for this run so external workers can proxy logs/prompts try: - await _register_session(execution_id, getattr(ctx, "session", None)) + await _register_session( + run_id=execution.run_id, + execution_id=execution_id, + session=getattr(ctx, "session", None), + ) except Exception: pass diff --git a/tests/executor/test_workflow.py b/tests/executor/test_workflow.py index 71df7c544..bcb11bff1 100644 --- a/tests/executor/test_workflow.py +++ b/tests/executor/test_workflow.py @@ -346,7 +346,10 @@ async def test_run_async_with_temporal_custom_params(self, mock_context): # Verify start_workflow was called with correct parameters workflow.executor.start_workflow.assert_called_once_with( - "TestWorkflow", workflow_id=custom_workflow_id, task_queue=custom_task_queue + "TestWorkflow", + workflow_id=custom_workflow_id, + task_queue=custom_task_queue, + workflow_memo=None, ) # Verify execution uses the handle's ID diff --git a/tests/logging/test_upstream_logging.py b/tests/logging/test_upstream_logging.py index 96184ec99..012b3a3de 100644 --- a/tests/logging/test_upstream_logging.py +++ b/tests/logging/test_upstream_logging.py @@ -30,19 +30,13 @@ async def test_upstream_logging_listener_sends_notifications(monkeypatch): dummy_session = DummyUpstreamSession() - # Monkeypatch get_current_context to return an object with upstream_session - def _fake_get_current_context(): - return SimpleNamespace(upstream_session=dummy_session) - - monkeypatch.setattr( - "mcp_agent.core.context.get_current_context", _fake_get_current_context - ) + current_context = SimpleNamespace(upstream_session=dummy_session) # Configure logging with low threshold so our event passes await LoggingConfig.configure(event_filter=EventFilter(min_level="debug")) try: - logger = get_logger("tests.logging") + logger = get_logger("tests.logging", context=current_context) logger.info("hello world", name="unit", foo="bar") # Give the async bus a moment to process @@ -76,4 +70,3 @@ async def test_logging_capability_registered_in_fastmcp(): # The presence of a SetLevelRequest handler indicates logging capability will be advertised assert types.SetLevelRequest in low.request_handlers - diff --git a/tests/test_app.py b/tests/test_app.py index 677163457..b6a3e8993 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -414,7 +414,7 @@ async def test_logger_property(self, basic_app): # First call creates the logger assert basic_app.logger is mock_logger mock_get_logger.assert_called_once_with( - f"mcp_agent.{basic_app.name}", session_id=None + f"mcp_agent.{basic_app.name}", session_id=None, context=None ) # Reset mock @@ -450,7 +450,9 @@ async def test_logger_property_with_session_id(self, basic_app, mock_context): # Get the logger - this should call get_logger with the session_id assert basic_app.logger is mock_logger mock_get_logger.assert_called_once_with( - f"mcp_agent.{basic_app.name}", session_id=mock_context.session_id + f"mcp_agent.{basic_app.name}", + session_id=mock_context.session_id, + context=mock_context, ) # From b71467c2652d943244bcedbe7f8398454e76e1c0 Mon Sep 17 00:00:00 2001 From: Roman van der Krogt Date: Thu, 4 Sep 2025 22:24:42 +0100 Subject: [PATCH 16/24] make it work again --- examples/mcp_agent_server/temporal/client.py | 170 +++++++------- src/mcp_agent/app.py | 2 + src/mcp_agent/core/context.py | 1 - src/mcp_agent/executor/temporal/__init__.py | 5 +- .../executor/temporal/interceptor.py | 207 ++++++++++++++++++ .../executor/temporal/session_proxy.py | 18 +- .../executor/temporal/temporal_context.py | 14 ++ src/mcp_agent/executor/workflow.py | 50 ++--- src/mcp_agent/logging/logger.py | 17 +- src/mcp_agent/logging/transport.py | 1 - src/mcp_agent/server/app_server.py | 5 +- tests/test_app.py | 6 +- 12 files changed, 351 insertions(+), 145 deletions(-) create mode 100644 src/mcp_agent/executor/temporal/interceptor.py create mode 100644 src/mcp_agent/executor/temporal/temporal_context.py diff --git a/examples/mcp_agent_server/temporal/client.py b/examples/mcp_agent_server/temporal/client.py index c2ef23bf8..198ce0789 100644 --- a/examples/mcp_agent_server/temporal/client.py +++ b/examples/mcp_agent_server/temporal/client.py @@ -1,4 +1,6 @@ +import asyncio import json +import time from mcp_agent.app import MCPApp from mcp_agent.config import MCPServerSettings from mcp_agent.executor.workflow import WorkflowExecution @@ -7,8 +9,8 @@ from datetime import timedelta from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession -from mcp.types import LoggingMessageNotificationParams from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp.types import CallToolResult, LoggingMessageNotificationParams try: from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport @@ -83,6 +85,17 @@ def make_session( f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" ) + get_status_result = await server.call_tool( + "workflows-BasicAgentWorkflow-get_status", + arguments={"run_id": run_id}, + ) + + execution = WorkflowExecution(**json.loads(run_result.content[0].text)) + run_id = execution.run_id + logger.info( + f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" + ) + # Wait for the workflow to complete while True: get_status_result = await server.call_tool( @@ -90,103 +103,88 @@ def make_session( arguments={"run_id": run_id}, ) - execution = WorkflowExecution( - **json.loads(run_result.content[0].text) - ) - run_id = execution.run_id + workflow_status = _tool_result_to_json(get_status_result) + if workflow_status is None: + logger.error( + f"Failed to parse workflow status response: {get_status_result}" + ) + break + logger.info( - f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" + f"Workflow run {run_id} status:", + data=workflow_status, ) - # Wait for the workflow to complete - while True: - get_status_result = await server.call_tool( - "workflows-BasicAgentWorkflow-get_status", - arguments={"run_id": run_id}, + if not workflow_status.get("status"): + logger.error( + f"Workflow run {run_id} status is empty. get_status_result:", + data=get_status_result, ) + break - workflow_status = _tool_result_to_json(get_status_result) - if workflow_status is None: - logger.error( - f"Failed to parse workflow status response: {get_status_result}" - ) - break + if workflow_status.get("status") == "completed": + logger.info( + f"Workflow run {run_id} completed successfully! Result:", + data=workflow_status.get("result"), + ) + break + elif workflow_status.get("status") == "error": + logger.error( + f"Workflow run {run_id} failed with error:", + data=workflow_status, + ) + break + elif workflow_status.get("status") == "running": logger.info( - f"Workflow run {run_id} status:", + f"Workflow run {run_id} is still running...", + ) + elif workflow_status.get("status") == "cancelled": + logger.error( + f"Workflow run {run_id} was cancelled.", + data=workflow_status, + ) + break + else: + logger.error( + f"Unknown workflow status: {workflow_status.get('status')}", data=workflow_status, ) + break - if not workflow_status.get("status"): - logger.error( - f"Workflow run {run_id} status is empty. get_status_result:", - data=get_status_result, - ) - break - - if workflow_status.get("status") == "completed": - logger.info( - f"Workflow run {run_id} completed successfully! Result:", - data=workflow_status.get("result"), - ) - - break - elif workflow_status.get("status") == "error": - logger.error( - f"Workflow run {run_id} failed with error:", - data=workflow_status, - ) - break - elif workflow_status.get("status") == "running": - logger.info( - f"Workflow run {run_id} is still running...", - ) - elif workflow_status.get("status") == "cancelled": - logger.error( - f"Workflow run {run_id} was cancelled.", - data=workflow_status, - ) - break - else: - logger.error( - f"Unknown workflow status: {workflow_status.get('status')}", - data=workflow_status, - ) - break - - await asyncio.sleep(5) - - # TODO: UNCOMMENT ME to try out cancellation: - # await server.call_tool( - # "workflows-cancel", - # arguments={"workflow_id": "BasicAgentWorkflow", "run_id": run_id}, - # ) - - print(run_result) - - # Call the sync tool 'finder_tool' (no run/status loop) - try: - finder_result = await server.call_tool( - "finder_tool", - arguments={ - "request": "Summarize the Model Context Protocol introduction from https://modelcontextprotocol.io/introduction." - }, + await asyncio.sleep(5) + + # TODO: UNCOMMENT ME to try out cancellation: + # await server.call_tool( + # "workflows-cancel", + # arguments={"workflow_id": "BasicAgentWorkflow", "run_id": run_id}, + # ) + + print(run_result) + + # Call the sync tool 'finder_tool' (no run/status loop) + try: + finder_result = await server.call_tool( + "finder_tool", + arguments={ + "request": "Summarize the Model Context Protocol introduction from https://modelcontextprotocol.io/introduction." + }, + ) + finder_payload = _tool_result_to_json(finder_result) or ( + ( + finder_result.structuredContent.get("result") + if getattr(finder_result, "structuredContent", None) + else None ) - finder_payload = _tool_result_to_json(finder_result) or ( - ( - finder_result.structuredContent.get("result") - if getattr(finder_result, "structuredContent", None) - else None - ) - or ( - finder_result.content[0].text - if getattr(finder_result, "content", None) - else None - ) + or ( + finder_result.content[0].text + if getattr(finder_result, "content", None) + else None ) - logger.info("finder_tool result:", data=finder_payload) - except Exception as e: - logger.error("finder_tool call failed", data=str(e)) + ) + logger.info("finder_tool result:", data=finder_payload) + except Exception as e: + logger.error("finder_tool call failed", data=str(e)) except Exception as e: # Tolerate benign shutdown races from SSE client (BrokenResourceError within ExceptionGroup) if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup): diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index c6a0a45c2..87697bdd0 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -195,12 +195,14 @@ def logger(self): try: if self._context is not None: self._logger._bound_context = self._context # type: ignore[attr-defined] + except Exception: pass else: # Update the logger's bound context in case upstream_session was set after logger creation if self._context and hasattr(self._logger, "_bound_context"): self._logger._bound_context = self._context + return self._logger async def initialize(self): diff --git a/src/mcp_agent/core/context.py b/src/mcp_agent/core/context.py index c43047bd9..d449c938a 100644 --- a/src/mcp_agent/core/context.py +++ b/src/mcp_agent/core/context.py @@ -92,7 +92,6 @@ class Context(BaseModel): # Dynamic gateway configuration (per-run overrides via Temporal memo) gateway_url: str | None = None gateway_token: str | None = None - execution_id: str | None = None model_config = ConfigDict( extra="allow", diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index e9e8fb55e..4863f97bf 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -38,6 +38,7 @@ from mcp_agent.logging.logger import get_logger from mcp_agent.utils.common import unwrap from mcp_agent.executor.temporal.system_activities import SystemActivities +from mcp_agent.executor.temporal.interceptor import ContextPropagationInterceptor if TYPE_CHECKING: from mcp_agent.app import MCPApp @@ -265,9 +266,9 @@ async def ensure_client(self): api_key=self.config.api_key, tls=self.config.tls, data_converter=pydantic_data_converter, - interceptors=[TracingInterceptor()] + interceptors=[TracingInterceptor(), ContextPropagationInterceptor()] if self.context.tracing_enabled - else [], + else [ContextPropagationInterceptor()], rpc_metadata=self.config.rpc_metadata or {}, ) diff --git a/src/mcp_agent/executor/temporal/interceptor.py b/src/mcp_agent/executor/temporal/interceptor.py new file mode 100644 index 000000000..93b441ba4 --- /dev/null +++ b/src/mcp_agent/executor/temporal/interceptor.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Mapping, Protocol, Type + +import temporalio.activity +import temporalio.api.common.v1 +import temporalio.client +import temporalio.converter +import temporalio.worker +import temporalio.workflow +from mcp_agent.logging.logger import get_logger +from mcp_agent.executor.temporal.temporal_context import ( + EXECUTION_ID_KEY, + get_execution_id, + set_execution_id, +) + + +class _InputWithHeaders(Protocol): + headers: Mapping[str, temporalio.api.common.v1.Payload] + + +logger = get_logger(__name__) + + +def set_header_from_context( + input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter +) -> None: + execution_id_val = get_execution_id() + + if execution_id_val: + input.headers = { + **input.headers, + EXECUTION_ID_KEY: payload_converter.to_payload(execution_id_val), + } + + +@contextmanager +def context_from_header( + input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter +): + execution_id_payload = input.headers.get(EXECUTION_ID_KEY) + execution_id_from_header = ( + payload_converter.from_payload(execution_id_payload, str) + if execution_id_payload + else None + ) + set_execution_id(execution_id_from_header if execution_id_from_header else None) + + yield + + +class ContextPropagationInterceptor( + temporalio.client.Interceptor, temporalio.worker.Interceptor +): + """Interceptor that propagates a value through client, workflow and activity calls. + + This interceptor implements methods `temporalio.client.Interceptor` and `temporalio.worker.Interceptor` so that + + (1) a user ID key is taken from context by the client code and sent in a header field with outbound requests + (2) workflows take this value from their task input, set it in context, and propagate it into the header field of + their outbound calls + (3) activities similarly take the value from their task input and set it in context so that it's available for their + outbound calls + """ + + def __init__( + self, + payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter, + ) -> None: + logger.info("ZZ Creating interceptor with payload converter") + self._payload_converter = payload_converter + + def intercept_client( + self, next: temporalio.client.OutboundInterceptor + ) -> temporalio.client.OutboundInterceptor: + logger.info("ZZ Creating client interceptor class") + return _ContextPropagationClientOutboundInterceptor( + next, self._payload_converter + ) + + def intercept_activity( + self, next: temporalio.worker.ActivityInboundInterceptor + ) -> temporalio.worker.ActivityInboundInterceptor: + logger.info("ZZ Creating activity interceptor class") + return _ContextPropagationActivityInboundInterceptor(next) + + def workflow_interceptor_class( + self, input: temporalio.worker.WorkflowInterceptorClassInput + ) -> Type[_ContextPropagationWorkflowInboundInterceptor]: + logger.info("ZZ Creating workflow interceptor class") + return _ContextPropagationWorkflowInboundInterceptor + + +class _ContextPropagationClientOutboundInterceptor( + temporalio.client.OutboundInterceptor +): + def __init__( + self, + next: temporalio.client.OutboundInterceptor, + payload_converter: temporalio.converter.PayloadConverter, + ) -> None: + super().__init__(next) + logger.info("ZZ Creating client outbound interceptor") + self._payload_converter = payload_converter + + async def start_workflow( + self, input: temporalio.client.StartWorkflowInput + ) -> temporalio.client.WorkflowHandle[Any, Any]: + set_header_from_context(input, self._payload_converter) + return await super().start_workflow(input) + + async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> Any: + set_header_from_context(input, self._payload_converter) + return await super().query_workflow(input) + + async def signal_workflow( + self, input: temporalio.client.SignalWorkflowInput + ) -> None: + set_header_from_context(input, self._payload_converter) + await super().signal_workflow(input) + + async def start_workflow_update( + self, input: temporalio.client.StartWorkflowUpdateInput + ) -> temporalio.client.WorkflowUpdateHandle[Any]: + set_header_from_context(input, self._payload_converter) + return await self.next.start_workflow_update(input) + + +class _ContextPropagationActivityInboundInterceptor( + temporalio.worker.ActivityInboundInterceptor +): + async def execute_activity( + self, input: temporalio.worker.ExecuteActivityInput + ) -> Any: + with context_from_header(input, temporalio.activity.payload_converter()): + return await self.next.execute_activity(input) + + +class _ContextPropagationWorkflowInboundInterceptor( + temporalio.worker.WorkflowInboundInterceptor +): + def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None: + logger.info("ZZ Creating worker inbound interceptor") + + self.next.init(_ContextPropagationWorkflowOutboundInterceptor(outbound)) + + async def execute_workflow( + self, input: temporalio.worker.ExecuteWorkflowInput + ) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.execute_workflow(input) + + async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_signal(input) + + async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_query(input) + + def handle_update_validator( + self, input: temporalio.worker.HandleUpdateInput + ) -> None: + with context_from_header(input, temporalio.workflow.payload_converter()): + self.next.handle_update_validator(input) + + async def handle_update_handler( + self, input: temporalio.worker.HandleUpdateInput + ) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_update_handler(input) + + +class _ContextPropagationWorkflowOutboundInterceptor( + temporalio.worker.WorkflowOutboundInterceptor +): + async def signal_child_workflow( + self, input: temporalio.worker.SignalChildWorkflowInput + ) -> None: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.signal_child_workflow(input) + + async def signal_external_workflow( + self, input: temporalio.worker.SignalExternalWorkflowInput + ) -> None: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.signal_external_workflow(input) + + def start_activity( + self, input: temporalio.worker.StartActivityInput + ) -> temporalio.workflow.ActivityHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return self.next.start_activity(input) + + async def start_child_workflow( + self, input: temporalio.worker.StartChildWorkflowInput + ) -> temporalio.workflow.ChildWorkflowHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.start_child_workflow(input) + + def start_local_activity( + self, input: temporalio.worker.StartLocalActivityInput + ) -> temporalio.workflow.ActivityHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return self.next.start_local_activity(input) diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index e29010f65..e12e2de3d 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -2,6 +2,7 @@ from mcp_agent.core.context import Context from mcp_agent.executor.temporal.system_activities import SystemActivities +from mcp_agent.executor.temporal.temporal_context import get_execution_id class SessionProxy: @@ -17,19 +18,10 @@ class SessionProxy: - request(method, params) """ - def __init__(self, *, executor, execution_id: str, context: Context): + def __init__(self, *, executor, context: Context): self._executor = executor - self._execution_id = execution_id self.sys_acts = SystemActivities(context) - @property - def execution_id(self) -> str: - return self._execution_id - - @execution_id.setter - def execution_id(self, value: str): - self._execution_id = value - async def send_log_message( self, *, @@ -51,16 +43,16 @@ async def send_log_message( # result = await self._executor.execute(self.sys_acts.relay_notify, self.execution_id, "notifications/message", params) # we can't. await self.sys_acts.relay_notify( - self.execution_id, "notifications/message", params + get_execution_id(), "notifications/message", params ) async def notify(self, method: str, params: Dict[str, Any] | None = None) -> bool: result = await self.sys_acts.relay_notify( - self.execution_id, method, params or {} + get_execution_id(), method, params or {} ) return bool(result) async def request( self, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: - await self.sys_acts.relay_request(self.execution_id, method, params or {}) + await self.sys_acts.relay_request(get_execution_id(), method, params or {}) diff --git a/src/mcp_agent/executor/temporal/temporal_context.py b/src/mcp_agent/executor/temporal/temporal_context.py new file mode 100644 index 000000000..c5fd4b2ba --- /dev/null +++ b/src/mcp_agent/executor/temporal/temporal_context.py @@ -0,0 +1,14 @@ +from typing import Optional + +EXECUTION_ID_KEY = "__execution_id" + +_execution_id: Optional[str] = None + + +def set_execution_id(execution_id: str) -> None: + global _execution_id + _execution_id = execution_id + + +def get_execution_id() -> Optional[str]: + return _execution_id diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 5f6b5b623..35bcad1d5 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -258,18 +258,13 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": if self.context.config.execution_engine == "temporal": setattr(self._logger, "_temporal_run_id", self._run_id) # Ensure upstream_session is a passthrough SessionProxy bound to this run - if ( - getattr(self.context, "upstream_session", None) is None - and self.context.execution_id - ): - try: - self.context.upstream_session = SessionProxy( - executor=self.executor, - execution_id=self.context.execution_id, - context=self.context, - ) - except Exception: - pass + upstream_session = getattr(self.context, "upstream_session", None) + + if upstream_session is None: + self.context.upstream_session = SessionProxy( + executor=self.executor, + context=self.context, + ) except Exception: pass @@ -838,10 +833,9 @@ async def initialize(self): if isinstance(memo_map, dict): gw = memo_map.get("gateway_url") gt = memo_map.get("gateway_token") - e_id = memo_map.get("execution_id") self._logger.debug( - f"Proxy parameters: gateway_url={gw}, gateway_token={gt}, execution_id={e_id}" + f"Proxy parameters: gateway_url={gw}, gateway_token={gt}" ) if gw: @@ -854,11 +848,6 @@ async def initialize(self): self.context.gateway_token = gt except Exception: pass - if e_id: - try: - self.context.execution_id = e_id - except Exception: - pass except Exception: # Safe to ignore if called outside workflow sandbox or memo unavailable pass @@ -869,21 +858,16 @@ async def initialize(self): upstream_session = getattr(self.context, "upstream_session", None) if upstream_session is None: - if self.context.execution_id: - self.context.upstream_session = SessionProxy( - executor=self.executor, - execution_id=self.context.execution_id, - context=self.context, - ) + self.context.upstream_session = SessionProxy( + executor=self.executor, + context=self.context, + ) - app = self.context.app - if app: - # Ensure the app's logger is bound to the current context with upstream_session - if app._logger and hasattr(app._logger, "_bound_context"): - app._logger._bound_context = self.context - elif self.context.execution_id: - # ensure the upstream session's execution_id is the current one. (We may be in a different workflow.) - upstream_session.execution_id = self.context.execution_id + app = self.context.app + if app: + # Ensure the app's logger is bound to the current context with upstream_session + if app._logger and hasattr(app._logger, "_bound_context"): + app._logger._bound_context = self.context except Exception: # Non-fatal if context is immutable early; will be set after run_id assignment in run_async pass diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index ca0ab6b9c..faf5b7117 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -15,6 +15,8 @@ from contextlib import asynccontextmanager, contextmanager +import temporalio + from mcp_agent.logging.events import ( Event, EventContext, @@ -72,13 +74,18 @@ def _emit_event(self, event: Event): asyncio.create_task(self.event_bus.emit(event)) else: # If no loop is running, run it until the emit completes - try: - loop.run_until_complete(self.event_bus.emit(event)) - except NotImplementedError: + if isinstance( + loop, temporalio.worker._workflow_instance._WorkflowInstanceImpl + ): # Handle Temporal workflow environment where run_until_complete() is not implemented # In Temporal, we can't block on async operations, so we'll need to avoid this # Simply log to stdout/stderr as a fallback self.event_bus.emit_with_stderr_transport(event) + else: + try: + loop.run_until_complete(self.event_bus.emit(event)) + except NotImplementedError: + pass def event( self, @@ -102,6 +109,7 @@ def event( # can forward reliably, regardless of the current task context. # 1) Prefer logger-bound app context (set at creation or refreshed by caller) extra_event_fields: Dict[str, Any] = {} + try: upstream = ( getattr(self._bound_context, "upstream_session", None) @@ -396,7 +404,7 @@ def get_logger(namespace: str, session_id: str | None = None, context=None) -> L with _logger_lock: existing = _loggers.get(namespace) if existing is None: - logger = Logger(namespace, session_id, bound_context=context) + logger = Logger(namespace, session_id, context) _loggers[namespace] = logger return logger @@ -405,4 +413,5 @@ def get_logger(namespace: str, session_id: str | None = None, context=None) -> L existing.session_id = session_id if context is not None: existing._bound_context = context + return existing diff --git a/src/mcp_agent/logging/transport.py b/src/mcp_agent/logging/transport.py index c067712fe..c795c47a8 100644 --- a/src/mcp_agent/logging/transport.py +++ b/src/mcp_agent/logging/transport.py @@ -447,7 +447,6 @@ def emit_with_stderr_transport(self, event: Event): self._running = True self._task = asyncio.create_task(self._process_events()) - # Then queue for listeners self._queue.put_nowait(event) async def _send_to_transport(self, event: Event): diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index cf8ac2e4a..a5aa5cf61 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -25,6 +25,7 @@ WorkflowRegistry, InMemoryWorkflowRegistry, ) +from mcp_agent.executor.temporal.temporal_context import set_execution_id from mcp_agent.logging.logger import get_logger from mcp_agent.logging.logger import LoggingConfig from mcp_agent.mcp.mcp_server_registry import ServerRegistry @@ -1354,6 +1355,7 @@ async def _workflow_run( # Generate a unique execution ID to track this run. We need to pass this to the workflow, and the run_id is only established # after we create the workflow execution_id = str(uuid.uuid4()) + set_execution_id(execution_id) # Resolve workflows and app context irrespective of startup mode # This now returns a context with upstream_session already set @@ -1496,7 +1498,8 @@ async def _workflow_status( try: state = str(status.get("status", "")).lower() if state in ("completed", "error", "cancelled"): - await _unregister_session(run_id) + # await _unregister_session(run_id) + pass except Exception: pass diff --git a/tests/test_app.py b/tests/test_app.py index b6a3e8993..677163457 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -414,7 +414,7 @@ async def test_logger_property(self, basic_app): # First call creates the logger assert basic_app.logger is mock_logger mock_get_logger.assert_called_once_with( - f"mcp_agent.{basic_app.name}", session_id=None, context=None + f"mcp_agent.{basic_app.name}", session_id=None ) # Reset mock @@ -450,9 +450,7 @@ async def test_logger_property_with_session_id(self, basic_app, mock_context): # Get the logger - this should call get_logger with the session_id assert basic_app.logger is mock_logger mock_get_logger.assert_called_once_with( - f"mcp_agent.{basic_app.name}", - session_id=mock_context.session_id, - context=mock_context, + f"mcp_agent.{basic_app.name}", session_id=mock_context.session_id ) # From 6135e951ad00cdca18e2a6080a67ea89015682ce Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Sep 2025 09:47:49 -0400 Subject: [PATCH 17/24] Fixes to execution ID, make proxy a proper ServerSession type (#419) --- .../executor/temporal/interceptor.py | 7 - .../executor/temporal/session_proxy.py | 314 ++++++++++++++++-- .../executor/temporal/temporal_context.py | 10 +- src/mcp_agent/server/app_server.py | 29 +- 4 files changed, 294 insertions(+), 66 deletions(-) diff --git a/src/mcp_agent/executor/temporal/interceptor.py b/src/mcp_agent/executor/temporal/interceptor.py index 93b441ba4..a085f731f 100644 --- a/src/mcp_agent/executor/temporal/interceptor.py +++ b/src/mcp_agent/executor/temporal/interceptor.py @@ -69,13 +69,11 @@ def __init__( self, payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter, ) -> None: - logger.info("ZZ Creating interceptor with payload converter") self._payload_converter = payload_converter def intercept_client( self, next: temporalio.client.OutboundInterceptor ) -> temporalio.client.OutboundInterceptor: - logger.info("ZZ Creating client interceptor class") return _ContextPropagationClientOutboundInterceptor( next, self._payload_converter ) @@ -83,13 +81,11 @@ def intercept_client( def intercept_activity( self, next: temporalio.worker.ActivityInboundInterceptor ) -> temporalio.worker.ActivityInboundInterceptor: - logger.info("ZZ Creating activity interceptor class") return _ContextPropagationActivityInboundInterceptor(next) def workflow_interceptor_class( self, input: temporalio.worker.WorkflowInterceptorClassInput ) -> Type[_ContextPropagationWorkflowInboundInterceptor]: - logger.info("ZZ Creating workflow interceptor class") return _ContextPropagationWorkflowInboundInterceptor @@ -102,7 +98,6 @@ def __init__( payload_converter: temporalio.converter.PayloadConverter, ) -> None: super().__init__(next) - logger.info("ZZ Creating client outbound interceptor") self._payload_converter = payload_converter async def start_workflow( @@ -142,8 +137,6 @@ class _ContextPropagationWorkflowInboundInterceptor( temporalio.worker.WorkflowInboundInterceptor ): def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None: - logger.info("ZZ Creating worker inbound interceptor") - self.next.init(_ContextPropagationWorkflowOutboundInterceptor(outbound)) async def execute_workflow( diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index e12e2de3d..18b30d1a0 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -1,58 +1,308 @@ -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Any, Dict, List, Type + +import anyio +import mcp.types as types +from anyio.streams.memory import ( + MemoryObjectReceiveStream, + MemoryObjectSendStream, +) +from temporalio import workflow as _twf + +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession as _BaseServerSession +from mcp.shared.message import ServerMessageMetadata from mcp_agent.core.context import Context from mcp_agent.executor.temporal.system_activities import SystemActivities from mcp_agent.executor.temporal.temporal_context import get_execution_id -class SessionProxy: +class SessionProxy(_BaseServerSession): """ - A 'virtual' MCP ServerSession bound to a Temporal workflow run. + SessionProxy acts like an MCP `ServerSession` for code running under the + Temporal engine. It forwards server->client messages through the MCPApp + gateway so that logs, notifications, and requests reach the original + upstream MCP client. - This proxy exposes a subset of the ServerSession API and routes calls - through generic Temporal activities to keep workflow code deterministic. + Behavior: + - Inside a Temporal workflow (deterministic scope), all network I/O is + performed via registered Temporal activities. + - Outside a workflow (e.g., inside an activity or plain asyncio code), + calls are executed directly using the SystemActivities helpers. - Methods: - - send_log_message(level, data, logger=None, related_request_id=None) - - notify(method, params) - - request(method, params) + This keeps workflow logic deterministic while remaining a drop-in proxy + for the common ServerSession methods used by the agent runtime. """ - def __init__(self, *, executor, context: Context): + def __init__(self, *, executor, context: Context) -> None: + # Create inert in-memory streams to satisfy base constructor. We do not + # use these streams; all communication is proxied via HTTP gateway. + send_read, recv_read = anyio.create_memory_object_stream(0) + send_write, recv_write = anyio.create_memory_object_stream(0) + + init_opts = InitializationOptions( + server_name="mcp_agent_proxy", + server_version="0.0.0", + capabilities=types.ServerCapabilities(), + instructions=None, + ) + # Initialize base class in stateless mode to skip handshake state + super().__init__( + recv_read, # type: ignore[arg-type] + send_write, # type: ignore[arg-type] + init_opts, + stateless=True, + ) + + # Keep references so streams aren't GC'd + self._dummy_streams: tuple[ + MemoryObjectSendStream[Any], + MemoryObjectReceiveStream[Any], + MemoryObjectSendStream[Any], + MemoryObjectReceiveStream[Any], + ] = (send_read, recv_read, send_write, recv_write) + self._executor = executor - self.sys_acts = SystemActivities(context) + self._context = context + # Local helper used when we're not inside a workflow runtime + self._sys_acts = SystemActivities(context) + # Provide a low-level RPC facade similar to real ServerSession + self.rpc = _RPC(self) + + # ---------------------- + # Generic passthroughs + # ---------------------- + async def notify(self, method: str, params: Dict[str, Any] | None = None) -> bool: + """Send a server->client notification via the gateway. + + Returns True on best-effort success. + """ + exec_id = get_execution_id() + if not exec_id: + return False + + if _in_workflow_runtime(): + try: + act = self._context.task_registry.get_activity("mcp_relay_notify") + await self._executor.execute(act, exec_id, method, params or {}) + return True + except Exception: + return False + # Non-workflow (activity/asyncio) + return bool(await self._sys_acts.relay_notify(exec_id, method, params or {})) + + async def request( + self, method: str, params: Dict[str, Any] | None = None + ) -> Dict[str, Any]: + """Send a server->client request and return the client's response. + The result is a plain JSON-serializable dict. + """ + exec_id = get_execution_id() + if not exec_id: + return {"error": "missing_execution_id"} + + if _in_workflow_runtime(): + act = self._context.task_registry.get_activity("mcp_relay_request") + return await self._executor.execute(act, exec_id, method, params or {}) + return await self._sys_acts.relay_request(exec_id, method, params or {}) + + # ---------------------- + # ServerSession-like API + # ---------------------- + async def send_notification( + self, + notification: types.ServerNotification, + related_request_id: types.RequestId | None = None, + ) -> None: + root = notification.root + params: Dict[str, Any] | None = None + try: + if getattr(root, "params", None) is not None: + params = root.params.model_dump(by_alias=True, mode="json") # type: ignore[attr-defined] + else: + params = {} + except Exception: + params = {} + # Best-effort pass-through of related_request_id when provided + if related_request_id is not None: + params = dict(params or {}) + params["related_request_id"] = related_request_id + await self.notify(root.method, params) # type: ignore[attr-defined] + + async def send_request( + self, + request: types.ServerRequest, + result_type: Type[Any], + metadata: ServerMessageMetadata | None = None, + ) -> Any: + root = request.root + params: Dict[str, Any] | None = None + try: + if getattr(root, "params", None) is not None: + params = root.params.model_dump(by_alias=True, mode="json") # type: ignore[attr-defined] + else: + params = {} + except Exception: + params = {} + # Note: metadata (e.g., related_request_id) is handled server-side where applicable + payload = await self.request(root.method, params) # type: ignore[attr-defined] + # Attempt to validate into the requested result type + try: + return result_type.model_validate(payload) # type: ignore[attr-defined] + except Exception: + return payload async def send_log_message( self, - *, - level: str, - data: Dict[str, Any] | Any, - logger: Optional[str] = None, - related_request_id: Optional[str] = None, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, + ) -> None: + """Best-effort log forwarding to the client's UI.""" + # Prefer activity-based forwarding inside workflow for determinism + exec_id = get_execution_id() + if _in_workflow_runtime() and exec_id: + try: + act = self._context.task_registry.get_activity("mcp_forward_log") + namespace = ( + (data or {}).get("namespace") + if isinstance(data, dict) + else (logger or "mcp_agent") + ) + message = (data or {}).get("message") if isinstance(data, dict) else "" + await self._executor.execute( + act, + exec_id, + str(level), + namespace or (logger or "mcp_agent"), + message or "", + (data or {}), + ) + return + except Exception: + # Fall back to notify path below + pass + + params: Dict[str, Any] = {"level": str(level), "data": data, "logger": logger} + if related_request_id is not None: + params["related_request_id"] = related_request_id + await self.notify("notifications/message", params) + + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + related_request_id: str | None = None, ) -> None: - # Map to notifications/message via generic relay params: Dict[str, Any] = { - "level": level, - "data": data, - "logger": logger, + "progressToken": progress_token, + "progress": progress, } + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message if related_request_id is not None: params["related_request_id"] = related_request_id + await self.notify("notifications/progress", params) - # We are outside of the temporal loop. So even though we'd like to do something like - # result = await self._executor.execute(self.sys_acts.relay_notify, self.execution_id, "notifications/message", params) - # we can't. - await self.sys_acts.relay_notify( - get_execution_id(), "notifications/message", params - ) + async def send_resource_updated(self, uri: types.AnyUrl) -> None: + await self.notify("notifications/resources/updated", {"uri": str(uri)}) - async def notify(self, method: str, params: Dict[str, Any] | None = None) -> bool: - result = await self.sys_acts.relay_notify( - get_execution_id(), method, params or {} - ) - return bool(result) + async def send_resource_list_changed(self) -> None: + await self.notify("notifications/resources/list_changed", {}) + + async def send_tool_list_changed(self) -> None: + await self.notify("notifications/tools/list_changed", {}) + + async def send_prompt_list_changed(self) -> None: + await self.notify("notifications/prompts/list_changed", {}) + + async def send_ping(self) -> types.EmptyResult: + result = await self.request("ping", {}) + return types.EmptyResult.model_validate(result) + + async def list_roots(self) -> types.ListRootsResult: + result = await self.request("roots/list", {}) + return types.ListRootsResult.model_validate(result) + + async def create_message( + self, + messages: List[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: List[str] | None = None, + metadata: Dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResult: + params: Dict[str, Any] = { + "messages": [m.model_dump(by_alias=True, mode="json") for m in messages], + "maxTokens": max_tokens, + } + if system_prompt is not None: + params["systemPrompt"] = system_prompt + if include_context is not None: + params["includeContext"] = include_context + if temperature is not None: + params["temperature"] = temperature + if stop_sequences is not None: + params["stopSequences"] = stop_sequences + if metadata is not None: + params["metadata"] = metadata + if model_preferences is not None: + params["modelPreferences"] = model_preferences.model_dump( + by_alias=True, mode="json" + ) + if related_request_id is not None: + # Threading ID through JSON-RPC metadata is handled by gateway; include for completeness + params["related_request_id"] = related_request_id + + result = await self.request("sampling/createMessage", params) + return types.CreateMessageResult.model_validate(result) + + async def elicit( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + params: Dict[str, Any] = { + "message": message, + "requestedSchema": requestedSchema, + } + if related_request_id is not None: + params["related_request_id"] = related_request_id + result = await self.request("elicitation/create", params) + return types.ElicitResult.model_validate(result) + + +def _in_workflow_runtime() -> bool: + """Return True if currently executing inside a Temporal workflow sandbox.""" + try: + return _twf._Runtime.current() is not None # type: ignore[attr-defined] + except Exception: + return False + + +class _RPC: + """Lightweight facade to mimic the low-level RPC interface on sessions.""" + + def __init__(self, proxy: SessionProxy) -> None: + self._proxy = proxy + + async def notify(self, method: str, params: Dict[str, Any] | None = None) -> None: + await self._proxy.notify(method, params or {}) async def request( self, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: - await self.sys_acts.relay_request(get_execution_id(), method, params or {}) + return await self._proxy.request(method, params or {}) diff --git a/src/mcp_agent/executor/temporal/temporal_context.py b/src/mcp_agent/executor/temporal/temporal_context.py index c5fd4b2ba..c34b53569 100644 --- a/src/mcp_agent/executor/temporal/temporal_context.py +++ b/src/mcp_agent/executor/temporal/temporal_context.py @@ -2,13 +2,13 @@ EXECUTION_ID_KEY = "__execution_id" -_execution_id: Optional[str] = None +_EXECUTION_ID: str | None = None -def set_execution_id(execution_id: str) -> None: - global _execution_id - _execution_id = execution_id +def set_execution_id(execution_id: Optional[str]) -> None: + global _EXECUTION_ID + _EXECUTION_ID = execution_id def get_execution_id() -> Optional[str]: - return _execution_id + return _EXECUTION_ID diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index a5aa5cf61..09226f165 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -25,7 +25,9 @@ WorkflowRegistry, InMemoryWorkflowRegistry, ) -from mcp_agent.executor.temporal.temporal_context import set_execution_id +from mcp_agent.executor.temporal.temporal_context import ( + set_execution_id, +) from mcp_agent.logging.logger import get_logger from mcp_agent.logging.logger import LoggingConfig from mcp_agent.mcp.mcp_server_registry import ServerRegistry @@ -628,25 +630,6 @@ async def _set_level( # If handler registration fails, continue without dynamic level updates pass - # Register logging/setLevel handler so client can adjust verbosity dynamically - # This enables MCP logging capability in InitializeResult.capabilities.logging - lowlevel_server = getattr(mcp, "_mcp_server", None) - try: - if lowlevel_server is not None: - - @lowlevel_server.set_logging_level() - async def _set_level( - level: str, - ) -> None: # mcp.types.LoggingLevel is a Literal[str] - try: - LoggingConfig.set_min_level(level) - except Exception: - # Best-effort, do not crash server on invalid level - pass - except Exception: - # If handler registration fails, continue without dynamic level updates - pass - # region Workflow Tools @mcp.tool(name="workflows-list") @@ -1498,8 +1481,10 @@ async def _workflow_status( try: state = str(status.get("status", "")).lower() if state in ("completed", "error", "cancelled"): - # await _unregister_session(run_id) - pass + try: + await _unregister_session(run_id) + except Exception: + pass except Exception: pass From 015395db601207051e34fda0a5cd2ebe72c394d2 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Sep 2025 11:00:54 -0400 Subject: [PATCH 18/24] run formatter --- .../cli/cloud/commands/auth/whoami/main.py | 1 - .../cli/cloud/commands/logger/__init__.py | 2 +- .../commands/logger/configure/__init__.py | 2 +- .../cloud/commands/logger/configure/main.py | 76 +++-- .../cloud/commands/logger/tail/__init__.py | 2 +- .../cli/cloud/commands/logger/tail/main.py | 310 ++++++++++-------- .../cli/cloud/commands/logger/utils.py | 52 +-- src/mcp_agent/mcp/mcp_aggregator.py | 29 +- tests/mcp/test_mcp_aggregator.py | 267 ++++++++++----- 9 files changed, 459 insertions(+), 282 deletions(-) diff --git a/src/mcp_agent/cli/cloud/commands/auth/whoami/main.py b/src/mcp_agent/cli/cloud/commands/auth/whoami/main.py index 7965fdef8..14e3fa2a6 100644 --- a/src/mcp_agent/cli/cloud/commands/auth/whoami/main.py +++ b/src/mcp_agent/cli/cloud/commands/auth/whoami/main.py @@ -1,6 +1,5 @@ """MCP Agent Cloud whoami command implementation.""" - from rich.console import Console from rich.panel import Panel from rich.table import Table diff --git a/src/mcp_agent/cli/cloud/commands/logger/__init__.py b/src/mcp_agent/cli/cloud/commands/logger/__init__.py index bb0332d7a..66805dc7e 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/__init__.py +++ b/src/mcp_agent/cli/cloud/commands/logger/__init__.py @@ -6,4 +6,4 @@ from .tail.main import tail_logs -__all__ = ["tail_logs"] \ No newline at end of file +__all__ = ["tail_logs"] diff --git a/src/mcp_agent/cli/cloud/commands/logger/configure/__init__.py b/src/mcp_agent/cli/cloud/commands/logger/configure/__init__.py index bcfed9693..988a5e789 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/configure/__init__.py +++ b/src/mcp_agent/cli/cloud/commands/logger/configure/__init__.py @@ -2,4 +2,4 @@ from .main import configure_logger -__all__ = ["configure_logger"] \ No newline at end of file +__all__ = ["configure_logger"] diff --git a/src/mcp_agent/cli/cloud/commands/logger/configure/main.py b/src/mcp_agent/cli/cloud/commands/logger/configure/main.py index 448cc0e63..fff24b2ff 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/configure/main.py +++ b/src/mcp_agent/cli/cloud/commands/logger/configure/main.py @@ -32,10 +32,10 @@ def configure_logger( ), ) -> None: """Configure OTEL endpoint and headers for log collection. - + This command allows you to configure the OpenTelemetry endpoint and headers that will be used for collecting logs from your deployed MCP apps. - + Examples: mcp-agent cloud logger configure https://otel.example.com:4318/v1/logs mcp-agent cloud logger configure https://otel.example.com --headers "Authorization=Bearer token,X-Custom=value" @@ -44,9 +44,9 @@ def configure_logger( if not endpoint and not test: console.print("[red]Error: Must specify endpoint or use --test[/red]") raise typer.Exit(1) - + config_path = _find_config_file() - + if test: if config_path and config_path.exists(): config = _load_config(config_path) @@ -54,7 +54,9 @@ def configure_logger( endpoint = otel_config.get("endpoint") headers_dict = otel_config.get("headers", {}) else: - console.print("[yellow]No configuration file found. Use --endpoint to set up OTEL configuration.[/yellow]") + console.print( + "[yellow]No configuration file found. Use --endpoint to set up OTEL configuration.[/yellow]" + ) raise typer.Exit(1) else: headers_dict = {} @@ -64,54 +66,68 @@ def configure_logger( key, value = header_pair.strip().split("=", 1) headers_dict[key.strip()] = value.strip() except ValueError: - console.print("[red]Error: Headers must be in format 'key=value,key2=value2'[/red]") + console.print( + "[red]Error: Headers must be in format 'key=value,key2=value2'[/red]" + ) raise typer.Exit(1) - + if endpoint: console.print(f"[blue]Testing connection to {endpoint}...[/blue]") - + try: with httpx.Client(timeout=10.0) as client: response = client.get( - endpoint.replace("/v1/logs", "/health") if "/v1/logs" in endpoint else f"{endpoint}/health", - headers=headers_dict + endpoint.replace("/v1/logs", "/health") + if "/v1/logs" in endpoint + else f"{endpoint}/health", + headers=headers_dict, ) - - if response.status_code in [200, 404]: # 404 is fine, means endpoint exists + + if response.status_code in [ + 200, + 404, + ]: # 404 is fine, means endpoint exists console.print("[green]✓ Connection successful[/green]") else: - console.print(f"[yellow]⚠ Got status {response.status_code}, but endpoint is reachable[/yellow]") - + console.print( + f"[yellow]⚠ Got status {response.status_code}, but endpoint is reachable[/yellow]" + ) + except httpx.RequestError as e: console.print(f"[red]✗ Connection failed: {e}[/red]") if not test: - console.print("[yellow]Configuration will be saved anyway. Check your endpoint URL and network connection.[/yellow]") - + console.print( + "[yellow]Configuration will be saved anyway. Check your endpoint URL and network connection.[/yellow]" + ) + if not test: if not config_path: config_path = Path.cwd() / "mcp_agent.config.yaml" - + config = _load_config(config_path) if config_path.exists() else {} - + if "otel" not in config: config["otel"] = {} - + config["otel"]["endpoint"] = endpoint config["otel"]["headers"] = headers_dict - + try: config_path.parent.mkdir(parents=True, exist_ok=True) with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) - - console.print(Panel( - f"[green]✓ OTEL configuration saved to {config_path}[/green]\n\n" - f"Endpoint: {endpoint}\n" - f"Headers: {len(headers_dict)} configured" + (f" ({', '.join(headers_dict.keys())})" if headers_dict else ""), - title="Configuration Saved", - border_style="green" - )) - + + console.print( + Panel( + f"[green]✓ OTEL configuration saved to {config_path}[/green]\n\n" + f"Endpoint: {endpoint}\n" + f"Headers: {len(headers_dict)} configured" + + (f" ({', '.join(headers_dict.keys())})" if headers_dict else ""), + title="Configuration Saved", + border_style="green", + ) + ) + except Exception as e: console.print(f"[red]Error saving configuration: {e}[/red]") raise typer.Exit(1) @@ -134,4 +150,4 @@ def _load_config(config_path: Path) -> dict: with open(config_path, "r") as f: return yaml.safe_load(f) or {} except Exception as e: - raise CLIError(f"Failed to load config from {config_path}: {e}") \ No newline at end of file + raise CLIError(f"Failed to load config from {config_path}: {e}") diff --git a/src/mcp_agent/cli/cloud/commands/logger/tail/__init__.py b/src/mcp_agent/cli/cloud/commands/logger/tail/__init__.py index 36a27364c..81c22db0a 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/tail/__init__.py +++ b/src/mcp_agent/cli/cloud/commands/logger/tail/__init__.py @@ -2,4 +2,4 @@ from .main import tail_logs -__all__ = ["tail_logs"] \ No newline at end of file +__all__ = ["tail_logs"] diff --git a/src/mcp_agent/cli/cloud/commands/logger/tail/main.py b/src/mcp_agent/cli/cloud/commands/logger/tail/main.py index b9ecf5368..d99c6df70 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/tail/main.py +++ b/src/mcp_agent/cli/cloud/commands/logger/tail/main.py @@ -18,7 +18,10 @@ from mcp_agent.cli.exceptions import CLIError from mcp_agent.cli.auth import load_credentials, UserCredentials from mcp_agent.cli.core.constants import DEFAULT_API_BASE_URL -from mcp_agent.cli.cloud.commands.logger.utils import parse_app_identifier, resolve_server_url +from mcp_agent.cli.cloud.commands.logger.utils import ( + parse_app_identifier, + resolve_server_url, +) console = Console() @@ -73,89 +76,103 @@ def tail_logs( ), ) -> None: """Tail logs for an MCP app deployment. - + Retrieve and optionally stream logs from deployed MCP apps. Supports filtering by time duration, text patterns, and continuous streaming. - + Examples: # Get last 50 logs from an app mcp-agent cloud logger tail app_abc123 --limit 50 - + # Stream logs continuously mcp-agent cloud logger tail https://app.mcpac.dev/abc123 --follow - + # Show logs from the last hour with error filtering mcp-agent cloud logger tail app_abc123 --since 1h --grep "ERROR|WARN" - + # Follow logs and filter for specific patterns mcp-agent cloud logger tail app_abc123 --follow --grep "authentication.*failed" """ - + credentials = load_credentials() if not credentials: - console.print("[red]Error: Not authenticated. Run 'mcp-agent login' first.[/red]") + console.print( + "[red]Error: Not authenticated. Run 'mcp-agent login' first.[/red]" + ) raise typer.Exit(4) - + # Validate conflicting options if follow and since: - console.print("[red]Error: --since cannot be used with --follow (streaming mode)[/red]") + console.print( + "[red]Error: --since cannot be used with --follow (streaming mode)[/red]" + ) raise typer.Exit(6) - + if follow and limit != DEFAULT_LOG_LIMIT: - console.print("[red]Error: --limit cannot be used with --follow (streaming mode)[/red]") + console.print( + "[red]Error: --limit cannot be used with --follow (streaming mode)[/red]" + ) raise typer.Exit(6) - + if follow and order_by: - console.print("[red]Error: --order-by cannot be used with --follow (streaming mode)[/red]") + console.print( + "[red]Error: --order-by cannot be used with --follow (streaming mode)[/red]" + ) raise typer.Exit(6) - + if follow and (asc or desc): - console.print("[red]Error: --asc/--desc cannot be used with --follow (streaming mode)[/red]") + console.print( + "[red]Error: --asc/--desc cannot be used with --follow (streaming mode)[/red]" + ) raise typer.Exit(6) - + # Validate order_by values if order_by and order_by not in ["timestamp", "severity"]: console.print("[red]Error: --order-by must be 'timestamp' or 'severity'[/red]") raise typer.Exit(6) - + # Validate that both --asc and --desc are not used together if asc and desc: console.print("[red]Error: Cannot use both --asc and --desc together[/red]") raise typer.Exit(6) - + # Validate format values if format and format not in ["text", "json", "yaml"]: console.print("[red]Error: --format must be 'text', 'json', or 'yaml'[/red]") raise typer.Exit(6) - + app_id, config_id, server_url = parse_app_identifier(app_identifier) - + try: if follow: - asyncio.run(_stream_logs( - app_id=app_id, - config_id=config_id, - server_url=server_url, - credentials=credentials, - grep_pattern=grep, - app_identifier=app_identifier, - format=format, - )) + asyncio.run( + _stream_logs( + app_id=app_id, + config_id=config_id, + server_url=server_url, + credentials=credentials, + grep_pattern=grep, + app_identifier=app_identifier, + format=format, + ) + ) else: - asyncio.run(_fetch_logs( - app_id=app_id, - config_id=config_id, - server_url=server_url, - credentials=credentials, - since=since, - grep_pattern=grep, - limit=limit, - order_by=order_by, - asc=asc, - desc=desc, - format=format, - )) - + asyncio.run( + _fetch_logs( + app_id=app_id, + config_id=config_id, + server_url=server_url, + credentials=credentials, + since=since, + grep_pattern=grep, + limit=limit, + order_by=order_by, + asc=asc, + desc=desc, + format=format, + ) + ) + except KeyboardInterrupt: console.print("\n[yellow]Interrupted by user[/yellow]") sys.exit(0) @@ -166,7 +183,7 @@ def tail_logs( async def _fetch_logs( app_id: Optional[str], - config_id: Optional[str], + config_id: Optional[str], server_url: Optional[str], credentials: UserCredentials, since: Optional[str], @@ -178,38 +195,40 @@ async def _fetch_logs( format: str, ) -> None: """Fetch logs one-time via HTTP API.""" - + api_base = DEFAULT_API_BASE_URL headers = { "Authorization": f"Bearer {credentials.api_key}", "Content-Type": "application/json", } - + payload = {} - + if app_id: payload["app_id"] = app_id elif config_id: payload["app_configuration_id"] = config_id else: - raise CLIError("Unable to determine app or configuration ID from provided identifier") - + raise CLIError( + "Unable to determine app or configuration ID from provided identifier" + ) + if since: payload["since"] = since if limit: payload["limit"] = limit - + if order_by: if order_by == "timestamp": payload["orderBy"] = "LOG_ORDER_BY_TIMESTAMP" elif order_by == "severity": payload["orderBy"] = "LOG_ORDER_BY_LEVEL" - + if asc: payload["order"] = "LOG_ORDER_ASC" elif desc: payload["order"] = "LOG_ORDER_DESC" - + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -217,7 +236,7 @@ async def _fetch_logs( transient=True, ) as progress: progress.add_task("Fetching logs...", total=None) - + try: async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( @@ -225,128 +244,148 @@ async def _fetch_logs( json=payload, headers=headers, ) - + if response.status_code == 401: - raise CLIError("Authentication failed. Try running 'mcp-agent login'") + raise CLIError( + "Authentication failed. Try running 'mcp-agent login'" + ) elif response.status_code == 404: raise CLIError("App or configuration not found") elif response.status_code != 200: - raise CLIError(f"API request failed: {response.status_code} {response.text}") - + raise CLIError( + f"API request failed: {response.status_code} {response.text}" + ) + data = response.json() log_entries = data.get("logEntries", []) - + except httpx.RequestError as e: raise CLIError(f"Failed to connect to API: {e}") - - filtered_logs = _filter_logs(log_entries, grep_pattern) if grep_pattern else log_entries - + + filtered_logs = ( + _filter_logs(log_entries, grep_pattern) if grep_pattern else log_entries + ) + if not filtered_logs: console.print("[yellow]No logs found matching the criteria[/yellow]") return - - _display_logs(filtered_logs, title=f"Logs for {app_id or config_id}", format=format) + _display_logs(filtered_logs, title=f"Logs for {app_id or config_id}", format=format) async def _stream_logs( app_id: Optional[str], config_id: Optional[str], - server_url: Optional[str], + server_url: Optional[str], credentials: UserCredentials, grep_pattern: Optional[str], app_identifier: str, format: str, ) -> None: """Stream logs continuously via SSE.""" - + if not server_url: server_url = await resolve_server_url(app_id, config_id, credentials) - + parsed = urlparse(server_url) stream_url = f"{parsed.scheme}://{parsed.netloc}/logs" hostname = parsed.hostname or "" - deployment_id = hostname.split('.')[0] if '.' in hostname else hostname - + deployment_id = hostname.split(".")[0] if "." in hostname else hostname + headers = { "Accept": "text/event-stream", "Cache-Control": "no-cache", "X-Routing-Key": deployment_id, } - + if credentials.api_key: headers["Authorization"] = f"Bearer {credentials.api_key}" - - console.print(f"[blue]Streaming logs from {app_identifier} (Press Ctrl+C to stop)[/blue]") - + + console.print( + f"[blue]Streaming logs from {app_identifier} (Press Ctrl+C to stop)[/blue]" + ) + # Setup signal handler for graceful shutdown def signal_handler(signum, frame): console.print("\n[yellow]Stopping log stream...[/yellow]") sys.exit(0) - + signal.signal(signal.SIGINT, signal_handler) - + try: async with httpx.AsyncClient(timeout=None) as client: async with client.stream("GET", stream_url, headers=headers) as response: - if response.status_code == 401: - raise CLIError("Authentication failed. Try running 'mcp-agent login'") + raise CLIError( + "Authentication failed. Try running 'mcp-agent login'" + ) elif response.status_code == 404: raise CLIError("Log stream not found for the specified app") elif response.status_code != 200: - raise CLIError(f"Failed to connect to log stream: {response.status_code}") - + raise CLIError( + f"Failed to connect to log stream: {response.status_code}" + ) + console.print("[green]✓ Connected to log stream[/green]\n") - + buffer = "" async for chunk in response.aiter_text(): buffer += chunk - lines = buffer.split('\n') - + lines = buffer.split("\n") + for line in lines[:-1]: - if line.startswith('data:'): - data_content = line.removeprefix('data:') - + if line.startswith("data:"): + data_content = line.removeprefix("data:") + try: log_data = json.loads(data_content) - - if 'message' in log_data: - timestamp = log_data.get('time') + + if "message" in log_data: + timestamp = log_data.get("time") if timestamp: - formatted_timestamp = _convert_timestamp_to_local(timestamp) + formatted_timestamp = ( + _convert_timestamp_to_local(timestamp) + ) else: formatted_timestamp = datetime.now().isoformat() - + log_entry = { - 'timestamp': formatted_timestamp, - 'message': log_data['message'], - 'level': log_data.get('level', 'INFO') + "timestamp": formatted_timestamp, + "message": log_data["message"], + "level": log_data.get("level", "INFO"), } - - if not grep_pattern or _matches_pattern(log_entry['message'], grep_pattern): + + if not grep_pattern or _matches_pattern( + log_entry["message"], grep_pattern + ): _display_log_entry(log_entry, format=format) - + except json.JSONDecodeError: # Skip malformed JSON continue - + except httpx.RequestError as e: raise CLIError(f"Failed to connect to log stream: {e}") - - -def _filter_logs(log_entries: List[Dict[str, Any]], pattern: str) -> List[Dict[str, Any]]: +def _filter_logs( + log_entries: List[Dict[str, Any]], pattern: str +) -> List[Dict[str, Any]]: """Filter log entries by pattern.""" if not pattern: return log_entries - + try: regex = re.compile(pattern, re.IGNORECASE) - return [entry for entry in log_entries if regex.search(entry.get('message', ''))] + return [ + entry for entry in log_entries if regex.search(entry.get("message", "")) + ] except re.error: - return [entry for entry in log_entries if pattern.lower() in entry.get('message', '').lower()] + return [ + entry + for entry in log_entries + if pattern.lower() in entry.get("message", "").lower() + ] def _matches_pattern(message: str, pattern: str) -> bool: @@ -361,21 +400,21 @@ def _matches_pattern(message: str, pattern: str) -> bool: def _clean_log_entry(entry: Dict[str, Any]) -> Dict[str, Any]: """Clean up a log entry for structured output formats.""" cleaned_entry = entry.copy() - cleaned_entry['severity'] = _parse_log_level(entry.get('level', 'INFO')) - cleaned_entry['message'] = _clean_message(entry.get('message', '')) - cleaned_entry.pop('level', None) + cleaned_entry["severity"] = _parse_log_level(entry.get("level", "INFO")) + cleaned_entry["message"] = _clean_message(entry.get("message", "")) + cleaned_entry.pop("level", None) return cleaned_entry def _display_text_log_entry(entry: Dict[str, Any]) -> None: """Display a single log entry in text format.""" - timestamp = _format_timestamp(entry.get('timestamp', '')) - raw_level = entry.get('level', 'INFO') + timestamp = _format_timestamp(entry.get("timestamp", "")) + raw_level = entry.get("level", "INFO") level = _parse_log_level(raw_level) - message = _clean_message(entry.get('message', '')) - + message = _clean_message(entry.get("message", "")) + level_style = _get_level_style(level) - + console.print( f"[bright_black not bold]{timestamp}[/bright_black not bold] " f"[{level_style}]{level:7}[/{level_style}] " @@ -383,11 +422,13 @@ def _display_text_log_entry(entry: Dict[str, Any]) -> None: ) -def _display_logs(log_entries: List[Dict[str, Any]], title: str = "Logs", format: str = "text") -> None: +def _display_logs( + log_entries: List[Dict[str, Any]], title: str = "Logs", format: str = "text" +) -> None: """Display logs in the specified format.""" if not log_entries: return - + if format == "json": cleaned_entries = [_clean_log_entry(entry) for entry in log_entries] print(json.dumps(cleaned_entries, indent=2)) @@ -397,7 +438,7 @@ def _display_logs(log_entries: List[Dict[str, Any]], title: str = "Logs", format else: # text format (default) if title: console.print(f"[bold blue]{title}[/bold blue]\n") - + for entry in log_entries: _display_text_log_entry(entry) @@ -426,20 +467,20 @@ def _format_timestamp(timestamp_str: str) -> str: try: if timestamp_str: # Parse UTC timestamp and convert to local time - dt_utc = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00')) + dt_utc = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) dt_local = dt_utc.astimezone() - return dt_local.strftime('%H:%M:%S') - return datetime.now().strftime('%H:%M:%S') + return dt_local.strftime("%H:%M:%S") + return datetime.now().strftime("%H:%M:%S") except (ValueError, TypeError): return timestamp_str[:8] if len(timestamp_str) >= 8 else timestamp_str def _parse_log_level(level: str) -> str: """Parse log level from API format to clean display format.""" - if level.startswith('LOG_LEVEL_'): - clean_level = level.replace('LOG_LEVEL_', '') - if clean_level == 'UNSPECIFIED': - return 'UNKNOWN' + if level.startswith("LOG_LEVEL_"): + clean_level = level.replace("LOG_LEVEL_", "") + if clean_level == "UNSPECIFIED": + return "UNKNOWN" return clean_level return level.upper() @@ -447,29 +488,36 @@ def _parse_log_level(level: str) -> str: def _clean_message(message: str) -> str: """Remove redundant log level prefix from message if present.""" prefixes = [ - 'ERROR:', 'WARNING:', 'INFO:', 'DEBUG:', 'TRACE:', - 'WARN:', 'FATAL:', 'UNKNOWN:', 'UNSPECIFIED:' + "ERROR:", + "WARNING:", + "INFO:", + "DEBUG:", + "TRACE:", + "WARN:", + "FATAL:", + "UNKNOWN:", + "UNSPECIFIED:", ] - + for prefix in prefixes: if message.startswith(prefix): - return message[len(prefix):].lstrip() - + return message[len(prefix) :].lstrip() + return message def _get_level_style(level: str) -> str: """Get Rich style for log level.""" level = level.upper() - if level in ['ERROR', 'FATAL']: + if level in ["ERROR", "FATAL"]: return "red bold" - elif level in ['WARN', 'WARNING']: + elif level in ["WARN", "WARNING"]: return "yellow bold" - elif level == 'INFO': + elif level == "INFO": return "blue" - elif level in ['DEBUG', 'TRACE']: + elif level in ["DEBUG", "TRACE"]: return "dim" - elif level in ['UNKNOWN', 'UNSPECIFIED']: + elif level in ["UNKNOWN", "UNSPECIFIED"]: return "magenta" else: return "white" diff --git a/src/mcp_agent/cli/cloud/commands/logger/utils.py b/src/mcp_agent/cli/cloud/commands/logger/utils.py index 6c9a8c2a7..3e05bad20 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/utils.py +++ b/src/mcp_agent/cli/cloud/commands/logger/utils.py @@ -9,35 +9,37 @@ from mcp_agent.cli.core.constants import DEFAULT_API_BASE_URL -def parse_app_identifier(identifier: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: +def parse_app_identifier( + identifier: str, +) -> Tuple[Optional[str], Optional[str], Optional[str]]: """Parse app identifier to extract app ID, config ID, and server URL.""" - + # Check if it's a URL - if identifier.startswith(('http://', 'https://')): + if identifier.startswith(("http://", "https://")): return None, None, identifier - + # Check if it's an MCPAppConfig ID (starts with apcnf_) - if identifier.startswith('apcnf_'): + if identifier.startswith("apcnf_"): return None, identifier, None - + # Check if it's an MCPApp ID (starts with app_) - if identifier.startswith('app_'): + if identifier.startswith("app_"): return identifier, None, None - + # If no specific prefix, assume it's an app ID for backward compatibility return identifier, None, None async def resolve_server_url( app_id: Optional[str], - config_id: Optional[str], + config_id: Optional[str], credentials: UserCredentials, ) -> str: """Resolve server URL from app ID or configuration ID.""" - + if not app_id and not config_id: raise CLIError("Either app_id or config_id must be provided") - + # Determine the endpoint and payload based on identifier type if app_id: endpoint = "/mcp_app/get_app" @@ -57,38 +59,42 @@ async def resolve_server_url( no_url_msg = f"No server URL found for app configuration '{config_id}'" offline_msg = f"App configuration '{config_id}' server is offline" api_error_msg = "Failed to get app configuration" - + api_base = DEFAULT_API_BASE_URL headers = { "Authorization": f"Bearer {credentials.api_key}", "Content-Type": "application/json", } - + try: async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(f"{api_base}{endpoint}", json=payload, headers=headers) - + response = await client.post( + f"{api_base}{endpoint}", json=payload, headers=headers + ) + if response.status_code == 404: raise CLIError(not_found_msg) elif response.status_code != 200: - raise CLIError(f"{api_error_msg}: {response.status_code} {response.text}") - + raise CLIError( + f"{api_error_msg}: {response.status_code} {response.text}" + ) + data = response.json() resource_info = data.get(response_key, {}) server_info = resource_info.get("appServerInfo") - + if not server_info: raise CLIError(not_deployed_msg) - + server_url = server_info.get("serverUrl") if not server_url: raise CLIError(no_url_msg) - + status = server_info.get("status", "APP_SERVER_STATUS_UNSPECIFIED") if status == "APP_SERVER_STATUS_OFFLINE": raise CLIError(offline_msg) - + return server_url - + except httpx.RequestError as e: - raise CLIError(f"Failed to connect to API: {e}") \ No newline at end of file + raise CLIError(f"Failed to connect to API: {e}") diff --git a/src/mcp_agent/mcp/mcp_aggregator.py b/src/mcp_agent/mcp/mcp_aggregator.py index ca36c0e0e..f57e67b48 100644 --- a/src/mcp_agent/mcp/mcp_aggregator.py +++ b/src/mcp_agent/mcp/mcp_aggregator.py @@ -345,26 +345,37 @@ async def load_server(self, server_name: str): # Process tools async with self._tool_map_lock: self._server_to_tool_map[server_name] = [] - + # Get server configuration to check for tool filtering allowed_tools = None disabled_tool_count = 0 - if (self.context is None or self.context.server_registry is None - or not hasattr(self.context.server_registry, "get_server_config")): - logger.warning(f"No config found for server '{server_name}', no tool filter will be applied...") + if ( + self.context is None + or self.context.server_registry is None + or not hasattr(self.context.server_registry, "get_server_config") + ): + logger.warning( + f"No config found for server '{server_name}', no tool filter will be applied..." + ) else: - allowed_tools = self.context.server_registry.get_server_config(server_name).allowed_tools + allowed_tools = self.context.server_registry.get_server_config( + server_name + ).allowed_tools if allowed_tools is not None and len(allowed_tools) == 0: - logger.warning(f"Allowed tool list is explicitly empty for server '{server_name}'") - + logger.warning( + f"Allowed tool list is explicitly empty for server '{server_name}'" + ) + for tool in tools: # Apply tool filtering if configured - O(1) lookup with set if allowed_tools is not None and tool.name not in allowed_tools: - logger.debug(f"Filtering out tool '{tool.name}' from server '{server_name}' (not in allowed_tools)") + logger.debug( + f"Filtering out tool '{tool.name}' from server '{server_name}' (not in allowed_tools)" + ) disabled_tool_count += 1 continue - + namespaced_tool_name = f"{server_name}{SEP}{tool.name}" namespaced_tool = NamespacedTool( tool=tool, diff --git a/tests/mcp/test_mcp_aggregator.py b/tests/mcp/test_mcp_aggregator.py index d5ef11299..a592743fc 100644 --- a/tests/mcp/test_mcp_aggregator.py +++ b/tests/mcp/test_mcp_aggregator.py @@ -868,27 +868,28 @@ async def get_prompt(self, name, arguments=None): # ============================================================================= - class MockServerConfig: """Mock server configuration for testing""" + def __init__(self, allowed_tools=None): self.allowed_tools = allowed_tools class DummyContextWithServerRegistry: """Extended dummy context with server registry for tool filtering tests""" + def __init__(self, server_configs=None): self.tracer = None self.tracing_enabled = False self.server_configs = server_configs or {} - + class MockServerRegistry: def __init__(self, configs): self.configs = configs - + def get_server_config(self, server_name): return self.configs.get(server_name, MockServerConfig()) - + def start_server(self, server_name, client_session_factory=None): class DummyCtxMgr: async def __aenter__(self): @@ -896,13 +897,16 @@ class DummySession: async def initialize(self): class InitResult: capabilities = {"tools": True} + return InitResult() + return DummySession() - + async def __aexit__(self, exc_type, exc_val, exc_tb): pass + return DummyCtxMgr() - + self.server_registry = MockServerRegistry(self.server_configs) self._mcp_connection_manager_lock = asyncio.Lock() self._mcp_connection_manager_ref_count = 0 @@ -912,43 +916,59 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def test_tool_filtering_with_allowed_tools(): """Test that tools are filtered correctly when allowed_tools is configured""" # Setup server config with allowed tools - server_configs = { - "test_server": MockServerConfig(allowed_tools={"tool1", "tool3"}) - } + server_configs = {"test_server": MockServerConfig(allowed_tools={"tool1", "tool3"})} context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + # Mock tools that would be returned from server mock_tools = [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), # Should be included - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), # Should be filtered out - Tool(name="tool3", description="Description for tool3", inputSchema={"type": "object"}), # Should be included - Tool(name="tool4", description="Description for tool4", inputSchema={"type": "object"}), # Should be filtered out + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), # Should be included + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), # Should be filtered out + Tool( + name="tool3", + description="Description for tool3", + inputSchema={"type": "object"}, + ), # Should be included + Tool( + name="tool4", + description="Description for tool4", + inputSchema={"type": "object"}, + ), # Should be filtered out ] - + # Mock _fetch_capabilities to return our test tools async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) # capabilities, tools, prompts, resources - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Verify only allowed tools were added server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 2 - + tool_names = [tool.tool.name for tool in server_tools] assert "tool1" in tool_names assert "tool3" in tool_names assert "tool2" not in tool_names assert "tool4" not in tool_names - + # Verify namespaced tools map assert "test_server_tool1" in aggregator._namespaced_tool_map assert "test_server_tool3" in aggregator._namespaced_tool_map @@ -960,34 +980,46 @@ async def mock_fetch_capabilities(server_name): async def test_tool_filtering_no_filtering_when_none(): """Test that all tools are included when allowed_tools is None""" # Setup server config with no filtering - server_configs = { - "test_server": MockServerConfig(allowed_tools=None) - } + server_configs = {"test_server": MockServerConfig(allowed_tools=None)} context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + mock_tools = [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), - Tool(name="tool3", description="Description for tool3", inputSchema={"type": "object"}), + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), + Tool( + name="tool3", + description="Description for tool3", + inputSchema={"type": "object"}, + ), ] - + async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Verify all tools were added server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 3 - + tool_names = [tool.tool.name for tool in server_tools] assert "tool1" in tool_names assert "tool2" in tool_names @@ -998,33 +1030,41 @@ async def mock_fetch_capabilities(server_name): async def test_tool_filtering_empty_allowed_tools(): """Test behavior when allowed_tools is empty set (should filter out all tools)""" # Setup server config with empty allowed tools - server_configs = { - "test_server": MockServerConfig(allowed_tools=set()) - } + server_configs = {"test_server": MockServerConfig(allowed_tools=set())} context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + mock_tools = [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), ] - + async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Verify no tools were added server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 0 - + # Verify namespaced tools map is empty for this server assert "test_server_tool1" not in aggregator._namespaced_tool_map assert "test_server_tool2" not in aggregator._namespaced_tool_map @@ -1035,29 +1075,39 @@ async def test_tool_filtering_no_server_registry(): """Test fallback behavior when server registry is not available""" # Setup context without proper server registry context = DummyContext() # Original dummy context without server registry - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + mock_tools = [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), ] - + async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Should include all tools when no server registry is available server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 2 - + tool_names = [tool.tool.name for tool in server_tools] assert "tool1" in tool_names assert "tool2" in tool_names @@ -1073,40 +1123,70 @@ async def test_tool_filtering_multiple_servers(): "server3": MockServerConfig(allowed_tools=None), # No filtering } context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["server1", "server2", "server3"], connection_persistence=False, context=context, name="test_agent", ) - + # Different tools for each server server_tools = { "server1": [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), - Tool(name="tool_extra", description="Description for tool_extra", inputSchema={"type": "object"}) + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), + Tool( + name="tool_extra", + description="Description for tool_extra", + inputSchema={"type": "object"}, + ), ], "server2": [ - Tool(name="tool3", description="Description for tool3", inputSchema={"type": "object"}), - Tool(name="tool_filtered", description="Description for tool_filtered", inputSchema={"type": "object"}) + Tool( + name="tool3", + description="Description for tool3", + inputSchema={"type": "object"}, + ), + Tool( + name="tool_filtered", + description="Description for tool_filtered", + inputSchema={"type": "object"}, + ), ], "server3": [ - Tool(name="toolA", description="Description for toolA", inputSchema={"type": "object"}), - Tool(name="toolB", description="Description for toolB", inputSchema={"type": "object"}) + Tool( + name="toolA", + description="Description for toolA", + inputSchema={"type": "object"}, + ), + Tool( + name="toolB", + description="Description for toolB", + inputSchema={"type": "object"}, + ), ], } - + async def mock_fetch_capabilities(server_name): tools = server_tools.get(server_name, []) return (None, tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("server1") - await aggregator.load_server("server2") + await aggregator.load_server("server2") await aggregator.load_server("server3") - + # Check server1 filtering server1_tools = aggregator._server_to_tool_map.get("server1", []) assert len(server1_tools) == 2 @@ -1114,21 +1194,21 @@ async def mock_fetch_capabilities(server_name): assert "tool1" in server1_names assert "tool2" in server1_names assert "tool_extra" not in server1_names - + # Check server2 filtering server2_tools = aggregator._server_to_tool_map.get("server2", []) assert len(server2_tools) == 1 server2_names = [tool.tool.name for tool in server2_tools] assert "tool3" in server2_names assert "tool_filtered" not in server2_names - + # Check server3 (no filtering) server3_tools = aggregator._server_to_tool_map.get("server3", []) assert len(server3_tools) == 2 server3_names = [tool.tool.name for tool in server3_tools] assert "toolA" in server3_names assert "toolB" in server3_names - + # Check namespaced tools map assert "server1_tool1" in aggregator._namespaced_tool_map assert "server1_tool2" in aggregator._namespaced_tool_map @@ -1139,39 +1219,56 @@ async def mock_fetch_capabilities(server_name): assert "server3_toolB" in aggregator._namespaced_tool_map - -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_tool_filtering_edge_case_exact_match(): """Test that tool filtering requires exact name matches""" server_configs = { "test_server": MockServerConfig(allowed_tools={"tool", "tool_exact"}) } context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + mock_tools = [ - Tool(name="tool", description="Description for tool", inputSchema={"type": "object"}), # Should be included (exact match) - Tool(name="tool_exact", description="Description for tool_exact", inputSchema={"type": "object"}), # Should be included (exact match) - Tool(name="tool_similar", description="Description for tool_similar", inputSchema={"type": "object"}), # Should be filtered (not exact match) - Tool(name="my_tool", description="Description for my_tool", inputSchema={"type": "object"}), # Should be filtered (not exact match) + Tool( + name="tool", + description="Description for tool", + inputSchema={"type": "object"}, + ), # Should be included (exact match) + Tool( + name="tool_exact", + description="Description for tool_exact", + inputSchema={"type": "object"}, + ), # Should be included (exact match) + Tool( + name="tool_similar", + description="Description for tool_similar", + inputSchema={"type": "object"}, + ), # Should be filtered (not exact match) + Tool( + name="my_tool", + description="Description for my_tool", + inputSchema={"type": "object"}, + ), # Should be filtered (not exact match) ] - + async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Verify only exact matches were included server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 2 - + tool_names = [tool.tool.name for tool in server_tools] assert "tool" in tool_names assert "tool_exact" in tool_names From 905b450755beff9abf26078613034f95c14e9f5c Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Sep 2025 15:10:13 -0400 Subject: [PATCH 19/24] many fixes --- examples/mcp_agent_server/temporal/client.py | 17 +++ src/mcp_agent/app.py | 7 + src/mcp_agent/executor/temporal/__init__.py | 3 +- .../executor/temporal/interceptor.py | 6 +- .../executor/temporal/session_proxy.py | 7 +- .../executor/temporal/temporal_context.py | 37 ++++- src/mcp_agent/logging/listeners.py | 2 +- src/mcp_agent/logging/logger.py | 129 +++++++++++++++++- src/mcp_agent/server/app_server.py | 13 +- .../workflows/deep_orchestrator/README.md | 2 +- uv.lock | 2 +- 11 files changed, 202 insertions(+), 23 deletions(-) diff --git a/examples/mcp_agent_server/temporal/client.py b/examples/mcp_agent_server/temporal/client.py index 198ce0789..c945e1cd2 100644 --- a/examples/mcp_agent_server/temporal/client.py +++ b/examples/mcp_agent_server/temporal/client.py @@ -1,6 +1,7 @@ import asyncio import json import time +import argparse from mcp_agent.app import MCPApp from mcp_agent.config import MCPServerSettings from mcp_agent.executor.workflow import WorkflowExecution @@ -23,6 +24,14 @@ async def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--server-log-level", + type=str, + default=None, + help="Set server logging level (debug, info, notice, warning, error, critical, alert, emergency)", + ) + args = parser.parse_args() # Create MCPApp to get the server registry app = MCPApp(name="workflow_mcp_client") async with app.run() as client_app: @@ -69,6 +78,14 @@ def make_session( context.server_registry, client_session_factory=make_session, ) as server: + # Ask server to send logs at the requested level (default info) + level = (args.server_log_level or "info").lower() + print(f"[client] Setting server logging level to: {level}") + try: + await server.set_logging_level(level) + except Exception: + # Older servers may not support logging capability + print("[client] Server does not support logging/setLevel") # Call the BasicAgentWorkflow run_result = await server.call_tool( "workflows-BasicAgentWorkflow-run", diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index 87697bdd0..e72dc7636 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -15,6 +15,7 @@ from mcp_agent.executor.signal_registry import SignalRegistry from mcp_agent.logging.event_progress import ProgressAction from mcp_agent.logging.logger import get_logger +from mcp_agent.logging.logger import set_default_bound_context from mcp_agent.executor.decorator_registry import ( DecoratorRegistry, register_asyncio_decorators, @@ -233,6 +234,12 @@ async def initialize(self): # Store a reference to this app instance in the context for easier access self._context.app = self + # Provide a safe default bound context for loggers created after init without explicit context + try: + set_default_bound_context(self._context) + except Exception: + pass + # Auto-load subagents if enabled in settings try: subagents = self._config.agents diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index 4863f97bf..aa7af60ce 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -37,8 +37,8 @@ from mcp_agent.executor.workflow_signal import SignalHandler from mcp_agent.logging.logger import get_logger from mcp_agent.utils.common import unwrap -from mcp_agent.executor.temporal.system_activities import SystemActivities from mcp_agent.executor.temporal.interceptor import ContextPropagationInterceptor +from mcp_agent.executor.temporal.system_activities import SystemActivities if TYPE_CHECKING: from mcp_agent.app import MCPApp @@ -521,6 +521,7 @@ async def create_temporal_worker_for_app(app: "MCPApp"): task_queue=running_app.executor.config.task_queue, activities=activities, workflows=workflows, + interceptors=[ContextPropagationInterceptor()], ) try: diff --git a/src/mcp_agent/executor/temporal/interceptor.py b/src/mcp_agent/executor/temporal/interceptor.py index a085f731f..3e7ed5d0e 100644 --- a/src/mcp_agent/executor/temporal/interceptor.py +++ b/src/mcp_agent/executor/temporal/interceptor.py @@ -40,6 +40,7 @@ def set_header_from_context( def context_from_header( input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter ): + prev_exec_id = get_execution_id() execution_id_payload = input.headers.get(EXECUTION_ID_KEY) execution_id_from_header = ( payload_converter.from_payload(execution_id_payload, str) @@ -48,7 +49,10 @@ def context_from_header( ) set_execution_id(execution_id_from_header if execution_id_from_header else None) - yield + try: + yield + finally: + set_execution_id(prev_exec_id) class ContextPropagationInterceptor( diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index 18b30d1a0..2f6580ec2 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -11,7 +11,7 @@ from temporalio import workflow as _twf from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession as _BaseServerSession +from mcp.server.session import ServerSession from mcp.shared.message import ServerMessageMetadata from mcp_agent.core.context import Context @@ -19,7 +19,7 @@ from mcp_agent.executor.temporal.temporal_context import get_execution_id -class SessionProxy(_BaseServerSession): +class SessionProxy(ServerSession): """ SessionProxy acts like an MCP `ServerSession` for code running under the Temporal engine. It forwards server->client messages through the MCPApp @@ -108,9 +108,6 @@ async def request( return await self._executor.execute(act, exec_id, method, params or {}) return await self._sys_acts.relay_request(exec_id, method, params or {}) - # ---------------------- - # ServerSession-like API - # ---------------------- async def send_notification( self, notification: types.ServerNotification, diff --git a/src/mcp_agent/executor/temporal/temporal_context.py b/src/mcp_agent/executor/temporal/temporal_context.py index c34b53569..896df214d 100644 --- a/src/mcp_agent/executor/temporal/temporal_context.py +++ b/src/mcp_agent/executor/temporal/temporal_context.py @@ -2,7 +2,9 @@ EXECUTION_ID_KEY = "__execution_id" -_EXECUTION_ID: str | None = None +# Fallback global for non-Temporal contexts. This is best-effort only and +# used when neither workflow nor activity runtime is available. +_EXECUTION_ID: Optional[str] = None def set_execution_id(execution_id: Optional[str]) -> None: @@ -11,4 +13,37 @@ def set_execution_id(execution_id: Optional[str]) -> None: def get_execution_id() -> Optional[str]: + """Return the current Temporal run identifier to use for gateway routing. + + Priority: + - If inside a Temporal workflow, return workflow.info().run_id + - Else if inside a Temporal activity, return activity.info().workflow_run_id + - Else fall back to the process-scoped ContextVar (best-effort) + """ + # Try workflow runtime first + try: + from temporalio import workflow as _wf # type: ignore + + try: + if getattr(_wf, "_Runtime").current() is not None: # type: ignore[attr-defined] + return _wf.info().run_id + except Exception: + pass + except Exception: + pass + + # Then try activity runtime + try: + from temporalio import activity as _act # type: ignore + + try: + info = _act.info() + if info is not None and getattr(info, "workflow_run_id", None): + return info.workflow_run_id + except Exception: + pass + except Exception: + pass + + # Fallback to module-global (primarily for non-Temporal contexts) return _EXECUTION_ID diff --git a/src/mcp_agent/logging/listeners.py b/src/mcp_agent/logging/listeners.py index 0abc4329f..051baf075 100644 --- a/src/mcp_agent/logging/listeners.py +++ b/src/mcp_agent/logging/listeners.py @@ -249,7 +249,7 @@ async def handle_matched_event(self, event: Event) -> None: ) if upstream_session is None: - # No upstream_session available, event cannot be forwarded + # No upstream_session available; silently skip return # Map our EventType to MCP LoggingLevel; fold progress -> info diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index faf5b7117..6ef44b510 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -78,8 +78,111 @@ def _emit_event(self, event: Event): loop, temporalio.worker._workflow_instance._WorkflowInstanceImpl ): # Handle Temporal workflow environment where run_until_complete() is not implemented - # In Temporal, we can't block on async operations, so we'll need to avoid this - # Simply log to stdout/stderr as a fallback + # Prefer forwarding via the upstream session proxy using a workflow task, if available. + try: + from temporalio import workflow as _wf # type: ignore + from mcp_agent.executor.temporal.temporal_context import ( + get_execution_id as _get_exec_id, + ) + + upstream = getattr(event, "upstream_session", None) + if ( + upstream is None + and getattr(self, "_bound_context", None) is not None + ): + try: + upstream = getattr( + self._bound_context, "upstream_session", None + ) + except Exception: + upstream = None + + # Construct payload + async def _forward_via_proxy(): + # If we have an upstream session, use it first + if upstream is not None: + try: + level_map = { + "debug": "debug", + "info": "info", + "warning": "warning", + "error": "error", + "progress": "info", + } + level = level_map.get(event.type, "info") + logger_name = ( + event.namespace + if not event.name + else f"{event.namespace}.{event.name}" + ) + data = { + "message": event.message, + "namespace": event.namespace, + "name": event.name, + "timestamp": event.timestamp.isoformat(), + } + if event.data: + data["data"] = event.data + if event.trace_id or event.span_id: + data["trace"] = { + "trace_id": event.trace_id, + "span_id": event.span_id, + } + if event.context is not None: + try: + data["context"] = event.context.dict() + except Exception: + pass + + await upstream.send_log_message( # type: ignore[attr-defined] + level=level, data=data, logger=logger_name + ) + return + except Exception: + pass + + # Fallback: use activity gateway directly if execution_id is available + try: + exec_id = _get_exec_id() + if exec_id: + level = { + "debug": "debug", + "info": "info", + "warning": "warning", + "error": "error", + "progress": "info", + }.get(event.type, "info") + ns = event.namespace + msg = event.message + data = event.data or {} + # Call by activity name to align with worker registration + await _wf.execute_activity( + "mcp_forward_log", + exec_id, + level, + ns, + msg, + data, + schedule_to_close_timeout=5, + ) + return + except Exception as _e: + pass + + # If all else fails, fall back to stderr transport + self.event_bus.emit_with_stderr_transport(event) + + try: + _wf.create_task(_forward_via_proxy()) + return + except Exception: + # Could not create workflow task, fall through to stderr transport + pass + except Exception: + # If Temporal workflow module unavailable or any error occurs, fall through + pass + + # As a last resort, log to stdout/stderr as a fallback self.event_bus.emit_with_stderr_transport(event) else: try: @@ -120,6 +223,19 @@ def event( extra_event_fields["upstream_session"] = upstream except Exception: pass + # Fallback to default bound context if logger wasn't explicitly bound + if "upstream_session" not in extra_event_fields: + try: + from mcp_agent.logging.logger import _default_bound_context as _dbc # type: ignore + + if _dbc is not None: + _up = getattr(_dbc, "upstream_session", None) + if _up is not None: + extra_event_fields["upstream_session"] = _up + except Exception: + pass + + # Do not use global context fallbacks here; they are unsafe under concurrency. # No further fallbacks; upstream forwarding must be enabled by passing # a bound context when creating the logger or by server code attaching @@ -385,6 +501,7 @@ async def managed(cls, **config_kwargs): _logger_lock = threading.Lock() _loggers: Dict[str, Logger] = {} +_default_bound_context: Any | None = None def get_logger(namespace: str, session_id: str | None = None, context=None) -> Logger: @@ -404,7 +521,8 @@ def get_logger(namespace: str, session_id: str | None = None, context=None) -> L with _logger_lock: existing = _loggers.get(namespace) if existing is None: - logger = Logger(namespace, session_id, context) + bound_ctx = context if context is not None else _default_bound_context + logger = Logger(namespace, session_id, bound_ctx) _loggers[namespace] = logger return logger @@ -415,3 +533,8 @@ def get_logger(namespace: str, session_id: str | None = None, context=None) -> L existing._bound_context = context return existing + + +def set_default_bound_context(ctx: Any | None) -> None: + global _default_bound_context + _default_bound_context = ctx diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 09226f165..3ac614276 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -8,7 +8,6 @@ from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING import os -import uuid import asyncio from mcp.server.fastmcp import Context as MCPContext, FastMCP @@ -25,9 +24,7 @@ WorkflowRegistry, InMemoryWorkflowRegistry, ) -from mcp_agent.executor.temporal.temporal_context import ( - set_execution_id, -) + from mcp_agent.logging.logger import get_logger from mcp_agent.logging.logger import LoggingConfig from mcp_agent.mcp.mcp_server_registry import ServerRegistry @@ -1335,10 +1332,8 @@ async def _workflow_run( run_parameters: Dict[str, Any] | None = None, **kwargs: Any, ) -> Dict[str, str]: - # Generate a unique execution ID to track this run. We need to pass this to the workflow, and the run_id is only established - # after we create the workflow - execution_id = str(uuid.uuid4()) - set_execution_id(execution_id) + # Use Temporal run_id as the routing key for gateway callbacks. + # We don't have it until after the workflow is started; we'll register mapping post-start. # Resolve workflows and app context irrespective of startup mode # This now returns a context with upstream_session already set @@ -1420,7 +1415,6 @@ async def _workflow_run( workflow_memo = { "gateway_url": gateway_url, "gateway_token": gateway_token, - "execution_id": execution_id, } except Exception: workflow_memo = None @@ -1431,6 +1425,7 @@ async def _workflow_run( **run_parameters, ) + execution_id = execution.run_id logger.info( f"Workflow {workflow_name} started execution {execution_id} for workflow ID {execution.workflow_id}, " f"run ID {execution.run_id}. Parameters: {run_parameters}" diff --git a/src/mcp_agent/workflows/deep_orchestrator/README.md b/src/mcp_agent/workflows/deep_orchestrator/README.md index 132e8d195..f765fd214 100644 --- a/src/mcp_agent/workflows/deep_orchestrator/README.md +++ b/src/mcp_agent/workflows/deep_orchestrator/README.md @@ -35,11 +35,11 @@ flowchart TB B --> C{Execute Tasks} C --> D[Extract Knowledge] D --> E{Objective Complete?} + E -->|Yes| G E -->|No| F{Check Policy} F -->|Replan| B F -->|Continue| C F -->|Stop| G[Synthesize Results] - E -->|Yes| G G --> H[Final Result] style B fill:#e1f5fe diff --git a/uv.lock b/uv.lock index 218056688..63660bdb1 100644 --- a/uv.lock +++ b/uv.lock @@ -2040,7 +2040,7 @@ wheels = [ [[package]] name = "mcp-agent" -version = "0.1.15" +version = "0.1.16" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From c1e7d1d5c9866247d328aa7889a76d4a19821824 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Sep 2025 15:26:40 -0400 Subject: [PATCH 20/24] add tests --- pyproject.toml | 2 +- .../test_execution_id_and_interceptor.py | 112 ++++++++++++++++++ uv.lock | 2 +- 3 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 tests/executor/temporal/test_execution_id_and_interceptor.py diff --git a/pyproject.toml b/pyproject.toml index efa4f5ccc..2465cbe07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcp-agent" -version = "0.1.16" +version = "0.1.17" description = "Build effective agents with Model Context Protocol (MCP) using simple, composable patterns." readme = "README.md" license = { file = "LICENSE" } diff --git a/tests/executor/temporal/test_execution_id_and_interceptor.py b/tests/executor/temporal/test_execution_id_and_interceptor.py new file mode 100644 index 000000000..cfee00089 --- /dev/null +++ b/tests/executor/temporal/test_execution_id_and_interceptor.py @@ -0,0 +1,112 @@ +import pytest +from unittest.mock import MagicMock, patch + + +@pytest.mark.asyncio +@patch("temporalio.workflow.info") +@patch("temporalio.workflow._Runtime.current", return_value=MagicMock()) +def test_get_execution_id_in_workflow(mock_runtime, mock_info): + from mcp_agent.executor.temporal.temporal_context import get_execution_id + + mock_info.return_value.run_id = "run-123" + assert get_execution_id() == "run-123" + + +@pytest.mark.asyncio +@patch("temporalio.activity.info") +def test_get_execution_id_in_activity(mock_act_info): + from mcp_agent.executor.temporal.temporal_context import get_execution_id + + mock_act_info.return_value.workflow_run_id = "run-aaa" + assert get_execution_id() == "run-aaa" + + +def test_interceptor_restores_prev_value(): + from mcp_agent.executor.temporal.interceptor import context_from_header + from mcp_agent.executor.temporal.temporal_context import ( + EXECUTION_ID_KEY, + set_execution_id, + get_execution_id, + ) + import temporalio.converter + + payload_converter = temporalio.converter.default().payload_converter + + class Input: + headers = {} + + set_execution_id("prev") + input = Input() + # simulate header with new value + input.headers[EXECUTION_ID_KEY] = payload_converter.to_payload("new") + + assert get_execution_id() == "prev" + with context_from_header(input, payload_converter): + # inside scope we should get header value + assert get_execution_id() == "new" + # restored + assert get_execution_id() == "prev" + + +@pytest.mark.asyncio +async def test_http_proxy_helpers_happy_and_error_paths(monkeypatch): + from mcp_agent.mcp import client_proxy + + class Resp: + def __init__(self, status_code, json_data=None, text=""): + self.status_code = status_code + self._json = json_data or {} + self.text = text + self.content = b"x" if json_data is not None else b"" + + def json(self): + return self._json + + class Client: + def __init__(self, rcodes_iter): + self._rcodes = rcodes_iter + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, json=None, headers=None): + code, body = next(self._rcodes) + if body is None: + return Resp(code) + return Resp(code, body) + + # log_via_proxy ok, then error + rcodes = iter( + [ + (200, {"ok": True}), + (500, None), + (200, {"ok": True}), + (401, None), + (200, {"ok": True}), + (400, None), + ] + ) + + monkeypatch.setattr( + client_proxy.httpx, "AsyncClient", lambda timeout: Client(rcodes) + ) + + ok = await client_proxy.log_via_proxy(None, "run", "info", "ns", "msg") + assert ok is True + ok = await client_proxy.log_via_proxy(None, "run", "info", "ns", "msg") + assert ok is False + + # notify ok, then error + ok = await client_proxy.notify_via_proxy(None, "run", "m", {}) + assert ok is True + ok = await client_proxy.notify_via_proxy(None, "run", "m", {}) + assert ok is False + + # request ok, then error + res = await client_proxy.request_via_proxy(None, "run", "m", {}) + assert isinstance(res, dict) and res.get("ok", True) in (True,) + res = await client_proxy.request_via_proxy(None, "run", "m", {}) + assert isinstance(res, dict) and "error" in res diff --git a/uv.lock b/uv.lock index 63660bdb1..09b0c7340 100644 --- a/uv.lock +++ b/uv.lock @@ -2040,7 +2040,7 @@ wheels = [ [[package]] name = "mcp-agent" -version = "0.1.16" +version = "0.1.17" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From 7f26c7890514ac8015db44e406deed25d066a76d Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Sep 2025 16:18:37 -0400 Subject: [PATCH 21/24] PR feedback and critical fixes --- src/mcp_agent/executor/temporal/__init__.py | 12 +-- .../executor/temporal/interceptor.py | 2 +- .../executor/temporal/session_proxy.py | 10 ++- src/mcp_agent/executor/workflow.py | 80 +++---------------- src/mcp_agent/logging/logger.py | 36 ++++++--- 5 files changed, 50 insertions(+), 90 deletions(-) diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index aa7af60ce..0e53bd9fa 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -504,11 +504,13 @@ async def create_temporal_worker_for_app(app: "MCPApp"): activity_registry = running_app.context.task_registry # Register system activities (logging, human input proxy, generic relays) - sys_acts = SystemActivities(context=running_app.context) - app.workflow_task(name="mcp_forward_log")(sys_acts.forward_log) - app.workflow_task(name="mcp_request_user_input")(sys_acts.request_user_input) - app.workflow_task(name="mcp_relay_notify")(sys_acts.relay_notify) - app.workflow_task(name="mcp_relay_request")(sys_acts.relay_request) + system_activities = SystemActivities(context=running_app.context) + app.workflow_task(name="mcp_forward_log")(system_activities.forward_log) + app.workflow_task(name="mcp_request_user_input")( + system_activities.request_user_input + ) + app.workflow_task(name="mcp_relay_notify")(system_activities.relay_notify) + app.workflow_task(name="mcp_relay_request")(system_activities.relay_request) for name in activity_registry.list_activities(): activities.append(activity_registry.get_activity(name)) diff --git a/src/mcp_agent/executor/temporal/interceptor.py b/src/mcp_agent/executor/temporal/interceptor.py index 3e7ed5d0e..c680fdaae 100644 --- a/src/mcp_agent/executor/temporal/interceptor.py +++ b/src/mcp_agent/executor/temporal/interceptor.py @@ -62,7 +62,7 @@ class ContextPropagationInterceptor( This interceptor implements methods `temporalio.client.Interceptor` and `temporalio.worker.Interceptor` so that - (1) a user ID key is taken from context by the client code and sent in a header field with outbound requests + (1) an execution ID key is taken from context by the client code and sent in a header field with outbound requests (2) workflows take this value from their task input, set it in context, and propagate it into the header field of their outbound calls (3) activities similarly take the value from their task input and set it in context so that it's available for their diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index 2f6580ec2..25e04c8f0 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -67,7 +67,7 @@ def __init__(self, *, executor, context: Context) -> None: self._executor = executor self._context = context # Local helper used when we're not inside a workflow runtime - self._sys_acts = SystemActivities(context) + self._system_activities = SystemActivities(context) # Provide a low-level RPC facade similar to real ServerSession self.rpc = _RPC(self) @@ -91,7 +91,9 @@ async def notify(self, method: str, params: Dict[str, Any] | None = None) -> boo except Exception: return False # Non-workflow (activity/asyncio) - return bool(await self._sys_acts.relay_notify(exec_id, method, params or {})) + return bool( + await self._system_activities.relay_notify(exec_id, method, params or {}) + ) async def request( self, method: str, params: Dict[str, Any] | None = None @@ -106,7 +108,9 @@ async def request( if _in_workflow_runtime(): act = self._context.task_registry.get_activity("mcp_relay_request") return await self._executor.execute(act, exec_id, method, params or {}) - return await self._sys_acts.relay_request(exec_id, method, params or {}) + return await self._system_activities.relay_request( + exec_id, method, params or {} + ) async def send_notification( self, diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 35bcad1d5..b76c9a544 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -16,14 +16,11 @@ from pydantic import BaseModel, ConfigDict, Field from mcp_agent.core.context_dependent import ContextDependent -from mcp_agent.executor.temporal import TemporalExecutor -from mcp_agent.executor.temporal.workflow_signal import ( +from mcp_agent.executor.workflow_signal import ( Signal, SignalMailbox, ) -from mcp_agent.executor.temporal.session_proxy import SessionProxy from mcp_agent.logging.logger import get_logger -# (Temporal path now uses activities; HTTP proxy helpers unused here) if TYPE_CHECKING: from temporalio.client import WorkflowHandle @@ -256,6 +253,8 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": # Hint the logger with the current run_id for Temporal proxy fallback try: if self.context.config.execution_engine == "temporal": + from mcp_agent.executor.temporal.session_proxy import SessionProxy + setattr(self._logger, "_temporal_run_id", self._run_id) # Ensure upstream_session is a passthrough SessionProxy bound to this run upstream_session = getattr(self.context, "upstream_session", None) @@ -381,63 +380,6 @@ async def _execute_workflow(): workflow_id=self._workflow_id, ) - # Engine-aware helpers to unify upstream interactions - async def log_upstream( - self, - level: str, - namespace: str, - message: str, - data: Dict[str, Any] | None = None, - ): - if self.context.config.execution_engine == "temporal": - # Route via Temporal activity for determinism - try: - act = self.context.task_registry.get_activity("mcp_forward_log") - await self.executor.execute( - act, - self._run_id or "", - level, - namespace, - message, - data or {}, - ) - except Exception: - pass - return - # asyncio: use local logger - if level == "debug": - self._logger.debug(message, **(data or {})) - elif level == "warning": - self._logger.warning(message, **(data or {})) - elif level == "error": - self._logger.error(message, **(data or {})) - else: - self._logger.info(message, **(data or {})) - - async def ask_user( - self, prompt: str, metadata: Dict[str, Any] | None = None - ) -> Any: - if self.context.config.execution_engine == "temporal": - # Route via Temporal activity for determinism; returns request_id or error - try: - act = self.context.task_registry.get_activity("mcp_request_user_input") - return await self.executor.execute( - act, - self.context.session_id or "", - self.id or self.name, - self._run_id or "", - prompt, - (metadata or {}).get("signal_name", "human_input"), - ) - except Exception as e: - return {"error": str(e)} - handler = getattr(self.context, "human_input_handler", None) - if not handler: - return None - if asyncio.iscoroutinefunction(handler): # type: ignore[arg-type] - return await handler({"prompt": prompt, "metadata": metadata or {}}) - return handler({"prompt": prompt, "metadata": metadata or {}}) - async def resume( self, signal_name: str | None = "resume", payload: str | None = None ) -> bool: @@ -831,21 +773,21 @@ async def initialize(self): memo_map = None if isinstance(memo_map, dict): - gw = memo_map.get("gateway_url") - gt = memo_map.get("gateway_token") + gateway_url = memo_map.get("gateway_url") + gateway_token = memo_map.get("gateway_token") self._logger.debug( - f"Proxy parameters: gateway_url={gw}, gateway_token={gt}" + f"Proxy parameters: gateway_url={gateway_url}, gateway_token={gateway_token}" ) - if gw: + if gateway_url: try: - self.context.gateway_url = gw + self.context.gateway_url = gateway_url except Exception: pass - if gt: + if gateway_token: try: - self.context.gateway_token = gt + self.context.gateway_token = gateway_token except Exception: pass except Exception: @@ -855,6 +797,8 @@ async def initialize(self): # Expose a virtual upstream session (passthrough) bound to this run via activities # This lets any code use context.upstream_session like a real session. try: + from mcp_agent.executor.temporal.session_proxy import SessionProxy + upstream_session = getattr(self.context, "upstream_session", None) if upstream_session is None: diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index 6ef44b510..685526238 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -15,7 +15,6 @@ from contextlib import asynccontextmanager, contextmanager -import temporalio from mcp_agent.logging.events import ( Event, @@ -74,13 +73,24 @@ def _emit_event(self, event: Event): asyncio.create_task(self.event_bus.emit(event)) else: # If no loop is running, run it until the emit completes - if isinstance( - loop, temporalio.worker._workflow_instance._WorkflowInstanceImpl - ): - # Handle Temporal workflow environment where run_until_complete() is not implemented + # Detect Temporal workflow runtime without hard dependency + # If inside Temporal workflow sandbox, avoid run_until_complete and use workflow-safe forwarding + in_temporal_workflow = False + try: + from temporalio import workflow as _wf # type: ignore + + try: + # Detect active Temporal workflow runtime + if getattr(_wf, "_Runtime").current() is not None: # type: ignore[attr-defined] + in_temporal_workflow = True + except Exception: + in_temporal_workflow = False + except Exception: + in_temporal_workflow = False + + if in_temporal_workflow: # Prefer forwarding via the upstream session proxy using a workflow task, if available. try: - from temporalio import workflow as _wf # type: ignore from mcp_agent.executor.temporal.temporal_context import ( get_execution_id as _get_exec_id, ) @@ -166,18 +176,18 @@ async def _forward_via_proxy(): schedule_to_close_timeout=5, ) return - except Exception as _e: + except Exception: pass # If all else fails, fall back to stderr transport self.event_bus.emit_with_stderr_transport(event) - try: - _wf.create_task(_forward_via_proxy()) - return - except Exception: - # Could not create workflow task, fall through to stderr transport - pass + try: + _wf.create_task(_forward_via_proxy()) + return + except Exception: + # Could not create workflow task, fall through to stderr transport + pass except Exception: # If Temporal workflow module unavailable or any error occurs, fall through pass From 44dd1e027fb8d22f37315bcc79b4fee0120a452e Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Sep 2025 16:41:36 -0400 Subject: [PATCH 22/24] More PR feedback, thanks AI --- src/mcp_agent/app.py | 21 +++- src/mcp_agent/executor/temporal/__init__.py | 8 +- .../executor/temporal/session_proxy.py | 2 +- .../executor/temporal/temporal_context.py | 12 +- .../executor/temporal/workflow_signal.py | 4 +- src/mcp_agent/executor/workflow.py | 17 --- src/mcp_agent/logging/logger.py | 18 ++- src/mcp_agent/mcp/client_proxy.py | 117 +++++++++++------- src/mcp_agent/tracing/token_counter.py | 2 +- .../tracing/token_tracking_decorator.py | 2 +- 10 files changed, 109 insertions(+), 94 deletions(-) diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index e72dc7636..b28a31f01 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -849,14 +849,15 @@ def decorator(target: Callable[..., R]) -> Callable[..., R]: ) if task_defn: - # prevent trying to decorate an already decorated function + # Prevent re-decoration of an already temporal-decorated function, + # but still register it with the app. if hasattr(target, "__temporal_activity_definition"): self.logger.debug( - f"target {name} has __temporal_activity_definition" + "Skipping redecorate for already-temporal activity", + data={"activity_name": activity_name}, ) - return target # Already decorated with @activity - - if isinstance(target, MethodType): + task_callable = target + elif isinstance(target, MethodType): self_ref = target.__self__ @functools.wraps(func) @@ -919,7 +920,15 @@ def _register_global_workflow_tasks(self): ) if task_defn: # Engine-specific decorator available - if isinstance(target, MethodType): + # Prevent re-decoration of an already temporal-decorated function, + # but still register it with the app. + if hasattr(target, "__temporal_activity_definition"): + self.logger.debug( + "Skipping redecorate for already-temporal activity", + data={"activity_name": activity_name}, + ) + task_callable = target + elif isinstance(target, MethodType): self_ref = target.__self__ @functools.wraps(func) diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index 0e53bd9fa..3840d4944 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -122,7 +122,7 @@ async def run_task(task: Callable[..., R] | Coroutine[Any, Any, R]) -> R: return await task(*args, **kwargs) else: # Check if we're in a Temporal workflow context - if workflow._Runtime.current(): + if workflow.in_workflow(): wrapped_task = functools.partial(task, *args, **kwargs) result = wrapped_task() else: @@ -199,7 +199,7 @@ async def execute( """Execute multiple tasks (activities) in parallel.""" # Must be called from within a workflow - if not workflow._Runtime.current(): + if not workflow.in_workflow(): raise RuntimeError( "TemporalExecutor.execute must be called from within a workflow" ) @@ -217,7 +217,7 @@ async def execute_many( """Execute multiple tasks (activities) in parallel.""" # Must be called from within a workflow - if not workflow._Runtime.current(): + if not workflow.in_workflow(): raise RuntimeError( "TemporalExecutor.execute must be called from within a workflow" ) @@ -235,7 +235,7 @@ async def execute_streaming( *args, **kwargs, ) -> AsyncIterator[R | BaseException]: - if not workflow._Runtime.current(): + if not workflow.in_workflow(): raise RuntimeError( "TemporalExecutor.execute_streaming must be called from within a workflow" ) diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index 25e04c8f0..b0fa213a4 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -289,7 +289,7 @@ async def elicit( def _in_workflow_runtime() -> bool: """Return True if currently executing inside a Temporal workflow sandbox.""" try: - return _twf._Runtime.current() is not None # type: ignore[attr-defined] + return _twf.in_workflow() except Exception: return False diff --git a/src/mcp_agent/executor/temporal/temporal_context.py b/src/mcp_agent/executor/temporal/temporal_context.py index 896df214d..fa1cbf49b 100644 --- a/src/mcp_agent/executor/temporal/temporal_context.py +++ b/src/mcp_agent/executor/temporal/temporal_context.py @@ -18,15 +18,15 @@ def get_execution_id() -> Optional[str]: Priority: - If inside a Temporal workflow, return workflow.info().run_id - Else if inside a Temporal activity, return activity.info().workflow_run_id - - Else fall back to the process-scoped ContextVar (best-effort) + - Else fall back to the global (best-effort) """ # Try workflow runtime first try: - from temporalio import workflow as _wf # type: ignore + from temporalio import workflow # type: ignore try: - if getattr(_wf, "_Runtime").current() is not None: # type: ignore[attr-defined] - return _wf.info().run_id + if workflow.in_workflow(): + return workflow.info().run_id except Exception: pass except Exception: @@ -34,10 +34,10 @@ def get_execution_id() -> Optional[str]: # Then try activity runtime try: - from temporalio import activity as _act # type: ignore + from temporalio import activity # type: ignore try: - info = _act.info() + info = activity.info() if info is not None and getattr(info, "workflow_run_id", None): return info.workflow_run_id except Exception: diff --git a/src/mcp_agent/executor/temporal/workflow_signal.py b/src/mcp_agent/executor/temporal/workflow_signal.py index be4fb6801..30f6f4835 100644 --- a/src/mcp_agent/executor/temporal/workflow_signal.py +++ b/src/mcp_agent/executor/temporal/workflow_signal.py @@ -91,7 +91,7 @@ async def wait_for_signal( TimeoutError: If timeout is reached ValueError: If no value exists for the signal after waiting """ - if not workflow._Runtime.current(): + if not workflow.in_workflow(): raise RuntimeError("wait_for_signal must be called from within a workflow") # Get the mailbox safely from ContextVar @@ -156,7 +156,7 @@ async def signal(self, signal: Signal[SignalValueT]) -> None: # Validate the signal (already checks workflow_id is not None) self.validate_signal(signal) - if workflow._Runtime.current() is not None: + if workflow.in_workflow(): workflow_info = workflow.info() if ( signal.workflow_id == workflow_info.workflow_id diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index b76c9a544..7e0eed92d 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -250,23 +250,6 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": f"Workflow started with workflow ID: {self._workflow_id}, run ID: {self._run_id}" ) - # Hint the logger with the current run_id for Temporal proxy fallback - try: - if self.context.config.execution_engine == "temporal": - from mcp_agent.executor.temporal.session_proxy import SessionProxy - - setattr(self._logger, "_temporal_run_id", self._run_id) - # Ensure upstream_session is a passthrough SessionProxy bound to this run - upstream_session = getattr(self.context, "upstream_session", None) - - if upstream_session is None: - self.context.upstream_session = SessionProxy( - executor=self.executor, - context=self.context, - ) - except Exception: - pass - # Define the workflow execution function async def _execute_workflow(): try: diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index 685526238..5c70851ac 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -80,9 +80,7 @@ def _emit_event(self, event: Event): from temporalio import workflow as _wf # type: ignore try: - # Detect active Temporal workflow runtime - if getattr(_wf, "_Runtime").current() is not None: # type: ignore[attr-defined] - in_temporal_workflow = True + in_temporal_workflow = bool(_wf.in_workflow()) except Exception: in_temporal_workflow = False except Exception: @@ -234,14 +232,14 @@ def event( except Exception: pass # Fallback to default bound context if logger wasn't explicitly bound - if "upstream_session" not in extra_event_fields: + if ( + "upstream_session" not in extra_event_fields + and _default_bound_context is not None + ): try: - from mcp_agent.logging.logger import _default_bound_context as _dbc # type: ignore - - if _dbc is not None: - _up = getattr(_dbc, "upstream_session", None) - if _up is not None: - extra_event_fields["upstream_session"] = _up + upstream = getattr(_default_bound_context, "upstream_session", None) + if upstream is not None: + extra_event_fields["upstream_session"] = upstream except Exception: pass diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index e07db0d80..af7d8f34b 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -4,6 +4,7 @@ import httpx from mcp_agent.mcp.mcp_server_registry import ServerRegistry +from urllib.parse import quote def _resolve_gateway_url( @@ -22,7 +23,7 @@ def _resolve_gateway_url( # Next: a registry entry (if provided) if server_registry and server_name: - cfg = server_registry.get_server_context(server_name) + cfg = server_registry.get_server_config(server_name) if cfg and getattr(cfg, "url", None): return cfg.url.rstrip("/") @@ -49,22 +50,28 @@ async def log_via_proxy( if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) - async with httpx.AsyncClient(timeout=timeout) as client: - r = await client.post( - url, - json={ - "execution_id": execution_id, - "level": level, - "namespace": namespace, - "message": message, - "data": data or {}, - }, - headers=headers, - ) - if r.status_code >= 400: - return False - resp = r.json() - return bool(resp.get("ok", False)) + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, + json={ + "execution_id": execution_id, + "level": level, + "namespace": namespace, + "message": message, + "data": data or {}, + }, + headers=headers, + ) + except httpx.RequestError: + return False + if r.status_code >= 400: + return False + try: + resp = r.json() if r.content else {"ok": True} + except ValueError: + resp = {"ok": True} + return bool(resp.get("ok", True)) async def ask_via_proxy( @@ -84,19 +91,25 @@ async def ask_via_proxy( if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) - async with httpx.AsyncClient(timeout=timeout) as client: - r = await client.post( - url, - json={ - "execution_id": execution_id, - "prompt": {"text": prompt}, - "metadata": metadata or {}, - }, - headers=headers, - ) - if r.status_code >= 400: - return {"error": r.text} - return r.json() + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, + json={ + "execution_id": execution_id, + "prompt": {"text": prompt}, + "metadata": metadata or {}, + }, + headers=headers, + ) + except httpx.RequestError: + return {"error": "request_failed"} + if r.status_code >= 400: + return {"error": r.text} + try: + return r.json() if r.content else {"error": "invalid_response"} + except ValueError: + return {"error": "invalid_response"} async def notify_via_proxy( @@ -110,21 +123,27 @@ async def notify_via_proxy( gateway_token: Optional[str] = None, ) -> bool: base = _resolve_gateway_url(server_registry, server_name, gateway_url) - url = f"{base}/internal/session/by-run/{execution_id}/notify" + url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/notify" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) - async with httpx.AsyncClient(timeout=timeout) as client: - r = await client.post( - url, json={"method": method, "params": params or {}}, headers=headers - ) - if r.status_code >= 400: - return False + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, json={"method": method, "params": params or {}}, headers=headers + ) + except httpx.RequestError: + return False + if r.status_code >= 400: + return False + try: resp = r.json() if r.content else {"ok": True} - return bool(resp.get("ok", True)) + except ValueError: + resp = {"ok": True} + return bool(resp.get("ok", True)) async def request_via_proxy( @@ -138,16 +157,22 @@ async def request_via_proxy( gateway_token: Optional[str] = None, ) -> Dict[str, Any]: base = _resolve_gateway_url(server_registry, server_name, gateway_url) - url = f"{base}/internal/session/by-run/{execution_id}/request" + url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "20")) - async with httpx.AsyncClient(timeout=timeout) as client: - r = await client.post( - url, json={"method": method, "params": params or {}}, headers=headers - ) - if r.status_code >= 400: - return {"error": r.text} - return r.json() + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, json={"method": method, "params": params or {}}, headers=headers + ) + except httpx.RequestError: + return {"error": "request_failed"} + if r.status_code >= 400: + return {"error": r.text} + try: + return r.json() if r.content else {"error": "invalid_response"} + except ValueError: + return {"error": "invalid_response"} diff --git a/src/mcp_agent/tracing/token_counter.py b/src/mcp_agent/tracing/token_counter.py index ec794a746..4a1bb3728 100644 --- a/src/mcp_agent/tracing/token_counter.py +++ b/src/mcp_agent/tracing/token_counter.py @@ -827,7 +827,7 @@ async def record_usage( try: from temporalio import workflow as _twf # type: ignore - if _twf._Runtime.current(): # type: ignore[attr-defined] + if _twf.in_workflow(): if _twf.unsafe.is_replaying(): # type: ignore[attr-defined] return except Exception: diff --git a/src/mcp_agent/tracing/token_tracking_decorator.py b/src/mcp_agent/tracing/token_tracking_decorator.py index 491050e41..35833397e 100644 --- a/src/mcp_agent/tracing/token_tracking_decorator.py +++ b/src/mcp_agent/tracing/token_tracking_decorator.py @@ -38,7 +38,7 @@ async def wrapper(self, *args, **kwargs) -> T: try: from temporalio import workflow as _twf # type: ignore - if _twf._Runtime.current(): # type: ignore[attr-defined] + if _twf.in_workflow(): is_temporal_replay = _twf.unsafe.is_replaying() # type: ignore[attr-defined] except Exception: is_temporal_replay = False From 766de68d2b27225a9a52e5894709e2c71dd58af9 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Sep 2025 16:48:47 -0400 Subject: [PATCH 23/24] fix tests --- .../temporal/test_execution_id_and_interceptor.py | 6 +++--- tests/executor/temporal/test_signal_handler.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/executor/temporal/test_execution_id_and_interceptor.py b/tests/executor/temporal/test_execution_id_and_interceptor.py index cfee00089..7aa5f5cb5 100644 --- a/tests/executor/temporal/test_execution_id_and_interceptor.py +++ b/tests/executor/temporal/test_execution_id_and_interceptor.py @@ -1,11 +1,11 @@ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import patch @pytest.mark.asyncio @patch("temporalio.workflow.info") -@patch("temporalio.workflow._Runtime.current", return_value=MagicMock()) -def test_get_execution_id_in_workflow(mock_runtime, mock_info): +@patch("temporalio.workflow.in_workflow", return_value=True) +def test_get_execution_id_in_workflow(_mock_in_wf, mock_info): from mcp_agent.executor.temporal.temporal_context import get_execution_id mock_info.return_value.run_id = "run-123" diff --git a/tests/executor/temporal/test_signal_handler.py b/tests/executor/temporal/test_signal_handler.py index 2b898bc30..6459aea09 100644 --- a/tests/executor/temporal/test_signal_handler.py +++ b/tests/executor/temporal/test_signal_handler.py @@ -53,8 +53,8 @@ def test_attach_to_workflow(handler, mock_workflow): @pytest.mark.asyncio -@patch("temporalio.workflow._Runtime.current", return_value=MagicMock()) -async def test_wait_for_signal(mock_runtime, handler, mock_workflow): +@patch("temporalio.workflow.in_workflow", return_value=True) +async def test_wait_for_signal(_mock_in_wf, handler, mock_workflow): handler.attach_to_workflow(mock_workflow) # Patch the handler's ContextVar to point to the mock_workflow's mailbox handler._mailbox_ref.set(mock_workflow._signal_mailbox) @@ -66,7 +66,7 @@ async def test_wait_for_signal(mock_runtime, handler, mock_workflow): @pytest.mark.asyncio -@patch("temporalio.workflow._Runtime.current", return_value=None) +@patch("temporalio.workflow.in_workflow", return_value=False) @patch( "temporalio.workflow.get_external_workflow_handle", side_effect=__import__("temporalio.workflow").workflow._NotInWorkflowEventLoopError( @@ -74,7 +74,7 @@ async def test_wait_for_signal(mock_runtime, handler, mock_workflow): ), ) async def test_signal_outside_workflow( - mock_get_external, mock_runtime, handler, mock_executor + mock_get_external, _mock_in_wf, handler, mock_executor ): signal = Signal( name="test_signal", From 7b00d957ea72024ae4b519c5ccaf53bd93333192 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Sep 2025 16:58:44 -0400 Subject: [PATCH 24/24] More fixes --- src/mcp_agent/logging/logger.py | 3 ++- src/mcp_agent/server/app_server.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index 5c70851ac..194c0cc45 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -8,6 +8,7 @@ """ import asyncio +from datetime import timedelta import threading import time @@ -171,7 +172,7 @@ async def _forward_via_proxy(): ns, msg, data, - schedule_to_close_timeout=5, + schedule_to_close_timeout=timedelta(seconds=5), ) return except Exception: diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 3ac614276..9d0a38433 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -8,6 +8,7 @@ from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING import os +import secrets import asyncio from mcp.server.fastmcp import Context as MCPContext, FastMCP @@ -334,7 +335,9 @@ async def _relay_notify(request: Request): # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token and request.headers.get("X-MCP-Gateway-Token") != gw_token: + if gw_token and not secrets.compare_digest( + request.headers.get("X-MCP-Gateway-Token", ""), gw_token + ): return JSONResponse( {"ok": False, "error": "unauthorized"}, status_code=401 ) @@ -496,7 +499,9 @@ async def _internal_workflows_log(request: Request): # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token and request.headers.get("X-MCP-Gateway-Token") != gw_token: + if gw_token and not secrets.compare_digest( + request.headers.get("X-MCP-Gateway-Token", ""), gw_token + ): return JSONResponse( {"ok": False, "error": "unauthorized"}, status_code=401 ) @@ -533,7 +538,9 @@ async def _internal_human_prompts(request: Request): # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token and request.headers.get("X-MCP-Gateway-Token") != gw_token: + if gw_token and not secrets.compare_digest( + request.headers.get("X-MCP-Gateway-Token", ""), gw_token + ): return JSONResponse({"error": "unauthorized"}, status_code=401) session = await _get_session(execution_id)