diff --git a/durabletask/entities/entity_instance_id.py b/durabletask/entities/entity_instance_id.py index c3b76c1..02a2595 100644 --- a/durabletask/entities/entity_instance_id.py +++ b/durabletask/entities/entity_instance_id.py @@ -1,6 +1,3 @@ -from typing import Optional - - class EntityInstanceId: def __init__(self, entity: str, key: str): self.entity = entity @@ -20,7 +17,7 @@ def __lt__(self, other): return str(self) < str(other) @staticmethod - def parse(entity_id: str) -> Optional["EntityInstanceId"]: + def parse(entity_id: str) -> "EntityInstanceId": """Parse a string representation of an entity ID into an EntityInstanceId object. Parameters @@ -30,8 +27,13 @@ def parse(entity_id: str) -> Optional["EntityInstanceId"]: Returns ------- - Optional[EntityInstanceId] - The parsed EntityInstanceId object, or None if the input is None. + EntityInstanceId + The parsed EntityInstanceId object. + + Raises + ------ + ValueError + If the input string is not in the correct format. """ try: _, entity, key = entity_id.split("@", 2) diff --git a/durabletask/entities/entity_metadata.py b/durabletask/entities/entity_metadata.py index 6800939..3e04206 100644 --- a/durabletask/entities/entity_metadata.py +++ b/durabletask/entities/entity_metadata.py @@ -44,8 +44,9 @@ def __init__(self, @staticmethod def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool): - entity_id = EntityInstanceId.parse(entity_response.entity.instanceId) - if not entity_id: + try: + entity_id = EntityInstanceId.parse(entity_response.entity.instanceId) + except ValueError: raise ValueError("Invalid entity instance ID in entity response.") entity_state = None if includes_state: diff --git a/durabletask/worker.py b/durabletask/worker.py index fae345c..0c05430 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -750,9 +750,10 @@ def _execute_entity_batch( for operation in req.operations: start_time = datetime.now(timezone.utc) executor = _EntityExecutor(self._registry, self._logger) - entity_instance_id = EntityInstanceId.parse(instance_id) - if not entity_instance_id: - raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.") + try: + entity_instance_id = EntityInstanceId.parse(instance_id) + except ValueError: + raise RuntimeError(f"Invalid entity instance ID '{instance_id}' in entity operation request.") operation_result = None @@ -1656,8 +1657,9 @@ def process_event( raise _get_wrong_action_type_error( entity_call_id, expected_method_name, action ) - entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value) - if not entity_id: + try: + entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value) + except ValueError: raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'") ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id) elif event.HasField("entityOperationSignaled"):