diff --git a/durabletask/client.py b/durabletask/client.py index c150822..7d03758 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -12,6 +12,7 @@ from google.protobuf import wrappers_pb2 from durabletask.entities import EntityInstanceId +from durabletask.entities.entity_metadata import EntityMetadata import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs @@ -241,3 +242,15 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: st ) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") self._stub.SignalEntity(req, None) # TODO: Cancellation timeout? + + def get_entity(self, + entity_instance_id: EntityInstanceId, + include_state: bool = True + ) -> Optional[EntityMetadata]: + req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state) + self._logger.info(f"Getting entity '{entity_instance_id}'.") + res: pb.GetEntityResponse = self._stub.GetEntity(req) + if not res.exists: + return None + + return EntityMetadata.from_entity_response(res, include_state) diff --git a/durabletask/entities/__init__.py b/durabletask/entities/__init__.py index 4ab03c0..46f059b 100644 --- a/durabletask/entities/__init__.py +++ b/durabletask/entities/__init__.py @@ -7,7 +7,8 @@ from durabletask.entities.durable_entity import DurableEntity from durabletask.entities.entity_lock import EntityLock from durabletask.entities.entity_context import EntityContext +from durabletask.entities.entity_metadata import EntityMetadata -__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext"] +__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"] PACKAGE_NAME = "durabletask.entities" diff --git a/durabletask/entities/entity_instance_id.py b/durabletask/entities/entity_instance_id.py index 53c1171..c3b76c1 100644 --- a/durabletask/entities/entity_instance_id.py +++ b/durabletask/entities/entity_instance_id.py @@ -37,4 +37,4 @@ def parse(entity_id: str) -> Optional["EntityInstanceId"]: _, entity, key = entity_id.split("@", 2) return EntityInstanceId(entity=entity, key=key) except ValueError as ex: - raise ValueError("Invalid entity ID format", ex) + raise ValueError(f"Invalid entity ID format: {entity_id}", ex) diff --git a/durabletask/entities/entity_metadata.py b/durabletask/entities/entity_metadata.py new file mode 100644 index 0000000..6800939 --- /dev/null +++ b/durabletask/entities/entity_metadata.py @@ -0,0 +1,97 @@ +from datetime import datetime, timezone +from typing import Any, Optional, Type, TypeVar, Union, overload +from durabletask.entities.entity_instance_id import EntityInstanceId + +import durabletask.internal.orchestrator_service_pb2 as pb + +TState = TypeVar("TState") + + +class EntityMetadata: + """Class representing the metadata of a durable entity. + + This class encapsulates the metadata information of a durable entity, allowing for + easy access and manipulation of the entity's metadata within the Durable Task + Framework. + + Attributes: + id (EntityInstanceId): The unique identifier of the entity instance. + last_modified (datetime): The timestamp of the last modification to the entity. + backlog_queue_size (int): The size of the backlog queue for the entity. + locked_by (str): The identifier of the worker that currently holds the lock on the entity. + includes_state (bool): Indicates whether the metadata includes the state of the entity. + state (Optional[Any]): The current state of the entity, if included. + """ + + def __init__(self, + id: EntityInstanceId, + last_modified: datetime, + backlog_queue_size: int, + locked_by: str, + includes_state: bool, + state: Optional[Any]): + """Initializes a new instance of the EntityMetadata class. + + Args: + value: The initial state value of the entity. + """ + self.id = id + self.last_modified = last_modified + self.backlog_queue_size = backlog_queue_size + self._locked_by = locked_by + self.includes_state = includes_state + self._state = state + + @staticmethod + def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool): + entity_id = EntityInstanceId.parse(entity_response.entity.instanceId) + if not entity_id: + raise ValueError("Invalid entity instance ID in entity response.") + entity_state = None + if includes_state: + entity_state = entity_response.entity.serializedState.value + return EntityMetadata( + id=entity_id, + last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc), + backlog_queue_size=entity_response.entity.backlogQueueSize, + locked_by=entity_response.entity.lockedBy.value, + includes_state=includes_state, + state=entity_state + ) + + @overload + def get_state(self, intended_type: Type[TState]) -> Optional[TState]: + ... + + @overload + def get_state(self, intended_type: None = None) -> Any: + ... + + def get_state(self, intended_type: Optional[Type[TState]] = None) -> Union[None, TState, Any]: + """Get the current state of the entity, optionally converting it to a specified type.""" + if intended_type is None or self._state is None: + return self._state + + if isinstance(self._state, intended_type): + return self._state + + try: + return intended_type(self._state) # type: ignore[call-arg] + except Exception as ex: + raise TypeError( + f"Could not convert state of type '{type(self._state).__name__}' to '{intended_type.__name__}'" + ) from ex + + def get_locked_by(self) -> Optional[EntityInstanceId]: + """Get the identifier of the worker that currently holds the lock on the entity. + + Returns + ------- + str + The identifier of the worker that currently holds the lock on the entity. + """ + if not self._locked_by: + return None + + # Will throw ValueError if the format is invalid + return EntityInstanceId.parse(self._locked_by) diff --git a/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py b/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py index 19e8e5b..6075029 100644 --- a/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py +++ b/tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone import os import time @@ -39,6 +40,36 @@ def do_nothing(self, _): assert invoked +def test_client_get_class_entity(): + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + self.set_state(1) + nonlocal invoked # don't do this in a real app! + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(EmptyEntity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + state = c.get_entity(entity_id) + assert state is not None + assert state.id == entity_id + assert state.get_locked_by() is None + assert state.last_modified < datetime.now(timezone.utc) + assert state.get_state(int) == 1 + + assert invoked + + def test_orchestration_signal_class_entity(): invoked = False diff --git a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py index 4220655..6b857be 100644 --- a/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py +++ b/tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone import os import time @@ -39,6 +40,36 @@ def empty_entity(ctx: entities.EntityContext, _): assert invoked +def test_client_get_entity(): + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + nonlocal invoked # don't do this in a real app! + if ctx.operation == "do_nothing": + invoked = True + ctx.set_state(1) + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(empty_entity) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + state = c.get_entity(entity_id) + assert state is not None + assert state.id == entity_id + assert state.get_locked_by() is None + assert state.last_modified < datetime.now(timezone.utc) + assert state.get_state(int) == 1 + + assert invoked + + def test_orchestration_signal_entity(): invoked = False