Skip to content

Commit 070eef7

Browse files
authored
Allow retrieve entity metadata from client (#77)
* Allow retrieve entity metadata from client * Lint * Better state formatting * Add tests * Improve locked_by parsing * Parse last_modified as UTC * Fix tests
1 parent 901c63d commit 070eef7

File tree

6 files changed

+175
-2
lines changed

6 files changed

+175
-2
lines changed

durabletask/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from google.protobuf import wrappers_pb2
1313

1414
from durabletask.entities import EntityInstanceId
15+
from durabletask.entities.entity_metadata import EntityMetadata
1516
import durabletask.internal.helpers as helpers
1617
import durabletask.internal.orchestrator_service_pb2 as pb
1718
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
@@ -241,3 +242,15 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: st
241242
)
242243
self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
243244
self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?
245+
246+
def get_entity(self,
247+
entity_instance_id: EntityInstanceId,
248+
include_state: bool = True
249+
) -> Optional[EntityMetadata]:
250+
req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
251+
self._logger.info(f"Getting entity '{entity_instance_id}'.")
252+
res: pb.GetEntityResponse = self._stub.GetEntity(req)
253+
if not res.exists:
254+
return None
255+
256+
return EntityMetadata.from_entity_response(res, include_state)

durabletask/entities/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from durabletask.entities.durable_entity import DurableEntity
88
from durabletask.entities.entity_lock import EntityLock
99
from durabletask.entities.entity_context import EntityContext
10+
from durabletask.entities.entity_metadata import EntityMetadata
1011

11-
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext"]
12+
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"]
1213

1314
PACKAGE_NAME = "durabletask.entities"

durabletask/entities/entity_instance_id.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ def parse(entity_id: str) -> Optional["EntityInstanceId"]:
3737
_, entity, key = entity_id.split("@", 2)
3838
return EntityInstanceId(entity=entity, key=key)
3939
except ValueError as ex:
40-
raise ValueError("Invalid entity ID format", ex)
40+
raise ValueError(f"Invalid entity ID format: {entity_id}", ex)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from datetime import datetime, timezone
2+
from typing import Any, Optional, Type, TypeVar, Union, overload
3+
from durabletask.entities.entity_instance_id import EntityInstanceId
4+
5+
import durabletask.internal.orchestrator_service_pb2 as pb
6+
7+
TState = TypeVar("TState")
8+
9+
10+
class EntityMetadata:
11+
"""Class representing the metadata of a durable entity.
12+
13+
This class encapsulates the metadata information of a durable entity, allowing for
14+
easy access and manipulation of the entity's metadata within the Durable Task
15+
Framework.
16+
17+
Attributes:
18+
id (EntityInstanceId): The unique identifier of the entity instance.
19+
last_modified (datetime): The timestamp of the last modification to the entity.
20+
backlog_queue_size (int): The size of the backlog queue for the entity.
21+
locked_by (str): The identifier of the worker that currently holds the lock on the entity.
22+
includes_state (bool): Indicates whether the metadata includes the state of the entity.
23+
state (Optional[Any]): The current state of the entity, if included.
24+
"""
25+
26+
def __init__(self,
27+
id: EntityInstanceId,
28+
last_modified: datetime,
29+
backlog_queue_size: int,
30+
locked_by: str,
31+
includes_state: bool,
32+
state: Optional[Any]):
33+
"""Initializes a new instance of the EntityMetadata class.
34+
35+
Args:
36+
value: The initial state value of the entity.
37+
"""
38+
self.id = id
39+
self.last_modified = last_modified
40+
self.backlog_queue_size = backlog_queue_size
41+
self._locked_by = locked_by
42+
self.includes_state = includes_state
43+
self._state = state
44+
45+
@staticmethod
46+
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
47+
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
48+
if not entity_id:
49+
raise ValueError("Invalid entity instance ID in entity response.")
50+
entity_state = None
51+
if includes_state:
52+
entity_state = entity_response.entity.serializedState.value
53+
return EntityMetadata(
54+
id=entity_id,
55+
last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc),
56+
backlog_queue_size=entity_response.entity.backlogQueueSize,
57+
locked_by=entity_response.entity.lockedBy.value,
58+
includes_state=includes_state,
59+
state=entity_state
60+
)
61+
62+
@overload
63+
def get_state(self, intended_type: Type[TState]) -> Optional[TState]:
64+
...
65+
66+
@overload
67+
def get_state(self, intended_type: None = None) -> Any:
68+
...
69+
70+
def get_state(self, intended_type: Optional[Type[TState]] = None) -> Union[None, TState, Any]:
71+
"""Get the current state of the entity, optionally converting it to a specified type."""
72+
if intended_type is None or self._state is None:
73+
return self._state
74+
75+
if isinstance(self._state, intended_type):
76+
return self._state
77+
78+
try:
79+
return intended_type(self._state) # type: ignore[call-arg]
80+
except Exception as ex:
81+
raise TypeError(
82+
f"Could not convert state of type '{type(self._state).__name__}' to '{intended_type.__name__}'"
83+
) from ex
84+
85+
def get_locked_by(self) -> Optional[EntityInstanceId]:
86+
"""Get the identifier of the worker that currently holds the lock on the entity.
87+
88+
Returns
89+
-------
90+
str
91+
The identifier of the worker that currently holds the lock on the entity.
92+
"""
93+
if not self._locked_by:
94+
return None
95+
96+
# Will throw ValueError if the format is invalid
97+
return EntityInstanceId.parse(self._locked_by)

tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime, timezone
12
import os
23
import time
34

@@ -39,6 +40,36 @@ def do_nothing(self, _):
3940
assert invoked
4041

4142

43+
def test_client_get_class_entity():
44+
invoked = False
45+
46+
class EmptyEntity(entities.DurableEntity):
47+
def do_nothing(self, _):
48+
self.set_state(1)
49+
nonlocal invoked # don't do this in a real app!
50+
invoked = True
51+
52+
# Start a worker, which will connect to the sidecar in a background thread
53+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
54+
taskhub=taskhub_name, token_credential=None) as w:
55+
w.add_entity(EmptyEntity)
56+
w.start()
57+
58+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
59+
taskhub=taskhub_name, token_credential=None)
60+
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
61+
c.signal_entity(entity_id, "do_nothing")
62+
time.sleep(2) # wait for the signal to be processed
63+
state = c.get_entity(entity_id)
64+
assert state is not None
65+
assert state.id == entity_id
66+
assert state.get_locked_by() is None
67+
assert state.last_modified < datetime.now(timezone.utc)
68+
assert state.get_state(int) == 1
69+
70+
assert invoked
71+
72+
4273
def test_orchestration_signal_class_entity():
4374
invoked = False
4475

tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime, timezone
12
import os
23
import time
34

@@ -39,6 +40,36 @@ def empty_entity(ctx: entities.EntityContext, _):
3940
assert invoked
4041

4142

43+
def test_client_get_entity():
44+
invoked = False
45+
46+
def empty_entity(ctx: entities.EntityContext, _):
47+
nonlocal invoked # don't do this in a real app!
48+
if ctx.operation == "do_nothing":
49+
invoked = True
50+
ctx.set_state(1)
51+
52+
# Start a worker, which will connect to the sidecar in a background thread
53+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
54+
taskhub=taskhub_name, token_credential=None) as w:
55+
w.add_entity(empty_entity)
56+
w.start()
57+
58+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
59+
taskhub=taskhub_name, token_credential=None)
60+
entity_id = entities.EntityInstanceId("empty_entity", "testEntity")
61+
c.signal_entity(entity_id, "do_nothing")
62+
time.sleep(2) # wait for the signal to be processed
63+
state = c.get_entity(entity_id)
64+
assert state is not None
65+
assert state.id == entity_id
66+
assert state.get_locked_by() is None
67+
assert state.last_modified < datetime.now(timezone.utc)
68+
assert state.get_state(int) == 1
69+
70+
assert invoked
71+
72+
4273
def test_orchestration_signal_entity():
4374
invoked = False
4475

0 commit comments

Comments
 (0)