Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions examples/human_input/temporal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
from mcp.types import CallToolResult, LoggingMessageNotificationParams
from mcp_agent.human_input.console_handler import console_input_callback

try:
from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport
except Exception: # pragma: no cover
Expand Down Expand Up @@ -119,10 +120,10 @@ async def _received_notification(self, notification): # type: ignore[override]
return await super()._received_notification(notification)

def make_session(
read_stream: MemoryObjectReceiveStream,
write_stream: MemoryObjectSendStream,
read_timeout_seconds: timedelta | None,
context: Context | None = None,
read_stream: MemoryObjectReceiveStream,
write_stream: MemoryObjectSendStream,
read_timeout_seconds: timedelta | None,
context: Context | None = None,
) -> ClientSession:
return ConsolePrintingClientSession(
read_stream=read_stream,
Expand All @@ -134,9 +135,9 @@ def make_session(

# Connect to the workflow server
async with gen_client(
"basic_agent_server",
context.server_registry,
client_session_factory=make_session,
"basic_agent_server",
context.server_registry,
client_session_factory=make_session,
) as server:
# Ask server to send logs at the requested level (default info)
level = "info"
Expand All @@ -148,25 +149,22 @@ def make_session(
print("[client] Server does not support logging/setLevel")

# Call the `greet` tool defined via `@app.tool`
run_result = await server.call_tool(
"greet",
arguments={}
)
run_result = await server.call_tool("greet", arguments={})
print(f"[client] Workflow run result: {run_result}")
except Exception as e:
# Tolerate benign shutdown races from SSE client (BrokenResourceError within ExceptionGroup)
if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup):
subs = getattr(e, "exceptions", []) or []
if (
_BrokenResourceError is not None
and subs
and all(isinstance(se, _BrokenResourceError) for se in subs)
_BrokenResourceError is not None
and subs
and all(isinstance(se, _BrokenResourceError) for se in subs)
):
logger.debug("Ignored BrokenResourceError from SSE shutdown")
else:
raise
elif _BrokenResourceError is not None and isinstance(
e, _BrokenResourceError
e, _BrokenResourceError
):
logger.debug("Ignored BrokenResourceError from SSE shutdown")
elif "BrokenResourceError" in str(e):
Expand Down
1 change: 1 addition & 0 deletions examples/human_input/temporal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
when running in Temporal workflows by routing requests through the MCP
elicitation framework instead of direct console I/O.
"""

import asyncio
from mcp_agent.app import MCPApp
from mcp_agent.human_input.elicitation_handler import elicitation_input_callback
Expand Down
45 changes: 18 additions & 27 deletions examples/mcp/mcp_elicitation/temporal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ async def _received_notification(self, notification): # type: ignore[override]
return await super()._received_notification(notification)

def make_session(
read_stream: MemoryObjectReceiveStream,
write_stream: MemoryObjectSendStream,
read_timeout_seconds: timedelta | None,
context: Context | None = None,
read_stream: MemoryObjectReceiveStream,
write_stream: MemoryObjectSendStream,
read_timeout_seconds: timedelta | None,
context: Context | None = None,
) -> ClientSession:
return ConsolePrintingClientSession(
read_stream=read_stream,
Expand All @@ -138,9 +138,9 @@ def make_session(

# Connect to the workflow server
async with gen_client(
"basic_agent_server",
context.server_registry,
client_session_factory=make_session,
"basic_agent_server",
context.server_registry,
client_session_factory=make_session,
) as server:
# Ask server to send logs at the requested level (default info)
level = "info"
Expand All @@ -154,42 +154,33 @@ def make_session(
# Call the `book_table` tool defined via `@app.tool`
run_result = await server.call_tool(
"book_table",
arguments={
"date": "today",
"party_size": 2,
"topic": "autumn"
},
arguments={"date": "today", "party_size": 2, "topic": "autumn"},
)
print(f"[client] Workflow run result: {run_result}")

# Run the `TestWorkflow` workflow...
run_result = await server.call_tool(
"workflows-TestWorkflow-run",
arguments={
"run_parameters":{
"args":{
"run_parameters": {
"args": {
"date": "today",
"party_size": 2,
"topic": "autumn"
"topic": "autumn",
}
}
}
},
)

execution = WorkflowExecution(
**json.loads(run_result.content[0].text)
)
execution = WorkflowExecution(**json.loads(run_result.content[0].text))
run_id = execution.run_id
workflow_id = execution.workflow_id

# and wait for execution to complete
while True:
get_status_result = await server.call_tool(
"workflows-get_status",
arguments={
"run_id": run_id,
"workflow_id": workflow_id
},
arguments={"run_id": run_id, "workflow_id": workflow_id},
)

workflow_status = _tool_result_to_json(get_status_result)
Expand Down Expand Up @@ -248,15 +239,15 @@ def make_session(
if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup):
subs = getattr(e, "exceptions", []) or []
if (
_BrokenResourceError is not None
and subs
and all(isinstance(se, _BrokenResourceError) for se in subs)
_BrokenResourceError is not None
and subs
and all(isinstance(se, _BrokenResourceError) for se in subs)
):
logger.debug("Ignored BrokenResourceError from SSE shutdown")
else:
raise
elif _BrokenResourceError is not None and isinstance(
e, _BrokenResourceError
e, _BrokenResourceError
):
logger.debug("Ignored BrokenResourceError from SSE shutdown")
elif "BrokenResourceError" in str(e):
Expand Down
6 changes: 1 addition & 5 deletions examples/mcp/mcp_elicitation/temporal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = MCPApp(
name="elicitation_demo",
description="Demo of workflow with elicitation"
)
app = MCPApp(name="elicitation_demo", description="Demo of workflow with elicitation")


@app.tool()
Expand Down Expand Up @@ -61,7 +58,6 @@ class ConfirmBooking(BaseModel):

@app.workflow
class TestWorkflow(Workflow[str]):

@app.workflow_run
async def run(self, args: Dict[str, Any]) -> WorkflowResult[str]:
app_ctx = app.context
Expand Down
91 changes: 91 additions & 0 deletions src/mcp_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,42 @@ def _should_include_non_namespaced_tool(
# No non_namespaced_tools key and no wildcard - include by default (no filter for non-namespaced)
return True, None

async def _sync_with_aggregator_state(self) -> None:
if self._agent_tasks is None:
return

executor = self.context.executor if self.context else None
if executor is None:
return

response = await executor.execute(
self._agent_tasks.get_aggregator_state_task,
GetAggregatorStateRequest(agent_name=self.name),
)

if isinstance(response, BaseException): # pragma: no cover - defensive
raise response

self.initialized = response.initialized

self._namespaced_tool_map.clear()
self._namespaced_tool_map.update(response.namespaced_tool_map)

self._server_to_tool_map.clear()
self._server_to_tool_map.update(response.server_to_tool_map)

self._namespaced_prompt_map.clear()
self._namespaced_prompt_map.update(response.namespaced_prompt_map)

self._server_to_prompt_map.clear()
self._server_to_prompt_map.update(response.server_to_prompt_map)

self._namespaced_resource_map.clear()
self._namespaced_resource_map.update(response.namespaced_resource_map)

self._server_to_resource_map.clear()
self._server_to_resource_map.update(response.server_to_resource_map)

async def list_tools(
self,
server_name: str | None = None,
Expand All @@ -508,6 +544,8 @@ async def list_tools(
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.list_tools"
Expand Down Expand Up @@ -731,6 +769,8 @@ async def list_resources(
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.list_resources"
Expand All @@ -754,6 +794,8 @@ async def read_resource(self, uri: str, server_name: str | None = None):
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.read_resource"
Expand Down Expand Up @@ -871,6 +913,8 @@ async def list_prompts(self, server_name: str | None = None) -> ListPromptsResul
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.list_prompts"
Expand Down Expand Up @@ -919,6 +963,8 @@ async def get_prompt(
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.get_prompt"
Expand Down Expand Up @@ -1077,6 +1123,8 @@ async def call_tool(
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.call_tool"
Expand Down Expand Up @@ -1194,6 +1242,14 @@ class InitAggregatorResponse(BaseModel):
)


class GetAggregatorStateRequest(BaseModel):
"""
Request to fetch the current cached state from an agent's aggregator.
"""

agent_name: str


class ListToolsRequest(BaseModel):
"""
Request to list tools for an agent.
Expand Down Expand Up @@ -1435,6 +1491,41 @@ async def initialize_aggregator_task(
server_to_resource_map=aggregator._server_to_resource_map,
)

async def get_aggregator_state_task(
self, request: GetAggregatorStateRequest
) -> InitAggregatorResponse:
async with self._with_aggregator(request.agent_name) as aggregator:
async with aggregator._tool_map_lock:
namespaced_tool_map = dict(aggregator._namespaced_tool_map)
server_to_tool_map = {
server: list(tools)
for server, tools in aggregator._server_to_tool_map.items()
}

async with aggregator._prompt_map_lock:
namespaced_prompt_map = dict(aggregator._namespaced_prompt_map)
server_to_prompt_map = {
server: list(prompts)
for server, prompts in aggregator._server_to_prompt_map.items()
}

async with aggregator._resource_map_lock:
namespaced_resource_map = dict(aggregator._namespaced_resource_map)
server_to_resource_map = {
server: list(resources)
for server, resources in aggregator._server_to_resource_map.items()
}

return InitAggregatorResponse(
initialized=aggregator.initialized,
namespaced_tool_map=namespaced_tool_map,
server_to_tool_map=server_to_tool_map,
namespaced_prompt_map=namespaced_prompt_map,
server_to_prompt_map=server_to_prompt_map,
namespaced_resource_map=namespaced_resource_map,
server_to_resource_map=server_to_resource_map,
)

async def shutdown_aggregator_task(self, agent_name: str) -> bool:
"""
Shutdown the agent's servers.
Expand Down
4 changes: 2 additions & 2 deletions src/mcp_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,9 @@ async def _run(self, *args, **kwargs): # type: ignore[no-redef]
# decorate the run method with the engine-specific run decorator.
if engine_type == "temporal":
try:
run_decorator = (self._decorator_registry.get_workflow_run_decorator(
run_decorator = self._decorator_registry.get_workflow_run_decorator(
engine_type
))
)
if run_decorator:
fn_run = getattr(auto_cls, "run")
# Ensure method appears as top-level for Temporal
Expand Down
12 changes: 6 additions & 6 deletions src/mcp_agent/executor/temporal/session_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def notify(self, method: str, params: Dict[str, Any] | None = None) -> boo
return True

async def request(
self, method: str, params: Dict[str, Any] | None = None
self, method: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
"""Send a server->client request and return the client's response.
The result is a plain JSON-serializable dict.
Expand Down Expand Up @@ -322,10 +322,10 @@ async def create_message(
raise RuntimeError(f"sampling/createMessage returned invalid result: {e}")

async def elicit(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
) -> types.ElicitResult:
params: Dict[str, Any] = {
"message": message,
Expand Down Expand Up @@ -358,6 +358,6 @@ async def notify(self, method: str, params: Dict[str, Any] | None = None) -> Non
await self._proxy.notify(method, params or {})

async def request(
self, method: str, params: Dict[str, Any] | None = None
self, method: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
return await self._proxy.request(method, params or {})
Loading
Loading