From 07d48760324882dcbf27658762e118d5175ac20a Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 5 Dec 2025 11:37:24 -0700 Subject: [PATCH 1/2] Add support for new event types --- durabletask/worker.py | 103 ++++++++++++++++++++++++++++++++---------- 1 file changed, 80 insertions(+), 23 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index fae345c..d0414c0 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -828,6 +828,7 @@ def __init__(self, instance_id: str, registry: _Registry): self._pending_tasks: dict[int, task.CompletableTask] = {} # Maps entity ID to task ID self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} + self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {} # Maps criticalSectionId to task ID self._entity_lock_id_map: dict[str, int] = {} self._sequence_number = 0 @@ -1590,33 +1591,70 @@ def process_event( else: raise TypeError("Unexpected sub-orchestration task type") elif event.HasField("eventRaised"): - # event names are case-insensitive - event_name = event.eventRaised.name.casefold() - if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id} Event raised: {event_name}") - task_list = ctx._pending_events.get(event_name, None) - decoded_result: Optional[Any] = None - if task_list: - event_task = task_list.pop(0) + if event.eventRaised.name in ctx._entity_task_id_map: + # This eventRaised represents the result of an entity operation after being translated to the old + # entity protocol by the Durable WebJobs extension + entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None)) + if entity_id is None: + raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'") + if task_id is None: + raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'") + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'") + result = None if not ph.is_empty(event.eventRaised.input): - decoded_result = shared.from_json(event.eventRaised.input.value) - event_task.complete(decoded_result) - if not task_list: - del ctx._pending_events[event_name] + # TODO: Investigate why the event result is wrapped in a dict with "result" key + result = shared.from_json(event.eventRaised.input.value)["result"] + ctx._entity_context.recover_lock_after_call(entity_id) + entity_task.complete(result) ctx.resume() - else: - # buffer the event - event_list = ctx._received_events.get(event_name, None) - if not event_list: - event_list = [] - ctx._received_events[event_name] = event_list + elif event.eventRaised.name in ctx._entity_lock_task_id_map: + # This eventRaised represents the result of an entity operation after being translated to the old + # entity protocol by the Durable WebJobs extension + entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None)) + if entity_id is None: + raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'") + if task_id is None: + raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'") + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'") + result = None if not ph.is_empty(event.eventRaised.input): - decoded_result = shared.from_json(event.eventRaised.input.value) - event_list.append(decoded_result) + # TODO: Investigate why the event result is wrapped in a dict with "result" key + result = shared.from_json(event.eventRaised.input.value)["result"] + ctx._entity_context.complete_acquire(event.eventRaised.name) + entity_task.complete(EntityLock(ctx)) + ctx.resume() + else: + # event names are case-insensitive + event_name = event.eventRaised.name.casefold() if not ctx.is_replaying: - self._logger.info( - f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it." - ) + self._logger.info(f"{ctx.instance_id} Event raised: {event_name}") + task_list = ctx._pending_events.get(event_name, None) + decoded_result: Optional[Any] = None + if task_list: + event_task = task_list.pop(0) + if not ph.is_empty(event.eventRaised.input): + decoded_result = shared.from_json(event.eventRaised.input.value) + event_task.complete(decoded_result) + if not task_list: + del ctx._pending_events[event_name] + ctx.resume() + else: + # buffer the event + event_list = ctx._received_events.get(event_name, None) + if not event_list: + event_list = [] + ctx._received_events[event_name] = event_list + if not ph.is_empty(event.eventRaised.input): + decoded_result = shared.from_json(event.eventRaised.input.value) + event_list.append(decoded_result) + if not ctx.is_replaying: + self._logger.info( + f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it." + ) elif event.HasField("executionSuspended"): if not self._is_suspended and not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Execution suspended.") @@ -1743,6 +1781,25 @@ def process_event( self._logger.info(f"{ctx.instance_id}: Entity operation failed.") self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}") pass + elif event.HasField("orchestratorCompleted"): + # Added in Functions only (for some reason) and does not affect orchestrator flow + pass + elif event.HasField("eventSent"): + # Check if this eventSent corresponds to an entity operation call after being translated to the old + # entity protocol by the Durable WebJobs extension. If so, treat this message similarly to + # entityOperationCalled and remove the pending action. Also store the entity id and event id for later + action = ctx._pending_actions.pop(event.eventId, None) + if action and action.HasField("sendEntityMessage"): + if action.sendEntityMessage.HasField("entityOperationCalled"): + entity_id = EntityInstanceId.parse(event.eventSent.instanceId) + event_id = json.loads(event.eventSent.input.value)["id"] + ctx._entity_task_id_map[event_id] = (entity_id, event.eventId) + elif action.sendEntityMessage.HasField("entityLockRequested"): + entity_id = EntityInstanceId.parse(event.eventSent.instanceId) + event_id = json.loads(event.eventSent.input.value)["id"] + ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId) + else: + return else: eventType = event.WhichOneof("eventType") raise task.OrchestrationStateError( From 0647489fa3c838721e6f92533ea05a78763c60cf Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 5 Dec 2025 12:29:21 -0700 Subject: [PATCH 2/2] Extract shared code, add additional conditionals --- durabletask/worker.py | 82 +++++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 39 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index d0414c0..0ce89a4 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -12,7 +12,7 @@ from threading import Event, Thread from types import GeneratorType from enum import Enum -from typing import Any, Generator, Optional, Sequence, TypeVar, Union +from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union from packaging.version import InvalidVersion, parse import grpc @@ -1592,41 +1592,11 @@ def process_event( raise TypeError("Unexpected sub-orchestration task type") elif event.HasField("eventRaised"): if event.eventRaised.name in ctx._entity_task_id_map: - # This eventRaised represents the result of an entity operation after being translated to the old - # entity protocol by the Durable WebJobs extension entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None)) - if entity_id is None: - raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'") - if task_id is None: - raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'") - entity_task = ctx._pending_tasks.pop(task_id, None) - if not entity_task: - raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'") - result = None - if not ph.is_empty(event.eventRaised.input): - # TODO: Investigate why the event result is wrapped in a dict with "result" key - result = shared.from_json(event.eventRaised.input.value)["result"] - ctx._entity_context.recover_lock_after_call(entity_id) - entity_task.complete(result) - ctx.resume() + self._handle_entity_event_raised(ctx, event, entity_id, task_id, False) elif event.eventRaised.name in ctx._entity_lock_task_id_map: - # This eventRaised represents the result of an entity operation after being translated to the old - # entity protocol by the Durable WebJobs extension entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None)) - if entity_id is None: - raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'") - if task_id is None: - raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'") - entity_task = ctx._pending_tasks.pop(task_id, None) - if not entity_task: - raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'") - result = None - if not ph.is_empty(event.eventRaised.input): - # TODO: Investigate why the event result is wrapped in a dict with "result" key - result = shared.from_json(event.eventRaised.input.value)["result"] - ctx._entity_context.complete_acquire(event.eventRaised.name) - entity_task.complete(EntityLock(ctx)) - ctx.resume() + self._handle_entity_event_raised(ctx, event, entity_id, task_id, True) else: # event names are case-insensitive event_name = event.eventRaised.name.casefold() @@ -1791,15 +1761,11 @@ def process_event( action = ctx._pending_actions.pop(event.eventId, None) if action and action.HasField("sendEntityMessage"): if action.sendEntityMessage.HasField("entityOperationCalled"): - entity_id = EntityInstanceId.parse(event.eventSent.instanceId) - event_id = json.loads(event.eventSent.input.value)["id"] + entity_id, event_id = self._parse_entity_event_sent_input(event) ctx._entity_task_id_map[event_id] = (entity_id, event.eventId) elif action.sendEntityMessage.HasField("entityLockRequested"): - entity_id = EntityInstanceId.parse(event.eventSent.instanceId) - event_id = json.loads(event.eventSent.input.value)["id"] + entity_id, event_id = self._parse_entity_event_sent_input(event) ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId) - else: - return else: eventType = event.WhichOneof("eventType") raise task.OrchestrationStateError( @@ -1809,6 +1775,44 @@ def process_event( # The orchestrator generator function completed ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED) + def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[EntityInstanceId, str]: + try: + entity_id = EntityInstanceId.parse(event.eventSent.instanceId) + except ValueError: + raise RuntimeError(f"Could not parse entity ID from instanceId '{event.eventSent.instanceId}'") + try: + event_id = json.loads(event.eventSent.input.value)["id"] + except (json.JSONDecodeError, KeyError, TypeError) as ex: + raise RuntimeError(f"Could not parse event ID from eventSent input '{event.eventSent.input.value}'") from ex + return entity_id, event_id + + def _handle_entity_event_raised(self, + ctx: _RuntimeOrchestrationContext, + event: pb.HistoryEvent, + entity_id: Optional[EntityInstanceId], + task_id: Optional[int], + is_lock_event: bool): + # This eventRaised represents the result of an entity operation after being translated to the old + # entity protocol by the Durable WebJobs extension + if entity_id is None: + raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'") + if task_id is None: + raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'") + entity_task = ctx._pending_tasks.pop(task_id, None) + if not entity_task: + raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'") + result = None + if not ph.is_empty(event.eventRaised.input): + # TODO: Investigate why the event result is wrapped in a dict with "result" key + result = shared.from_json(event.eventRaised.input.value)["result"] + if is_lock_event: + ctx._entity_context.complete_acquire(event.eventRaised.name) + entity_task.complete(EntityLock(ctx)) + else: + ctx._entity_context.recover_lock_after_call(entity_id) + entity_task.complete(result) + ctx.resume() + def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]: if versioning is None: return None