Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions examples/usecases/reliable_conversation/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
from models.conversation_models import ConversationState
from utils.test_runner import create_test_runner
from utils.progress_reporter import ProgressReporter, set_progress_reporter
from mcp_agent.executor.workflow import WorkflowResult
from typing import Dict, Any


# Define TestConversationWorkflow at module level to avoid "local class" error
# This will be registered dynamically with specific apps in tests
class TestConversationWorkflowTemplate(ConversationWorkflow):
"""Test workflow template - can be dynamically registered"""

def __init__(self, app):
super().__init__(app)

async def run(self, args: Dict[str, Any]) -> WorkflowResult[Dict[str, Any]]:
return await super().run(args)


def patch_llm_interactions():
Expand Down Expand Up @@ -91,12 +105,11 @@ async def test_rcm_with_real_calls():
# Create app using canonical mcp-agent pattern (loads config files automatically)
app = MCPApp(name="rcm_test")

# Register workflow
@app.workflow
class TestConversationWorkflow(ConversationWorkflow):
"""Test workflow registered with app"""

pass
# Set execution engine to asyncio to avoid Temporal decoration requirements
app.config.execution_engine = "asyncio"

# Register the workflow using the simple decorator approach for AsyncIO
TestConversationWorkflow = app.workflow(TestConversationWorkflowTemplate)

try:
async with app.run() as test_app:
Expand Down
21 changes: 21 additions & 0 deletions src/mcp_agent/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,27 @@ async def wait_for_signal(
)
return await self.signal_bus.wait_for_signal(signal)

async def wait_for_any_signal(
self,
signal_names: List[str],
workflow_id: str,
run_id: str | None = None,
timeout_seconds: int | None = None,
) -> Signal:
"""
Waits for any of a list of signals. This is a convenience wrapper
around the signal bus implementation.
"""
if not self.signal_bus:
raise RuntimeError("Signal bus is not initialized for this executor.")

return await self.signal_bus.wait_for_any_signal(
signal_names=signal_names,
workflow_id=workflow_id,
run_id=run_id,
timeout_seconds=timeout_seconds,
)

def uuid(self) -> uuid.UUID:
"""
Generate a UUID. Some executors enforce deterministic UUIDs, so this is an
Expand Down
69 changes: 66 additions & 3 deletions src/mcp_agent/executor/temporal/workflow_signal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Callable, Dict, Generic, Optional, TYPE_CHECKING
from typing import Any, Callable, Dict, Generic, List, Optional, TYPE_CHECKING

from temporalio import exceptions, workflow

Expand Down Expand Up @@ -177,6 +178,68 @@ async def wait_for_signal(
except exceptions.TimeoutError as e:
raise TimeoutError(f"Timeout waiting for signal {signal.name}") from e

async def wait_for_any_signal(
self,
signal_names: List[str],
workflow_id: str,
run_id: str | None = None,
timeout_seconds: int | None = None
) -> Signal[SignalValueT]:
"""
Waits for any of the specified signals using Temporal-safe primitives.
"""
if not workflow._Runtime.current():
raise RuntimeError("wait_for_any_signal must be called from within a Temporal workflow")

# Get the mailbox safely from ContextVar
mailbox = self._mailbox_ref.get()
if mailbox is None:
raise RuntimeError(
"Signal mailbox not initialized for this workflow. Please call attach_to_workflow first."
)

# Get current versions for all signals
current_versions = {name: mailbox.version(name) for name in signal_names}

logger.debug(
f"SignalMailbox.wait_for_any_signal: signal_names={signal_names}, current_versions={current_versions}"
)

# Wait for any signal to have a new version
def any_signal_updated():
for name in signal_names:
if mailbox.version(name) > current_versions[name]:
return True
return False

try:
await workflow.wait_condition(
any_signal_updated,
timeout=timedelta(seconds=timeout_seconds) if timeout_seconds else None,
)

# Find which signal was updated
for name in signal_names:
if mailbox.version(name) > current_versions[name]:
# Just get the value directly like wait_for_signal does
payload = mailbox.value(name)

logger.debug(
f"SignalMailbox.wait_for_any_signal returned: name={name}, val={payload}"
)
return Signal(
name=name,
payload=payload,
workflow_id=workflow_id,
run_id=run_id or workflow.info().run_id
)

# Should not reach here
raise RuntimeError("wait_condition returned but no signal was found")

except exceptions.TimeoutError as e:
raise asyncio.TimeoutError(f"Timeout waiting for signals: {signal_names}") from e

def on_signal(self, signal_name: str):
"""
Decorator that registers a callback for a signal.
Expand Down Expand Up @@ -228,7 +291,7 @@ async def signal(self, signal: Signal[SignalValueT]) -> None:
def validate_signal(self, signal):
super().validate_signal(signal)
# Add TemporalSignalHandler-specific validation
if signal.workflow_id is None or signal.run_id is None:
if not signal.workflow_id:
raise ValueError(
"No workflow_id or run_id provided on Signal. That is required for Temporal signals"
"A workflow_id must be provided on a Signal for Temporal signals"
)
26 changes: 24 additions & 2 deletions src/mcp_agent/executor/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,32 @@ async def _signal_receiver(self, name: str, args: Sequence[RawValue]):
self._logger.debug(f"Dynamic signal received: name={name}, args={args}")

# Extract payload and update mailbox
payload = args[0] if args else None
raw_payload = args[0] if args else None

# Deserialize the RawValue to get the actual content
actual_payload = raw_payload
if raw_payload and hasattr(raw_payload, 'payload'):
try:
from temporalio.converter import default_converter
# Use Temporal's converter to deserialize the payload
actual_payload = default_converter.from_payloads([raw_payload.payload])[0]
except Exception as e:
self._logger.error(f"Failed to deserialize signal payload: {e}")
# Fallback: try to extract JSON data directly
try:
import json
if hasattr(raw_payload.payload, 'data'):
# Decode the raw bytes and parse as JSON
json_str = raw_payload.payload.data.decode('utf-8')
actual_payload = json.loads(json_str)
else:
actual_payload = str(raw_payload)
except Exception as e2:
self._logger.error(f"Fallback deserialization also failed: {e2}")
actual_payload = raw_payload

Comment on lines +374 to 397
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined variable payload breaks the build

payload was removed in favour of actual_payload, but it is still referenced later when constructing Signal, causing F821 (payload undefined).

-                sig_obj = Signal(
-                    name=name,
-                    payload=payload,
+                sig_obj = Signal(
+                    name=name,
+                    payload=actual_payload,

Apply the change above (and any similar occurrences) to restore type-checking & runtime correctness.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/mcp_agent/executor/workflow.py around lines 374 to 397, the variable
`payload` is referenced after being removed and replaced by `actual_payload`,
causing an undefined variable error. Replace all occurrences of `payload` with
`actual_payload` in this section and any similar places to ensure the code uses
the correct variable and passes type-checking and runs correctly.

if hasattr(self, "_signal_mailbox"):
self._signal_mailbox.push(name, payload)
self._signal_mailbox.push(name, actual_payload)
self._logger.debug(f"Updated mailbox for signal {name}")
else:
self._logger.warning("No _signal_mailbox found on workflow instance")
Expand Down
144 changes: 144 additions & 0 deletions src/mcp_agent/executor/workflow_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,34 @@ async def wait_for_signal(
) -> SignalValueT:
"""Wait for a signal to be emitted."""

@abstractmethod
async def wait_for_any_signal(
self,
signal_names: List[str],
workflow_id: str,
run_id: str | None = None,
timeout_seconds: int | None = None,
) -> Signal[SignalValueT]:
"""
Waits for any of a list of signals and returns the one that fired.

This method is essential for workflows that need to react to multiple
different events concurrently.

Args:
signal_names: A list of signal names to wait for.
workflow_id: The ID of the workflow instance to listen on.
run_id: Optional specific run ID of the workflow.
timeout_seconds: Optional timeout for waiting.

Returns:
A Signal object containing the name and payload of the first signal received.

Raises:
asyncio.TimeoutError: If the timeout is reached.
"""
...

def on_signal(self, signal_name: str) -> Callable:
"""
Decorator to register a handler for a signal.
Expand Down Expand Up @@ -201,6 +229,45 @@ async def wrapped(value: SignalValueT):

return decorator

async def wait_for_any_signal(
self,
signal_names: List[str],
workflow_id: str,
run_id: str | None = None,
timeout_seconds: int | None = None
) -> Signal[SignalValueT]:
"""
Wait for any of the specified signals using console input.
Note: This is a simplified implementation for console-based workflows.
"""
# For console handler, we'll just wait for the first signal name entered
loop = asyncio.get_event_loop()
if timeout_seconds is not None:
try:
signal_name = await asyncio.wait_for(
loop.run_in_executor(None, input, f"Enter signal name ({', '.join(signal_names)}): "),
timeout_seconds
)
except asyncio.TimeoutError:
print("\nTimeout waiting for input")
raise
else:
signal_name = await loop.run_in_executor(None, input, f"Enter signal name ({', '.join(signal_names)}): ")

# Validate the signal name
if signal_name not in signal_names:
raise ValueError(f"Invalid signal name: {signal_name}. Expected one of: {signal_names}")

# Get the payload
payload = await loop.run_in_executor(None, input, f"Enter payload for {signal_name}: ")

return Signal(
name=signal_name,
payload=payload,
workflow_id=workflow_id,
run_id=run_id
)

async def signal(self, signal):
print(f"[SIGNAL SENT: {signal.name}] Value: {signal.payload}")

Expand Down Expand Up @@ -279,6 +346,83 @@ async def wrapped(value: SignalValueT):

return decorator

async def wait_for_any_signal(
self,
signal_names: List[str],
workflow_id: str,
run_id: str | None = None,
timeout_seconds: int | None = None
) -> Signal[SignalValueT]:
"""
Waits for any of a list of signals using asyncio primitives.
"""
# Create an event and a registration for each signal
pending_signals: List[PendingSignal] = []
waiter_tasks: List[asyncio.Task] = []

async with self._lock:
for name in signal_names:
event = asyncio.Event()
unique_name = f"{name}_{uuid.uuid4()}"
registration = SignalRegistration(
signal_name=name,
unique_name=unique_name,
workflow_id=workflow_id,
run_id=run_id,
)
pending = PendingSignal(registration=registration, event=event)
pending_signals.append(pending)
self._pending_signals.setdefault(name, []).append(pending)
waiter_tasks.append(asyncio.create_task(event.wait()))

try:
# Wait for any of the events to be set
done, pending = await asyncio.wait(
waiter_tasks,
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout_seconds,
)

if not done:
raise asyncio.TimeoutError(f"Timeout waiting for signals: {signal_names}")

# Find which pending signal corresponds to the completed task
completed_task = done.pop()
triggered_pending_signal = None
for i, task in enumerate(waiter_tasks):
if task is completed_task:
triggered_pending_signal = pending_signals[i]
break

if not triggered_pending_signal:
# Should not happen
raise RuntimeError("Could not identify which signal was triggered.")

return Signal(
name=triggered_pending_signal.registration.signal_name,
payload=triggered_pending_signal.value,
workflow_id=workflow_id,
run_id=run_id
)

Comment on lines +349 to +407
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Edge-case: empty signal_names yields confusing error.

If signal_names is empty we fall through to asyncio.wait([], …)ValueError("Set of coroutines/Futures is empty").
Consider raising a domain-specific ValueError earlier:

+        if not signal_names:
+            raise ValueError("signal_names must contain at least one signal to wait for")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def wait_for_any_signal(
self,
signal_names: List[str],
workflow_id: str,
run_id: str | None = None,
timeout_seconds: int | None = None
) -> Signal[SignalValueT]:
"""
Waits for any of a list of signals using asyncio primitives.
"""
# Create an event and a registration for each signal
pending_signals: List[PendingSignal] = []
waiter_tasks: List[asyncio.Task] = []
async with self._lock:
for name in signal_names:
event = asyncio.Event()
unique_name = f"{name}_{uuid.uuid4()}"
registration = SignalRegistration(
signal_name=name,
unique_name=unique_name,
workflow_id=workflow_id,
run_id=run_id,
)
pending = PendingSignal(registration=registration, event=event)
pending_signals.append(pending)
self._pending_signals.setdefault(name, []).append(pending)
waiter_tasks.append(asyncio.create_task(event.wait()))
try:
# Wait for any of the events to be set
done, pending = await asyncio.wait(
waiter_tasks,
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout_seconds,
)
if not done:
raise asyncio.TimeoutError(f"Timeout waiting for signals: {signal_names}")
# Find which pending signal corresponds to the completed task
completed_task = done.pop()
triggered_pending_signal = None
for i, task in enumerate(waiter_tasks):
if task is completed_task:
triggered_pending_signal = pending_signals[i]
break
if not triggered_pending_signal:
# Should not happen
raise RuntimeError("Could not identify which signal was triggered.")
return Signal(
name=triggered_pending_signal.registration.signal_name,
payload=triggered_pending_signal.value,
workflow_id=workflow_id,
run_id=run_id
)
async def wait_for_any_signal(
self,
signal_names: List[str],
workflow_id: str,
run_id: str | None = None,
timeout_seconds: int | None = None
) -> Signal[SignalValueT]:
"""
Waits for any of a list of signals using asyncio primitives.
"""
if not signal_names:
raise ValueError("signal_names must contain at least one signal to wait for")
# Create an event and a registration for each signal
pending_signals: List[PendingSignal] = []
waiter_tasks: List[asyncio.Task] = []
async with self._lock:
for name in signal_names:
event = asyncio.Event()
unique_name = f"{name}_{uuid.uuid4()}"
registration = SignalRegistration(
signal_name=name,
unique_name=unique_name,
workflow_id=workflow_id,
run_id=run_id,
)
pending = PendingSignal(registration=registration, event=event)
pending_signals.append(pending)
self._pending_signals.setdefault(name, []).append(pending)
waiter_tasks.append(asyncio.create_task(event.wait()))
try:
# Wait for any of the events to be set
done, pending = await asyncio.wait(
waiter_tasks,
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout_seconds,
)
if not done:
raise asyncio.TimeoutError(f"Timeout waiting for signals: {signal_names}")
# Find which pending signal corresponds to the completed task
completed_task = done.pop()
triggered_pending_signal = None
for i, task in enumerate(waiter_tasks):
if task is completed_task:
triggered_pending_signal = pending_signals[i]
break
if not triggered_pending_signal:
# Should not happen
raise RuntimeError("Could not identify which signal was triggered.")
return Signal(
name=triggered_pending_signal.registration.signal_name,
payload=triggered_pending_signal.value,
workflow_id=workflow_id,
run_id=run_id
)
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 349-349: Too many local variables (18/15)

(R0914)

🤖 Prompt for AI Agents
In src/mcp_agent/executor/workflow_signal.py around lines 349 to 407, the method
wait_for_any_signal does not handle the case when the input list signal_names is
empty, causing asyncio.wait to raise a ValueError with a confusing message. To
fix this, add an explicit check at the start of the method to raise a ValueError
with a clear, domain-specific message if signal_names is empty, preventing the
confusing error from asyncio.wait.

finally:
# Cleanup all waiters for this call
for task in waiter_tasks:
if not task.done():
task.cancel()

async with self._lock:
for pending_signal in pending_signals:
name = pending_signal.registration.signal_name
unique_name = pending_signal.registration.unique_name
if name in self._pending_signals:
self._pending_signals[name] = [
p for p in self._pending_signals[name]
if p.registration.unique_name != unique_name
]
if not self._pending_signals[name]:
del self._pending_signals[name]

async def signal(self, signal):
async with self._lock:
# Notify any waiting coroutines
Expand Down
5 changes: 3 additions & 2 deletions src/mcp_agent/logging/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,9 @@ async def emit(self, event: Event):
except Exception as e:
print(f"Error in transport.send_event: {e}")

# Then queue for listeners
await self._queue.put(event)
# Then queue for listeners only if the event bus has been started
if hasattr(self, '_queue') and self._queue is not None:
await self._queue.put(event)

def add_listener(self, name: str, listener: EventListener):
"""Add a listener to the event bus."""
Expand Down
3 changes: 2 additions & 1 deletion tests/executor/temporal/test_signal_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def mailbox():
return SignalMailbox()


def test_push_and_version(mailbox):
@patch('mcp_agent.executor.temporal.workflow_signal.logger')
def test_push_and_version(mock_logger, mailbox):
mailbox.push("signal1", "value1")
assert mailbox.version("signal1") == 1
assert mailbox.value("signal1") == "value1"
Expand Down
Loading
Loading