Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from google.protobuf import wrappers_pb2

from durabletask.entities import EntityInstanceId
from durabletask.entities.entity_metadata import EntityMetadata
import durabletask.internal.helpers as helpers
import durabletask.internal.orchestrator_service_pb2 as pb
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
Expand Down Expand Up @@ -241,3 +242,15 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: st
)
self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?

def get_entity(self,
entity_instance_id: EntityInstanceId,
include_state: bool = True
) -> Optional[EntityMetadata]:
req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
self._logger.info(f"Getting entity '{entity_instance_id}'.")
res: pb.GetEntityResponse = self._stub.GetEntity(req)
if not res.exists:
return None

return EntityMetadata.from_entity_response(res, include_state)
3 changes: 2 additions & 1 deletion durabletask/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from durabletask.entities.durable_entity import DurableEntity
from durabletask.entities.entity_lock import EntityLock
from durabletask.entities.entity_context import EntityContext
from durabletask.entities.entity_metadata import EntityMetadata

__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext"]
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"]

PACKAGE_NAME = "durabletask.entities"
2 changes: 1 addition & 1 deletion durabletask/entities/entity_instance_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def parse(entity_id: str) -> Optional["EntityInstanceId"]:
_, entity, key = entity_id.split("@", 2)
return EntityInstanceId(entity=entity, key=key)
except ValueError as ex:
raise ValueError("Invalid entity ID format", ex)
raise ValueError(f"Invalid entity ID format: {entity_id}", ex)
97 changes: 97 additions & 0 deletions durabletask/entities/entity_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from datetime import datetime, timezone
from typing import Any, Optional, Type, TypeVar, Union, overload
from durabletask.entities.entity_instance_id import EntityInstanceId

import durabletask.internal.orchestrator_service_pb2 as pb

TState = TypeVar("TState")


class EntityMetadata:
"""Class representing the metadata of a durable entity.

This class encapsulates the metadata information of a durable entity, allowing for
easy access and manipulation of the entity's metadata within the Durable Task
Framework.

Attributes:
id (EntityInstanceId): The unique identifier of the entity instance.
last_modified (datetime): The timestamp of the last modification to the entity.
backlog_queue_size (int): The size of the backlog queue for the entity.
locked_by (str): The identifier of the worker that currently holds the lock on the entity.
includes_state (bool): Indicates whether the metadata includes the state of the entity.
state (Optional[Any]): The current state of the entity, if included.
"""

def __init__(self,
id: EntityInstanceId,
last_modified: datetime,
backlog_queue_size: int,
locked_by: str,
includes_state: bool,
state: Optional[Any]):
"""Initializes a new instance of the EntityMetadata class.

Args:
value: The initial state value of the entity.
"""
self.id = id
self.last_modified = last_modified
self.backlog_queue_size = backlog_queue_size
self._locked_by = locked_by
self.includes_state = includes_state
self._state = state

@staticmethod
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
if not entity_id:
raise ValueError("Invalid entity instance ID in entity response.")
entity_state = None
if includes_state:
entity_state = entity_response.entity.serializedState.value
return EntityMetadata(
id=entity_id,
last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc),
backlog_queue_size=entity_response.entity.backlogQueueSize,
locked_by=entity_response.entity.lockedBy.value,
includes_state=includes_state,
state=entity_state
)

@overload
def get_state(self, intended_type: Type[TState]) -> Optional[TState]:
...

@overload
def get_state(self, intended_type: None = None) -> Any:
...

def get_state(self, intended_type: Optional[Type[TState]] = None) -> Union[None, TState, Any]:
"""Get the current state of the entity, optionally converting it to a specified type."""
if intended_type is None or self._state is None:
return self._state

if isinstance(self._state, intended_type):
return self._state

try:
return intended_type(self._state) # type: ignore[call-arg]
except Exception as ex:
raise TypeError(
f"Could not convert state of type '{type(self._state).__name__}' to '{intended_type.__name__}'"
) from ex

def get_locked_by(self) -> Optional[EntityInstanceId]:
"""Get the identifier of the worker that currently holds the lock on the entity.

Returns
-------
str
The identifier of the worker that currently holds the lock on the entity.
"""
if not self._locked_by:
return None

# Will throw ValueError if the format is invalid
return EntityInstanceId.parse(self._locked_by)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timezone
import os
import time

Expand Down Expand Up @@ -39,6 +40,36 @@ def do_nothing(self, _):
assert invoked


def test_client_get_class_entity():
invoked = False

class EmptyEntity(entities.DurableEntity):
def do_nothing(self, _):
self.set_state(1)
nonlocal invoked # don't do this in a real app!
invoked = True

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_entity(EmptyEntity)
w.start()

c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None)
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
c.signal_entity(entity_id, "do_nothing")
time.sleep(2) # wait for the signal to be processed
state = c.get_entity(entity_id)
assert state is not None
assert state.id == entity_id
assert state.get_locked_by() is None
assert state.last_modified < datetime.now(timezone.utc)
assert state.get_state(int) == 1

assert invoked


def test_orchestration_signal_class_entity():
invoked = False

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timezone
import os
import time

Expand Down Expand Up @@ -39,6 +40,36 @@ def empty_entity(ctx: entities.EntityContext, _):
assert invoked


def test_client_get_entity():
invoked = False

def empty_entity(ctx: entities.EntityContext, _):
nonlocal invoked # don't do this in a real app!
if ctx.operation == "do_nothing":
invoked = True
ctx.set_state(1)

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_entity(empty_entity)
w.start()

c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None)
entity_id = entities.EntityInstanceId("empty_entity", "testEntity")
c.signal_entity(entity_id, "do_nothing")
time.sleep(2) # wait for the signal to be processed
state = c.get_entity(entity_id)
assert state is not None
assert state.id == entity_id
assert state.get_locked_by() is None
assert state.last_modified < datetime.now(timezone.utc)
assert state.get_state(int) == 1

assert invoked


def test_orchestration_signal_entity():
invoked = False

Expand Down
Loading