Skip to content

Commit 440ae8d

Browse files
Copilotberndverst
andcommitted
Implement entity locking functionality with comprehensive tests
Co-authored-by: berndverst <[email protected]>
1 parent e2cff34 commit 440ae8d

File tree

5 files changed

+526
-6
lines changed

5 files changed

+526
-6
lines changed

durabletask/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from durabletask.worker import ConcurrencyOptions
77
from durabletask.task import (
88
EntityContext, EntityState, EntityQuery, EntityQueryResult,
9-
EntityInstanceId, EntityOperationFailedException, EntityBase, dispatch_to_entity_method
9+
EntityInstanceId, EntityOperationFailedException, EntityBase, dispatch_to_entity_method,
10+
OrchestrationContext
1011
)
1112

1213
__all__ = [
@@ -18,7 +19,8 @@
1819
"EntityInstanceId",
1920
"EntityOperationFailedException",
2021
"EntityBase",
21-
"dispatch_to_entity_method"
22+
"dispatch_to_entity_method",
23+
"OrchestrationContext"
2224
]
2325

2426
PACKAGE_NAME = "durabletask"

durabletask/task.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,25 @@ def call_entity(self, entity_id: Union[str, 'EntityInstanceId'], operation_name:
223223
"""
224224
pass
225225

226+
@abstractmethod
227+
def lock_entities(self, *entity_ids: Union[str, 'EntityInstanceId']) -> 'EntityLockContext':
228+
"""Create a context manager for locking multiple entities.
229+
230+
This allows orchestrations to lock entities before performing operations
231+
on them, preventing race conditions with other orchestrations.
232+
233+
Parameters
234+
----------
235+
*entity_ids : Union[str, EntityInstanceId]
236+
Variable number of entity IDs to lock
237+
238+
Returns
239+
-------
240+
EntityLockContext
241+
A context manager that handles locking and unlocking
242+
"""
243+
pass
244+
226245

227246
class FailureDetails:
228247
def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):
@@ -537,6 +556,20 @@ def from_string(cls, instance_id: str) -> 'EntityInstanceId':
537556
return cls(name=parts[0], key=parts[1])
538557

539558

559+
class EntityLockContext(ABC):
560+
"""Abstract base class for entity locking context managers."""
561+
562+
@abstractmethod
563+
def __enter__(self) -> 'EntityLockContext':
564+
"""Enter the entity lock context."""
565+
pass
566+
567+
@abstractmethod
568+
def __exit__(self, exc_type, exc_val, exc_tb):
569+
"""Exit the entity lock context."""
570+
pass
571+
572+
540573
class EntityOperationFailedException(Exception):
541574
"""Exception raised when an entity operation fails."""
542575

durabletask/worker.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -934,17 +934,18 @@ def signal_entity(self, entity_id: Union[str, task.EntityInstanceId], operation_
934934
entity_id_str = str(entity_id) if hasattr(entity_id, '__str__') else entity_id
935935

936936
action = pb.OrchestratorAction()
937-
action.sendEntitySignal.CopyFrom(pb.SendSignalAction(
938-
instanceId=entity_id_str,
937+
action.sendEvent.CopyFrom(pb.SendEventAction(
938+
instance=pb.OrchestrationInstance(instanceId=entity_id_str),
939939
name=operation_name,
940-
input=ph.get_string_value(shared.to_json(input)) if input is not None else None
940+
data=ph.get_string_value(shared.to_json(input)) if input is not None else None
941941
))
942942

943943
# Entity signals don't return values, so we create a completed task
944944
signal_task = task.CompletableTask()
945945

946946
# Store the action to be executed
947-
task_id = self._next_task_id()
947+
task_id = self.next_sequence_number()
948+
action.id = task_id
948949
self._pending_actions[task_id] = action
949950
self._pending_tasks[task_id] = signal_task
950951

@@ -960,6 +961,53 @@ def call_entity(self, entity_id: Union[str, task.EntityInstanceId], operation_na
960961
# This would require additional protobuf support
961962
raise NotImplementedError("Direct entity calls from orchestrations are not yet supported. Use signal_entity instead.")
962963

964+
def lock_entities(self, *entity_ids: Union[str, task.EntityInstanceId]) -> 'EntityLockContext':
965+
"""Create a context manager for locking multiple entities.
966+
967+
This allows orchestrations to lock entities before performing operations
968+
on them, preventing race conditions with other orchestrations.
969+
970+
Args:
971+
*entity_ids: Variable number of entity IDs to lock
972+
973+
Returns:
974+
EntityLockContext: A context manager that handles locking and unlocking
975+
976+
Example:
977+
with ctx.lock_entities("Counter@global", "ShoppingCart@user1"):
978+
# Perform operations on locked entities
979+
yield ctx.signal_entity("Counter@global", "increment", input=1)
980+
yield ctx.signal_entity("ShoppingCart@user1", "add_item", input=item)
981+
"""
982+
return EntityLockContext(self, entity_ids)
983+
984+
985+
class EntityLockContext(task.EntityLockContext):
986+
"""Context manager for entity locking in orchestrations.
987+
988+
This class provides a context manager that handles locking and unlocking
989+
of entities during orchestration execution to prevent race conditions.
990+
"""
991+
992+
def __init__(self, ctx: '_RuntimeOrchestrationContext', entity_ids: tuple):
993+
self._ctx = ctx
994+
self._entity_ids = [str(eid) if hasattr(eid, '__str__') else eid for eid in entity_ids]
995+
self._lock_instance_id = f"__lock__{ctx.instance_id}_{ctx.next_sequence_number()}"
996+
997+
def __enter__(self) -> 'EntityLockContext':
998+
"""Enter the entity lock context by acquiring locks on all specified entities."""
999+
# Signal each entity to acquire a lock
1000+
for entity_id in self._entity_ids:
1001+
self._ctx.signal_entity(entity_id, "__acquire_lock__", input=self._lock_instance_id)
1002+
return self
1003+
1004+
def __exit__(self, exc_type, exc_val, exc_tb):
1005+
"""Exit the entity lock context by releasing locks on all specified entities."""
1006+
# Signal each entity to release the lock
1007+
for entity_id in self._entity_ids:
1008+
self._ctx.signal_entity(entity_id, "__release_lock__", input=self._lock_instance_id)
1009+
return False # Don't suppress exceptions
1010+
9631011

9641012
class ExecutionResults:
9651013
actions: list[pb.OrchestratorAction]

examples/entity_locking_example.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Microsoft Corporation.
4+
# Licensed under the MIT License.
5+
6+
"""
7+
Example demonstrating entity locking in durable task orchestrations.
8+
9+
This example shows how to use entity locking to prevent race conditions
10+
when multiple orchestrations need to modify the same entities.
11+
"""
12+
13+
import durabletask as dt
14+
from typing import Any, Optional
15+
16+
17+
def counter_entity(ctx: dt.EntityContext, input: Any) -> Optional[Any]:
18+
"""A counter entity that supports locking and counting operations."""
19+
operation = ctx.operation_name
20+
21+
if operation == "__acquire_lock__":
22+
# Store the lock ID to track who has the lock
23+
lock_id = input
24+
current_lock = ctx.get_state(key="__lock__")
25+
if current_lock is not None:
26+
raise ValueError(f"Entity {ctx.instance_id} is already locked by {current_lock}")
27+
ctx.set_state(lock_id, key="__lock__")
28+
return None
29+
30+
elif operation == "__release_lock__":
31+
# Release the lock if it matches the provided lock ID
32+
lock_id = input
33+
current_lock = ctx.get_state(key="__lock__")
34+
if current_lock is None:
35+
raise ValueError(f"Entity {ctx.instance_id} is not locked")
36+
if current_lock != lock_id:
37+
raise ValueError(f"Lock ID mismatch for entity {ctx.instance_id}")
38+
ctx.set_state(None, key="__lock__")
39+
return None
40+
41+
elif operation == "increment":
42+
# Only allow increment if entity is locked
43+
current_lock = ctx.get_state(key="__lock__")
44+
if current_lock is None:
45+
raise ValueError(f"Entity {ctx.instance_id} must be locked before increment")
46+
47+
current_count = ctx.get_state(key="count") or 0
48+
new_count = current_count + (input or 1)
49+
ctx.set_state(new_count, key="count")
50+
return new_count
51+
52+
elif operation == "get":
53+
# Get can be called without locking
54+
return ctx.get_state(key="count") or 0
55+
56+
elif operation == "reset":
57+
# Reset requires locking
58+
current_lock = ctx.get_state(key="__lock__")
59+
if current_lock is None:
60+
raise ValueError(f"Entity {ctx.instance_id} must be locked before reset")
61+
62+
ctx.set_state(0, key="count")
63+
return 0
64+
65+
66+
def bank_account_entity(ctx: dt.EntityContext, input: Any) -> Optional[Any]:
67+
"""A bank account entity that supports locking for safe transfers."""
68+
operation = ctx.operation_name
69+
70+
if operation == "__acquire_lock__":
71+
lock_id = input
72+
current_lock = ctx.get_state(key="__lock__")
73+
if current_lock is not None:
74+
raise ValueError(f"Account {ctx.instance_id} is already locked by {current_lock}")
75+
ctx.set_state(lock_id, key="__lock__")
76+
return None
77+
78+
elif operation == "__release_lock__":
79+
lock_id = input
80+
current_lock = ctx.get_state(key="__lock__")
81+
if current_lock is None:
82+
raise ValueError(f"Account {ctx.instance_id} is not locked")
83+
if current_lock != lock_id:
84+
raise ValueError(f"Lock ID mismatch for account {ctx.instance_id}")
85+
ctx.set_state(None, key="__lock__")
86+
return None
87+
88+
elif operation == "deposit":
89+
current_lock = ctx.get_state(key="__lock__")
90+
if current_lock is None:
91+
raise ValueError(f"Account {ctx.instance_id} must be locked before deposit")
92+
93+
amount = input.get("amount", 0)
94+
current_balance = ctx.get_state(key="balance") or 0
95+
new_balance = current_balance + amount
96+
ctx.set_state(new_balance, key="balance")
97+
return new_balance
98+
99+
elif operation == "withdraw":
100+
current_lock = ctx.get_state(key="__lock__")
101+
if current_lock is None:
102+
raise ValueError(f"Account {ctx.instance_id} must be locked before withdraw")
103+
104+
amount = input.get("amount", 0)
105+
current_balance = ctx.get_state(key="balance") or 0
106+
if current_balance < amount:
107+
raise ValueError("Insufficient funds")
108+
new_balance = current_balance - amount
109+
ctx.set_state(new_balance, key="balance")
110+
return new_balance
111+
112+
elif operation == "get_balance":
113+
return ctx.get_state(key="balance") or 0
114+
115+
116+
def transfer_money_orchestration(ctx: dt.OrchestrationContext, input: Any) -> Any:
117+
"""Orchestration that safely transfers money between accounts using entity locking."""
118+
from_account = input["from_account"]
119+
to_account = input["to_account"]
120+
amount = input["amount"]
121+
122+
# Lock both accounts to prevent race conditions during transfer
123+
with ctx.lock_entities(from_account, to_account):
124+
# First, withdraw from source account
125+
yield ctx.signal_entity(from_account, "withdraw", input={"amount": amount})
126+
127+
# Then, deposit to destination account
128+
yield ctx.signal_entity(to_account, "deposit", input={"amount": amount})
129+
130+
# Return confirmation that transfer is complete
131+
return {
132+
"transfer_completed": True,
133+
"from_account": from_account,
134+
"to_account": to_account,
135+
"amount": amount
136+
}
137+
138+
139+
def batch_counter_update_orchestration(ctx: dt.OrchestrationContext, input: Any) -> Any:
140+
"""Orchestration that safely updates multiple counters in a batch."""
141+
counter_ids = input.get("counter_ids", [])
142+
increment_value = input.get("increment_value", 1)
143+
144+
# Lock all counters to ensure atomic batch operation
145+
with ctx.lock_entities(*counter_ids):
146+
results = []
147+
for counter_id in counter_ids:
148+
# Signal each counter to increment
149+
task = yield ctx.signal_entity(counter_id, "increment", input=increment_value)
150+
results.append(task)
151+
152+
# After all operations are complete, get final values
153+
final_values = {}
154+
for counter_id in counter_ids:
155+
value_task = yield ctx.signal_entity(counter_id, "get")
156+
final_values[counter_id] = value_task
157+
158+
return {
159+
"updated_counters": counter_ids,
160+
"increment_value": increment_value,
161+
"final_values": final_values
162+
}
163+
164+
165+
if __name__ == "__main__":
166+
print("Entity Locking Example")
167+
print("======================")
168+
print()
169+
print("This example demonstrates entity locking patterns:")
170+
print("1. Counter entity with locking support")
171+
print("2. Bank account entity with locking for transfers")
172+
print("3. Transfer orchestration using entity locking")
173+
print("4. Batch counter update orchestration")
174+
print()
175+
print("Key concepts:")
176+
print("- Entities handle __acquire_lock__ and __release_lock__ operations")
177+
print("- Orchestrations use ctx.lock_entities() context manager")
178+
print("- Locks prevent race conditions during multi-entity operations")
179+
print("- Locks are automatically released even if exceptions occur")
180+
print()
181+
print("To use these patterns in your own code:")
182+
print("1. Implement lock handling in your entity functions")
183+
print("2. Use 'with ctx.lock_entities(*entity_ids):' in orchestrations")
184+
print("3. Perform all related entity operations within the lock context")

0 commit comments

Comments
 (0)