Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions durabletask/entities/entity_instance_id.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from typing import Optional


class EntityInstanceId:
def __init__(self, entity: str, key: str):
self.entity = entity
Expand All @@ -20,7 +17,7 @@ def __lt__(self, other):
return str(self) < str(other)

@staticmethod
def parse(entity_id: str) -> Optional["EntityInstanceId"]:
def parse(entity_id: str) -> "EntityInstanceId":
"""Parse a string representation of an entity ID into an EntityInstanceId object.

Parameters
Expand All @@ -30,8 +27,13 @@ def parse(entity_id: str) -> Optional["EntityInstanceId"]:

Returns
-------
Optional[EntityInstanceId]
The parsed EntityInstanceId object, or None if the input is None.
EntityInstanceId
The parsed EntityInstanceId object.

Raises
------
ValueError
If the input string is not in the correct format.
"""
try:
_, entity, key = entity_id.split("@", 2)
Expand Down
5 changes: 3 additions & 2 deletions durabletask/entities/entity_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def __init__(self,

@staticmethod
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
if not entity_id:
try:
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
except ValueError:
raise ValueError("Invalid entity instance ID in entity response.")
entity_state = None
if includes_state:
Expand Down
12 changes: 7 additions & 5 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,10 @@ def _execute_entity_batch(
for operation in req.operations:
start_time = datetime.now(timezone.utc)
executor = _EntityExecutor(self._registry, self._logger)
entity_instance_id = EntityInstanceId.parse(instance_id)
if not entity_instance_id:
raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.")
try:
entity_instance_id = EntityInstanceId.parse(instance_id)
except ValueError:
raise RuntimeError(f"Invalid entity instance ID '{instance_id}' in entity operation request.")

operation_result = None

Expand Down Expand Up @@ -1656,8 +1657,9 @@ def process_event(
raise _get_wrong_action_type_error(
entity_call_id, expected_method_name, action
)
entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
if not entity_id:
try:
entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
except ValueError:
raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id)
elif event.HasField("entityOperationSignaled"):
Expand Down
Loading