Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 87 additions & 26 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from threading import Event, Thread
from types import GeneratorType
from enum import Enum
from typing import Any, Generator, Optional, Sequence, TypeVar, Union
from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union
from packaging.version import InvalidVersion, parse

import grpc
Expand Down Expand Up @@ -828,6 +828,7 @@ def __init__(self, instance_id: str, registry: _Registry):
self._pending_tasks: dict[int, task.CompletableTask] = {}
# Maps entity ID to task ID
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
# Maps criticalSectionId to task ID
self._entity_lock_id_map: dict[str, int] = {}
self._sequence_number = 0
Expand Down Expand Up @@ -1590,33 +1591,40 @@ def process_event(
else:
raise TypeError("Unexpected sub-orchestration task type")
elif event.HasField("eventRaised"):
# event names are case-insensitive
event_name = event.eventRaised.name.casefold()
if not ctx.is_replaying:
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
task_list = ctx._pending_events.get(event_name, None)
decoded_result: Optional[Any] = None
if task_list:
event_task = task_list.pop(0)
if not ph.is_empty(event.eventRaised.input):
decoded_result = shared.from_json(event.eventRaised.input.value)
event_task.complete(decoded_result)
if not task_list:
del ctx._pending_events[event_name]
ctx.resume()
if event.eventRaised.name in ctx._entity_task_id_map:
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
self._handle_entity_event_raised(ctx, event, entity_id, task_id, False)
elif event.eventRaised.name in ctx._entity_lock_task_id_map:
entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None))
self._handle_entity_event_raised(ctx, event, entity_id, task_id, True)
else:
# buffer the event
event_list = ctx._received_events.get(event_name, None)
if not event_list:
event_list = []
ctx._received_events[event_name] = event_list
if not ph.is_empty(event.eventRaised.input):
decoded_result = shared.from_json(event.eventRaised.input.value)
event_list.append(decoded_result)
# event names are case-insensitive
event_name = event.eventRaised.name.casefold()
if not ctx.is_replaying:
self._logger.info(
f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
)
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
task_list = ctx._pending_events.get(event_name, None)
decoded_result: Optional[Any] = None
if task_list:
event_task = task_list.pop(0)
if not ph.is_empty(event.eventRaised.input):
decoded_result = shared.from_json(event.eventRaised.input.value)
event_task.complete(decoded_result)
if not task_list:
del ctx._pending_events[event_name]
ctx.resume()
else:
# buffer the event
event_list = ctx._received_events.get(event_name, None)
if not event_list:
event_list = []
ctx._received_events[event_name] = event_list
if not ph.is_empty(event.eventRaised.input):
decoded_result = shared.from_json(event.eventRaised.input.value)
event_list.append(decoded_result)
if not ctx.is_replaying:
self._logger.info(
f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
)
elif event.HasField("executionSuspended"):
if not self._is_suspended and not ctx.is_replaying:
self._logger.info(f"{ctx.instance_id}: Execution suspended.")
Expand Down Expand Up @@ -1743,6 +1751,21 @@ def process_event(
self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
pass
elif event.HasField("orchestratorCompleted"):
# Added in Functions only (for some reason) and does not affect orchestrator flow
pass
elif event.HasField("eventSent"):
# Check if this eventSent corresponds to an entity operation call after being translated to the old
# entity protocol by the Durable WebJobs extension. If so, treat this message similarly to
# entityOperationCalled and remove the pending action. Also store the entity id and event id for later
action = ctx._pending_actions.pop(event.eventId, None)
if action and action.HasField("sendEntityMessage"):
if action.sendEntityMessage.HasField("entityOperationCalled"):
entity_id, event_id = self._parse_entity_event_sent_input(event)
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
elif action.sendEntityMessage.HasField("entityLockRequested"):
entity_id, event_id = self._parse_entity_event_sent_input(event)
ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId)
else:
eventType = event.WhichOneof("eventType")
raise task.OrchestrationStateError(
Expand All @@ -1752,6 +1775,44 @@ def process_event(
# The orchestrator generator function completed
ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)

def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[EntityInstanceId, str]:
try:
entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
except ValueError:
raise RuntimeError(f"Could not parse entity ID from instanceId '{event.eventSent.instanceId}'")
try:
event_id = json.loads(event.eventSent.input.value)["id"]
except (json.JSONDecodeError, KeyError, TypeError) as ex:
raise RuntimeError(f"Could not parse event ID from eventSent input '{event.eventSent.input.value}'") from ex
return entity_id, event_id

def _handle_entity_event_raised(self,
ctx: _RuntimeOrchestrationContext,
event: pb.HistoryEvent,
entity_id: Optional[EntityInstanceId],
task_id: Optional[int],
is_lock_event: bool):
# This eventRaised represents the result of an entity operation after being translated to the old
# entity protocol by the Durable WebJobs extension
if entity_id is None:
raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'")
if task_id is None:
raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'")
entity_task = ctx._pending_tasks.pop(task_id, None)
if not entity_task:
raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'")
result = None
if not ph.is_empty(event.eventRaised.input):
# TODO: Investigate why the event result is wrapped in a dict with "result" key
result = shared.from_json(event.eventRaised.input.value)["result"]
if is_lock_event:
ctx._entity_context.complete_acquire(event.eventRaised.name)
entity_task.complete(EntityLock(ctx))
else:
ctx._entity_context.recover_lock_after_call(entity_id)
entity_task.complete(result)
ctx.resume()

def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
if versioning is None:
return None
Expand Down
Loading