Skip to content

Commit 01fa7a3

Browse files
committed
respond to comment 1
Signed-off-by: Tim Li <[email protected]>
1 parent 17a7316 commit 01fa7a3

File tree

5 files changed

+86
-15
lines changed

5 files changed

+86
-15
lines changed
File renamed without changes.

cadence/worker/_decision_task_handler.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
import threading
3+
from typing import Dict, Tuple
24

35
from cadence.api.v1.common_pb2 import Payload
46
from cadence.api.v1.service_worker_pb2 import (
@@ -19,7 +21,8 @@ class DecisionTaskHandler(BaseTaskHandler[PollForDecisionTaskResponse]):
1921
"""
2022
Task handler for processing decision tasks.
2123
22-
This handler processes decision tasks and generates decisions using the workflow engine.
24+
This handler processes decision tasks and generates decisions using workflow engines.
25+
Uses a thread-safe cache to hold workflow engines for concurrent decision task handling.
2326
"""
2427

2528
def __init__(self, client: Client, task_list: str, registry: Registry, identity: str = "unknown", **options):
@@ -35,7 +38,9 @@ def __init__(self, client: Client, task_list: str, registry: Registry, identity:
3538
"""
3639
super().__init__(client, task_list, identity, **options)
3740
self._registry = registry
38-
self._workflow_engine: WorkflowEngine
41+
# Thread-safe cache to hold workflow engines keyed by (workflow_id, run_id)
42+
self._workflow_engines: Dict[Tuple[str, str], WorkflowEngine] = {}
43+
self._cache_lock = threading.RLock()
3944

4045

4146
async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None:
@@ -84,21 +89,41 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -
8489
)
8590
raise KeyError(f"Workflow type '{workflow_type_name}' not found")
8691

87-
# Create workflow info and engine
92+
# Create workflow info and get or create workflow engine from cache
8893
workflow_info = WorkflowInfo(
8994
workflow_type=workflow_type_name,
9095
workflow_domain=self._client.domain,
9196
workflow_id=workflow_id,
9297
workflow_run_id=run_id
9398
)
9499

95-
self._workflow_engine = WorkflowEngine(
96-
info=workflow_info,
97-
client=self._client,
98-
workflow_func=workflow_func
99-
)
100+
# Use thread-safe cache to get or create workflow engine
101+
cache_key = (workflow_id, run_id)
102+
with self._cache_lock:
103+
workflow_engine = self._workflow_engines.get(cache_key)
104+
if workflow_engine is None:
105+
workflow_engine = WorkflowEngine(
106+
info=workflow_info,
107+
client=self._client,
108+
workflow_func=workflow_func
109+
)
110+
self._workflow_engines[cache_key] = workflow_engine
100111

101-
decision_result = await self._workflow_engine.process_decision(task)
112+
decision_result = await workflow_engine.process_decision(task)
113+
114+
# Clean up completed workflows from cache to prevent memory leaks
115+
# Use getattr with default False to handle mocked engines in tests
116+
if getattr(workflow_engine, '_is_workflow_complete', False):
117+
with self._cache_lock:
118+
self._workflow_engines.pop(cache_key, None)
119+
logger.debug(
120+
"Removed completed workflow from cache",
121+
extra={
122+
"workflow_id": workflow_id,
123+
"run_id": run_id,
124+
"cache_size": len(self._workflow_engines)
125+
}
126+
)
102127

103128
# Respond with the decisions
104129
await self._respond_decision_task_completed(task, decision_result)

cadence/worker/_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from cadence.client import Client
66
from cadence.worker._registry import Registry
77
from cadence.worker._activity import ActivityWorker
8-
from cadence.worker._decision_worker import DecisionWorker
8+
from cadence.worker._decision import DecisionWorker
99
from cadence.worker._types import WorkerOptions, _DEFAULT_WORKER_OPTIONS
1010

1111

tests/cadence/worker/test_decision_task_handler.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ async def test_handle_task_implementation_workflow_not_found(self, handler, samp
135135
await handler._handle_task_implementation(sample_decision_task)
136136

137137
@pytest.mark.asyncio
138-
async def test_handle_task_implementation_creates_new_engines(self, handler, sample_decision_task, mock_registry):
139-
"""Test that decision task handler creates new workflow engines for each task."""
138+
async def test_handle_task_implementation_caches_engines(self, handler, sample_decision_task, mock_registry):
139+
"""Test that decision task handler caches workflow engines for same workflow execution."""
140140
# Mock workflow function
141141
mock_workflow_func = Mock()
142142
mock_registry.get_workflow.return_value = mock_workflow_func
@@ -151,14 +151,60 @@ async def test_handle_task_implementation_creates_new_engines(self, handler, sam
151151
# First call - should create new engine
152152
await handler._handle_task_implementation(sample_decision_task)
153153

154-
# Second call - should create another new engine
154+
# Second call with same workflow_id and run_id - should reuse cached engine
155155
await handler._handle_task_implementation(sample_decision_task)
156156

157+
# Registry should be called for each task (to get workflow function)
158+
assert mock_registry.get_workflow.call_count == 2
159+
160+
# Engine should be created only once (cached for second call)
161+
assert mock_engine_class.call_count == 1
162+
163+
# But process_decision should be called twice
164+
assert mock_engine.process_decision.call_count == 2
165+
166+
@pytest.mark.asyncio
167+
async def test_handle_task_implementation_different_executions_get_separate_engines(self, handler, mock_registry):
168+
"""Test that different workflow executions get separate engines."""
169+
# Mock workflow function
170+
mock_workflow_func = Mock()
171+
mock_registry.get_workflow.return_value = mock_workflow_func
172+
173+
# Create two different decision tasks
174+
task1 = Mock(spec=PollForDecisionTaskResponse)
175+
task1.task_token = b"test_task_token_1"
176+
task1.workflow_execution = Mock()
177+
task1.workflow_execution.workflow_id = "workflow_1"
178+
task1.workflow_execution.run_id = "run_1"
179+
task1.workflow_type = Mock()
180+
task1.workflow_type.name = "TestWorkflow"
181+
182+
task2 = Mock(spec=PollForDecisionTaskResponse)
183+
task2.task_token = b"test_task_token_2"
184+
task2.workflow_execution = Mock()
185+
task2.workflow_execution.workflow_id = "workflow_2" # Different workflow
186+
task2.workflow_execution.run_id = "run_2" # Different run
187+
task2.workflow_type = Mock()
188+
task2.workflow_type.name = "TestWorkflow"
189+
190+
# Mock workflow engine
191+
mock_engine = Mock(spec=WorkflowEngine)
192+
mock_decision_result = Mock(spec=DecisionResult)
193+
mock_decision_result.decisions = []
194+
mock_engine.process_decision = AsyncMock(return_value=mock_decision_result)
195+
196+
with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class:
197+
# Process different workflow executions
198+
await handler._handle_task_implementation(task1)
199+
await handler._handle_task_implementation(task2)
200+
157201
# Registry should be called for each task
158202
assert mock_registry.get_workflow.call_count == 2
159203

160-
# Engine should be created twice and called twice
204+
# Engine should be created twice (different executions)
161205
assert mock_engine_class.call_count == 2
206+
207+
# Process_decision should be called twice
162208
assert mock_engine.process_decision.call_count == 2
163209

164210
@pytest.mark.asyncio

tests/cadence/worker/test_decision_worker_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
1010
from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType
1111
from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes
12-
from cadence.worker._decision_worker import DecisionWorker
12+
from cadence.worker._decision import DecisionWorker
1313
from cadence.worker._registry import Registry
1414
from cadence.client import Client
1515

0 commit comments

Comments
 (0)