Skip to content

Commit 87ab941

Browse files
authored
Break unnecessary Temporal dependency in Workflow and Executor (#405)
* Break Temporal runtime dependency in Workflow and Executor * Fix errorneous import
1 parent baa8547 commit 87ab941

File tree

4 files changed

+87
-78
lines changed

4 files changed

+87
-78
lines changed

src/mcp_agent/executor/temporal/workflow_signal.py

Lines changed: 7 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from contextvars import ContextVar
2-
from dataclasses import dataclass
32
from datetime import timedelta
4-
from typing import Any, Callable, Dict, Generic, Optional, TYPE_CHECKING
3+
from typing import Any, Callable, Optional, TYPE_CHECKING
54

65
from temporalio import exceptions, workflow
76

8-
from mcp_agent.executor.workflow_signal import BaseSignalHandler, Signal, SignalValueT
7+
from mcp_agent.executor.workflow_signal import (
8+
BaseSignalHandler,
9+
Signal,
10+
SignalValueT,
11+
SignalMailbox,
12+
)
913
from mcp_agent.logging.logger import get_logger
1014

1115
if TYPE_CHECKING:
@@ -15,62 +19,6 @@
1519
logger = get_logger(__name__)
1620

1721

18-
@dataclass(slots=True)
19-
class _Record(Generic[SignalValueT]):
20-
"""Record for tracking signal values with versioning for broadcast semantics"""
21-
22-
value: Optional[SignalValueT] = None
23-
version: int = 0 # monotonic counter
24-
25-
26-
class SignalMailbox(Generic[SignalValueT]):
27-
"""
28-
Deterministic broadcast mailbox that stores signal values with versioning.
29-
Each workflow run has its own mailbox instance.
30-
"""
31-
32-
def __init__(self) -> None:
33-
self._store: Dict[str, _Record[SignalValueT]] = {}
34-
35-
def push(self, name: str, value: SignalValueT) -> None:
36-
"""
37-
Store a signal value and increment its version counter.
38-
This enables broadcast semantics where all waiters see the same value.
39-
"""
40-
rec = self._store.setdefault(name, _Record())
41-
rec.value = value
42-
rec.version += 1
43-
44-
logger.debug(
45-
f"SignalMailbox.push: name={name}, value={value}, version={rec.version}"
46-
)
47-
48-
def version(self, name: str) -> int:
49-
"""Get the current version counter for a signal name"""
50-
return self._store.get(name, _Record()).version
51-
52-
def value(self, name: str) -> SignalValueT:
53-
"""
54-
Get the current value for a signal name
55-
56-
Returns:
57-
The signal value
58-
59-
Raises:
60-
ValueError: If no value exists for the signal
61-
"""
62-
value = self._store.get(name, _Record()).value
63-
64-
if value is None:
65-
raise ValueError(f"No value for signal {name}")
66-
67-
logger.debug(
68-
f"SignalMailbox.value: name={name}, value={value}, version={self._store.get(name, _Record()).version}"
69-
)
70-
71-
return value
72-
73-
7422
class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
7523
"""
7624
Temporal-based signal handling using workflow signals.

src/mcp_agent/executor/workflow.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@
1616

1717
from pydantic import BaseModel, ConfigDict, Field
1818
from mcp_agent.core.context_dependent import ContextDependent
19-
from mcp_agent.executor.temporal import TemporalExecutor
20-
from mcp_agent.executor.temporal.workflow_signal import (
21-
SignalMailbox,
22-
TemporalSignalHandler,
23-
)
24-
from mcp_agent.executor.workflow_signal import Signal
19+
from mcp_agent.executor.workflow_signal import Signal, SignalMailbox
2520
from mcp_agent.logging.logger import get_logger
2621

2722
if TYPE_CHECKING:
2823
from temporalio.client import WorkflowHandle
2924
from mcp_agent.core.context import Context
25+
from mcp_agent.executor.temporal import TemporalExecutor
3026

3127
T = TypeVar("T")
3228

@@ -230,7 +226,7 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution":
230226
self._run_id = str(self.executor.uuid())
231227
elif self.context.config.execution_engine == "temporal":
232228
# For Temporal workflows, we'll start the workflow immediately
233-
executor: TemporalExecutor = self.executor
229+
executor: "TemporalExecutor" = self.executor
234230
handle = await executor.start_workflow(
235231
self.name,
236232
*args,
@@ -720,12 +716,22 @@ async def initialize(self):
720716
self._logger.debug(f"Initializing workflow {self.name}")
721717

722718
if self.context.config.execution_engine == "temporal":
723-
if isinstance(self.executor.signal_bus, TemporalSignalHandler):
724-
# Attach the signal handler to the workflow
725-
self.executor.signal_bus.attach_to_workflow(self)
726-
else:
719+
# Lazy import to avoid requiring Temporal unless engine is set to temporal
720+
try:
721+
from mcp_agent.executor.temporal.workflow_signal import (
722+
TemporalSignalHandler,
723+
)
724+
725+
if isinstance(self.executor.signal_bus, TemporalSignalHandler):
726+
# Attach the signal handler to the workflow
727+
self.executor.signal_bus.attach_to_workflow(self)
728+
else:
729+
self._logger.warning(
730+
"Signal handler not attached: executor.signal_bus is not a TemporalSignalHandler"
731+
)
732+
except Exception:
727733
self._logger.warning(
728-
"Signal handler not attached: executor.signal_bus is not a TemporalSignalHandler"
734+
"Signal handler not attached: Temporal support unavailable"
729735
)
730736

731737
self._initialized = True

src/mcp_agent/executor/workflow_signal.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import asyncio
22
import uuid
33
from abc import abstractmethod, ABC
4-
from typing import Any, Callable, Dict, Generic, List, Protocol, TypeVar
4+
from dataclasses import dataclass
5+
from typing import Any, Callable, Dict, Generic, List, Optional, Protocol, TypeVar
56

67
from pydantic import BaseModel, ConfigDict
8+
from mcp_agent.logging.logger import get_logger
79

810
SignalValueT = TypeVar("SignalValueT")
911

10-
# TODO: saqadri - handle signals properly that works with other execution backends like Temporal as well
12+
logger = get_logger(__name__)
1113

1214

1315
class Signal(BaseModel, Generic[SignalValueT]):
@@ -95,6 +97,62 @@ class PendingSignal(BaseModel):
9597
model_config = ConfigDict(arbitrary_types_allowed=True)
9698

9799

100+
@dataclass(slots=True)
101+
class _Record(Generic[SignalValueT]):
102+
"""Record for tracking signal values with versioning for broadcast semantics"""
103+
104+
value: Optional[SignalValueT] = None
105+
version: int = 0 # monotonic counter
106+
107+
108+
class SignalMailbox(Generic[SignalValueT]):
109+
"""
110+
Deterministic broadcast mailbox that stores signal values with versioning.
111+
Each workflow run has its own mailbox instance.
112+
"""
113+
114+
def __init__(self) -> None:
115+
self._store: Dict[str, _Record[SignalValueT]] = {}
116+
117+
def push(self, name: str, value: SignalValueT) -> None:
118+
"""
119+
Store a signal value and increment its version counter.
120+
This enables broadcast semantics where all waiters see the same value.
121+
"""
122+
rec = self._store.setdefault(name, _Record())
123+
rec.value = value
124+
rec.version += 1
125+
126+
logger.debug(
127+
f"SignalMailbox.push: name={name}, value={value}, version={rec.version}"
128+
)
129+
130+
def version(self, name: str) -> int:
131+
"""Get the current version counter for a signal name"""
132+
return self._store.get(name, _Record()).version
133+
134+
def value(self, name: str) -> SignalValueT:
135+
"""
136+
Get the current value for a signal name
137+
138+
Returns:
139+
The signal value
140+
141+
Raises:
142+
ValueError: If no value exists for the signal
143+
"""
144+
value = self._store.get(name, _Record()).value
145+
146+
if value is None:
147+
raise ValueError(f"No value for signal {name}")
148+
149+
logger.debug(
150+
f"SignalMailbox.value: name={name}, value={value}, version={self._store.get(name, _Record()).version}"
151+
)
152+
153+
return value
154+
155+
98156
class BaseSignalHandler(ABC, Generic[SignalValueT]):
99157
"""Base class implementing common signal handling functionality."""
100158

tests/executor/temporal/test_signal_handler.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import pytest
22
from unittest.mock import AsyncMock, MagicMock, patch
3-
from mcp_agent.executor.temporal.workflow_signal import (
4-
TemporalSignalHandler,
5-
SignalMailbox,
6-
)
7-
from mcp_agent.executor.workflow_signal import Signal
3+
from mcp_agent.executor.temporal.workflow_signal import TemporalSignalHandler
4+
from mcp_agent.executor.workflow_signal import Signal, SignalMailbox
85

96

107
@pytest.fixture

0 commit comments

Comments
 (0)