Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions examples/human_input/temporal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
from mcp.types import CallToolResult, LoggingMessageNotificationParams
from mcp_agent.human_input.console_handler import console_input_callback

try:
from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport
except Exception: # pragma: no cover
Expand Down Expand Up @@ -119,10 +120,10 @@ async def _received_notification(self, notification): # type: ignore[override]
return await super()._received_notification(notification)

def make_session(
read_stream: MemoryObjectReceiveStream,
write_stream: MemoryObjectSendStream,
read_timeout_seconds: timedelta | None,
context: Context | None = None,
read_stream: MemoryObjectReceiveStream,
write_stream: MemoryObjectSendStream,
read_timeout_seconds: timedelta | None,
context: Context | None = None,
) -> ClientSession:
return ConsolePrintingClientSession(
read_stream=read_stream,
Expand All @@ -134,9 +135,9 @@ def make_session(

# Connect to the workflow server
async with gen_client(
"basic_agent_server",
context.server_registry,
client_session_factory=make_session,
"basic_agent_server",
context.server_registry,
client_session_factory=make_session,
) as server:
# Ask server to send logs at the requested level (default info)
level = "info"
Expand All @@ -148,25 +149,22 @@ def make_session(
print("[client] Server does not support logging/setLevel")

# Call the `greet` tool defined via `@app.tool`
run_result = await server.call_tool(
"greet",
arguments={}
)
run_result = await server.call_tool("greet", arguments={})
print(f"[client] Workflow run result: {run_result}")
except Exception as e:
# Tolerate benign shutdown races from SSE client (BrokenResourceError within ExceptionGroup)
if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup):
subs = getattr(e, "exceptions", []) or []
if (
_BrokenResourceError is not None
and subs
and all(isinstance(se, _BrokenResourceError) for se in subs)
_BrokenResourceError is not None
and subs
and all(isinstance(se, _BrokenResourceError) for se in subs)
):
logger.debug("Ignored BrokenResourceError from SSE shutdown")
else:
raise
elif _BrokenResourceError is not None and isinstance(
e, _BrokenResourceError
e, _BrokenResourceError
):
logger.debug("Ignored BrokenResourceError from SSE shutdown")
elif "BrokenResourceError" in str(e):
Expand Down
1 change: 1 addition & 0 deletions examples/human_input/temporal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
when running in Temporal workflows by routing requests through the MCP
elicitation framework instead of direct console I/O.
"""

import asyncio
from mcp_agent.app import MCPApp
from mcp_agent.human_input.elicitation_handler import elicitation_input_callback
Expand Down
45 changes: 18 additions & 27 deletions examples/mcp/mcp_elicitation/temporal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ async def _received_notification(self, notification): # type: ignore[override]
return await super()._received_notification(notification)

def make_session(
read_stream: MemoryObjectReceiveStream,
write_stream: MemoryObjectSendStream,
read_timeout_seconds: timedelta | None,
context: Context | None = None,
read_stream: MemoryObjectReceiveStream,
write_stream: MemoryObjectSendStream,
read_timeout_seconds: timedelta | None,
context: Context | None = None,
) -> ClientSession:
return ConsolePrintingClientSession(
read_stream=read_stream,
Expand All @@ -138,9 +138,9 @@ def make_session(

# Connect to the workflow server
async with gen_client(
"basic_agent_server",
context.server_registry,
client_session_factory=make_session,
"basic_agent_server",
context.server_registry,
client_session_factory=make_session,
) as server:
# Ask server to send logs at the requested level (default info)
level = "info"
Expand All @@ -154,42 +154,33 @@ def make_session(
# Call the `book_table` tool defined via `@app.tool`
run_result = await server.call_tool(
"book_table",
arguments={
"date": "today",
"party_size": 2,
"topic": "autumn"
},
arguments={"date": "today", "party_size": 2, "topic": "autumn"},
)
print(f"[client] Workflow run result: {run_result}")

# Run the `TestWorkflow` workflow...
run_result = await server.call_tool(
"workflows-TestWorkflow-run",
arguments={
"run_parameters":{
"args":{
"run_parameters": {
"args": {
"date": "today",
"party_size": 2,
"topic": "autumn"
"topic": "autumn",
}
}
}
},
)

execution = WorkflowExecution(
**json.loads(run_result.content[0].text)
)
execution = WorkflowExecution(**json.loads(run_result.content[0].text))
run_id = execution.run_id
workflow_id = execution.workflow_id

# and wait for execution to complete
while True:
get_status_result = await server.call_tool(
"workflows-get_status",
arguments={
"run_id": run_id,
"workflow_id": workflow_id
},
arguments={"run_id": run_id, "workflow_id": workflow_id},
)

workflow_status = _tool_result_to_json(get_status_result)
Expand Down Expand Up @@ -248,15 +239,15 @@ def make_session(
if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup):
subs = getattr(e, "exceptions", []) or []
if (
_BrokenResourceError is not None
and subs
and all(isinstance(se, _BrokenResourceError) for se in subs)
_BrokenResourceError is not None
and subs
and all(isinstance(se, _BrokenResourceError) for se in subs)
):
logger.debug("Ignored BrokenResourceError from SSE shutdown")
else:
raise
elif _BrokenResourceError is not None and isinstance(
e, _BrokenResourceError
e, _BrokenResourceError
):
logger.debug("Ignored BrokenResourceError from SSE shutdown")
elif "BrokenResourceError" in str(e):
Expand Down
6 changes: 1 addition & 5 deletions examples/mcp/mcp_elicitation/temporal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = MCPApp(
name="elicitation_demo",
description="Demo of workflow with elicitation"
)
app = MCPApp(name="elicitation_demo", description="Demo of workflow with elicitation")


@app.tool()
Expand Down Expand Up @@ -61,7 +58,6 @@ class ConfirmBooking(BaseModel):

@app.workflow
class TestWorkflow(Workflow[str]):

@app.workflow_run
async def run(self, args: Dict[str, Any]) -> WorkflowResult[str]:
app_ctx = app.context
Expand Down
4 changes: 2 additions & 2 deletions src/mcp_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,9 @@ async def _run(self, *args, **kwargs): # type: ignore[no-redef]
# decorate the run method with the engine-specific run decorator.
if engine_type == "temporal":
try:
run_decorator = (self._decorator_registry.get_workflow_run_decorator(
run_decorator = self._decorator_registry.get_workflow_run_decorator(
engine_type
))
)
if run_decorator:
fn_run = getattr(auto_cls, "run")
# Ensure method appears as top-level for Temporal
Expand Down
12 changes: 6 additions & 6 deletions src/mcp_agent/executor/temporal/session_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def notify(self, method: str, params: Dict[str, Any] | None = None) -> boo
return True

async def request(
self, method: str, params: Dict[str, Any] | None = None
self, method: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
"""Send a server->client request and return the client's response.
The result is a plain JSON-serializable dict.
Expand Down Expand Up @@ -322,10 +322,10 @@ async def create_message(
raise RuntimeError(f"sampling/createMessage returned invalid result: {e}")

async def elicit(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
) -> types.ElicitResult:
params: Dict[str, Any] = {
"message": message,
Expand Down Expand Up @@ -358,6 +358,6 @@ async def notify(self, method: str, params: Dict[str, Any] | None = None) -> Non
await self._proxy.notify(method, params or {})

async def request(
self, method: str, params: Dict[str, Any] | None = None
self, method: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
return await self._proxy.request(method, params or {})
6 changes: 5 additions & 1 deletion src/mcp_agent/executor/temporal/system_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ async def relay_notify(

@activity.defn(name="mcp_relay_request")
async def relay_request(
self, make_async_call: bool, execution_id: str, method: str, params: Dict[str, Any] | None = None
self,
make_async_call: bool,
execution_id: str,
method: str,
params: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
gateway_url = getattr(self.context, "gateway_url", None)
gateway_token = getattr(self.context, "gateway_token", None)
Expand Down
27 changes: 14 additions & 13 deletions src/mcp_agent/executor/workflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import sys

from abc import ABC, abstractmethod
from datetime import datetime, timezone
Expand All @@ -9,6 +8,7 @@
Generic,
Literal,
Optional,
Sequence,
TypeVar,
TYPE_CHECKING,
)
Expand All @@ -27,6 +27,13 @@
from mcp_agent.core.context import Context
from mcp_agent.executor.temporal import TemporalExecutor

try:
from temporalio import workflow as temporal_workflow
from temporalio.common import RawValue
except ImportError: # Temporal not installed or available in this environment
temporal_workflow = None # type: ignore[assignment]
RawValue = None # type: ignore[assignment]

T = TypeVar("T")


Expand Down Expand Up @@ -423,17 +430,11 @@ async def cancel(self) -> bool:
self._logger.error(f"Error cancelling workflow {self._run_id}: {e}")
return False

# Add the dynamic signal handler method in the case that the workflow is running under Temporal
if "temporalio.workflow" in sys.modules:
from temporalio import workflow
from temporalio.common import RawValue
from typing import Sequence
if temporal_workflow is not None:

@workflow.signal(dynamic=True)
@temporal_workflow.signal(dynamic=True)
async def _signal_receiver(self, name: str, args: Sequence[RawValue]):
"""Dynamic signal handler for Temporal workflows."""
from temporalio import workflow

self._logger.debug(f"Dynamic signal received: name={name}, args={args}")

# Extract payload and update mailbox
Expand All @@ -450,8 +451,8 @@ async def _signal_receiver(self, name: str, args: Sequence[RawValue]):
sig_obj = Signal(
name=name,
payload=payload,
workflow_id=workflow.info().workflow_id,
run_id=workflow.info().run_id,
workflow_id=temporal_workflow.info().workflow_id,
run_id=temporal_workflow.info().run_id,
)

# Live lookup of handlers (enables callbacks added after attach_to_workflow)
Expand All @@ -461,7 +462,7 @@ async def _signal_receiver(self, name: str, args: Sequence[RawValue]):
else:
cb(sig_obj)

@workflow.query(name="token_tree")
@temporal_workflow.query(name="token_tree")
def _query_token_tree(self) -> str:
"""Return a best-effort token usage tree string from the workflow process.

Expand All @@ -481,7 +482,7 @@ def _query_token_tree(self) -> str:
except Exception:
return "(no token usage)"

@workflow.query(name="token_summary")
@temporal_workflow.query(name="token_summary")
def _query_token_summary(self) -> Dict[str, Any]:
"""Return a JSON-serializable token usage summary from the workflow process.

Expand Down
Loading
Loading