Skip to content

Commit 9965ba4

Browse files
committed
Finish entity support
1 parent 18145f8 commit 9965ba4

File tree

4 files changed

+54
-31
lines changed

4 files changed

+54
-31
lines changed

durabletask-azurefunctions/durabletask/azurefunctions/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _get_client_response_links(self, request: func.HttpRequest, instance_id: str
8383
@staticmethod
8484
def _get_instance_status_url(request: func.HttpRequest, instance_id: str) -> str:
8585
request_url = urlparse(request.url)
86-
location_url = f"{request_url.scheme}://{request_url.netloc}{request_url.path}"
86+
location_url = f"{request_url.scheme}://{request_url.netloc}"
8787
encoded_instance_id = quote(instance_id)
8888
location_url = location_url + "/runtime/webhooks/durabletask/instances/" + encoded_instance_id
8989
return location_url

durabletask/entities/entity_instance_id.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __lt__(self, other):
2020
return str(self) < str(other)
2121

2222
@staticmethod
23-
def parse(entity_id: str) -> Optional["EntityInstanceId"]:
23+
def parse(entity_id: str) -> "EntityInstanceId":
2424
"""Parse a string representation of an entity ID into an EntityInstanceId object.
2525
2626
Parameters

durabletask/worker.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,33 +1593,52 @@ def process_event(
15931593
else:
15941594
raise TypeError("Unexpected sub-orchestration task type")
15951595
elif event.HasField("eventRaised"):
1596-
# event names are case-insensitive
1597-
event_name = event.eventRaised.name.casefold()
1598-
if not ctx.is_replaying:
1599-
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
1600-
task_list = ctx._pending_events.get(event_name, None)
1601-
decoded_result: Optional[Any] = None
1602-
if task_list:
1603-
event_task = task_list.pop(0)
1596+
if event.eventRaised.name in ctx._entity_task_id_map:
1597+
# This eventRaised represents the result of an entity operation after being translated to the old
1598+
# entity protocol by the Durable WebJobs extension
1599+
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
1600+
if entity_id is None:
1601+
raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'")
1602+
if task_id is None:
1603+
raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'")
1604+
entity_task = ctx._pending_tasks.pop(task_id, None)
1605+
if not entity_task:
1606+
raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'")
1607+
result = None
16041608
if not ph.is_empty(event.eventRaised.input):
1605-
decoded_result = shared.from_json(event.eventRaised.input.value)
1606-
event_task.complete(decoded_result)
1607-
if not task_list:
1608-
del ctx._pending_events[event_name]
1609+
# TODO: Investigate why the event result is wrapped in a dict with "result" key
1610+
result = shared.from_json(event.eventRaised.input.value)["result"]
1611+
ctx._entity_context.recover_lock_after_call(entity_id)
1612+
entity_task.complete(result)
16091613
ctx.resume()
16101614
else:
1611-
# buffer the event
1612-
event_list = ctx._received_events.get(event_name, None)
1613-
if not event_list:
1614-
event_list = []
1615-
ctx._received_events[event_name] = event_list
1616-
if not ph.is_empty(event.eventRaised.input):
1617-
decoded_result = shared.from_json(event.eventRaised.input.value)
1618-
event_list.append(decoded_result)
1615+
# event names are case-insensitive
1616+
event_name = event.eventRaised.name.casefold()
16191617
if not ctx.is_replaying:
1620-
self._logger.info(
1621-
f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1622-
)
1618+
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
1619+
task_list = ctx._pending_events.get(event_name, None)
1620+
decoded_result: Optional[Any] = None
1621+
if task_list:
1622+
event_task = task_list.pop(0)
1623+
if not ph.is_empty(event.eventRaised.input):
1624+
decoded_result = shared.from_json(event.eventRaised.input.value)
1625+
event_task.complete(decoded_result)
1626+
if not task_list:
1627+
del ctx._pending_events[event_name]
1628+
ctx.resume()
1629+
else:
1630+
# buffer the event
1631+
event_list = ctx._received_events.get(event_name, None)
1632+
if not event_list:
1633+
event_list = []
1634+
ctx._received_events[event_name] = event_list
1635+
if not ph.is_empty(event.eventRaised.input):
1636+
decoded_result = shared.from_json(event.eventRaised.input.value)
1637+
event_list.append(decoded_result)
1638+
if not ctx.is_replaying:
1639+
self._logger.info(
1640+
f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1641+
)
16231642
elif event.HasField("executionSuspended"):
16241643
if not self._is_suspended and not ctx.is_replaying:
16251644
self._logger.info(f"{ctx.instance_id}: Execution suspended.")
@@ -1750,11 +1769,15 @@ def process_event(
17501769
# Added in Functions only (for some reason) and does not affect orchestrator flow
17511770
pass
17521771
elif event.HasField("eventSent"):
1753-
# Added in Functions only (for some reason) and does not affect orchestrator flow
1754-
pass
1755-
elif event.HasField("eventRaised"):
1756-
# Added in Functions only (for some reason) and does not affect orchestrator flow
1757-
pass
1772+
# Check if this eventSent corresponds to an entity operation call after being translated to the old
1773+
# entity protocol by the Durable WebJobs extension. If so, treat this message similarly to
1774+
# entityOperationCalled and remove the pending action. Also store the entity id and event id for later
1775+
action = ctx._pending_actions.pop(event.eventId, None)
1776+
if action and action.HasField("sendEntityMessage") and action.sendEntityMessage.HasField("entityOperationCalled"):
1777+
entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
1778+
event_id = json.loads(event.eventSent.input.value)["id"]
1779+
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
1780+
return
17581781
else:
17591782
eventType = event.WhichOneof("eventType")
17601783
raise task.OrchestrationStateError(

examples/entities/function_based_entity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def counter(ctx: entities.EntityContext, input: int) -> Optional[int]:
1414
if ctx.operation == "set":
1515
ctx.set_state(input)
16-
if ctx.operation == "add":
16+
elif ctx.operation == "add":
1717
current_state = ctx.get_state(int, 0)
1818
new_state = current_state + (input or 1)
1919
ctx.set_state(new_state)

0 commit comments

Comments
 (0)