Skip to content

Commit 44f6a95

Browse files
authored
Allow try-catching Entity exceptions in orchestrators (#324)
1 parent 1aac4bc commit 44f6a95

File tree

5 files changed

+77
-11
lines changed

5 files changed

+77
-11
lines changed

azure/durable_functions/models/TaskOrchestrationExecutor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,12 @@ def parse_history_event(directive_result):
180180
# retrieve result
181181
new_value = parse_history_event(event)
182182
if task._api_name == "CallEntityAction":
183-
new_value = ResponseMessage.from_dict(new_value)
184-
new_value = json.loads(new_value.result)
183+
event_payload = ResponseMessage.from_dict(new_value)
184+
new_value = json.loads(event_payload.result)
185+
186+
if event_payload.is_exception:
187+
new_value = Exception(new_value)
188+
is_success = False
185189
else:
186190
# generate exception
187191
new_value = Exception(f"{event.Reason} \n {event.Details}")

azure/durable_functions/models/entities/ResponseMessage.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class ResponseMessage:
77
Specifies the response of an entity, as processed by the durable-extension.
88
"""
99

10-
def __init__(self, result: str):
10+
def __init__(self, result: str, is_exception: bool = False):
1111
"""Instantiate a ResponseMessage.
1212
1313
Specifies the response of an entity, as processed by the durable-extension.
@@ -18,6 +18,7 @@ def __init__(self, result: str):
1818
The result provided by the entity
1919
"""
2020
self.result = result
21+
self.is_exception = is_exception
2122
# TODO: JS has an additional exceptionType field, but does not use it
2223

2324
@classmethod
@@ -34,5 +35,6 @@ def from_dict(cls, d: Dict[str, Any]) -> 'ResponseMessage':
3435
ResponseMessage:
3536
The ResponseMessage built from the provided dictionary
3637
"""
37-
result = cls(d["result"])
38+
is_error = "exceptionType" in d.keys()
39+
result = cls(d["result"], is_error)
3840
return result

tests/orchestrator/orchestrator_test_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ def assert_entity_state_equals(expected, result):
3030
assert_attribute_equal(expected, result, "signals")
3131

3232
def assert_results_are_equal(expected: Dict[str, Any], result: Dict[str, Any]) -> bool:
33-
assert_attribute_equal(expected, result, "result")
34-
assert_attribute_equal(expected, result, "isError")
33+
for (payload_expected, payload_result) in zip(expected, result):
34+
assert_attribute_equal(payload_expected, payload_result, "result")
35+
assert_attribute_equal(payload_expected, payload_result, "isError")
3536

3637
def assert_attribute_equal(expected, result, attribute):
3738
if attribute in expected:

tests/orchestrator/test_entity.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from azure.durable_functions.models.ReplaySchema import ReplaySchema
22
from .orchestrator_test_utils \
3-
import assert_orchestration_state_equals, get_orchestration_state_result, assert_valid_schema, \
3+
import assert_orchestration_state_equals, assert_results_are_equal, get_orchestration_state_result, assert_valid_schema, \
44
get_entity_state_result, assert_entity_state_equals
55
from tests.test_utils.ContextBuilder import ContextBuilder
66
from tests.test_utils.EntityContextBuilder import EntityContextBuilder
@@ -23,6 +23,14 @@ def generator_function_call_entity(context):
2323
outputs.append(x)
2424
return outputs
2525

26+
def generator_function_catch_entity_exception(context):
27+
entityId = df.EntityId("Counter", "myCounter")
28+
try:
29+
yield context.call_entity(entityId, "add", 3)
30+
return "No exception thrown"
31+
except:
32+
return "Exception thrown"
33+
2634
def generator_function_signal_entity(context):
2735
outputs = []
2836
entityId = df.EntityId("Counter", "myCounter")
@@ -53,6 +61,29 @@ def counter_entity_function(context):
5361
context.set_state(current_value)
5462
context.set_result(result)
5563

64+
def counter_entity_function_raises_exception(context):
65+
raise Exception("boom!")
66+
67+
def test_entity_raises_exception():
68+
# Create input batch
69+
batch = []
70+
add_to_batch(batch, name="get")
71+
context_builder = EntityContextBuilder(batch=batch)
72+
73+
# Run the entity, get observed result
74+
result = get_entity_state_result(
75+
context_builder,
76+
counter_entity_function_raises_exception,
77+
)
78+
79+
# Construct expected result
80+
expected_state = entity_base_expected_state()
81+
apply_operation(expected_state, result="boom!", state=None, is_error=True)
82+
expected = expected_state.to_json()
83+
84+
# Ensure expectation matches observed behavior
85+
#assert_valid_schema(result)
86+
assert_entity_state_equals(expected, result)
5687

5788
def test_entity_signal_then_call():
5889
"""Tests that a simple counter entity outputs the correct value
@@ -161,11 +192,11 @@ def add_signal_entity_action(state: OrchestratorState, id_: df.EntityId, op: str
161192
state.actions.append([action])
162193

163194
def add_call_entity_completed_events(
164-
context_builder: ContextBuilder, op: str, instance_id=str, input_=None, event_id=0):
195+
context_builder: ContextBuilder, op: str, instance_id=str, input_=None, event_id=0, is_error=False):
165196
context_builder.add_event_sent_event(instance_id, event_id)
166197
context_builder.add_orchestrator_completed_event()
167198
context_builder.add_orchestrator_started_event()
168-
context_builder.add_event_raised_event(name="0000", id_=0, input_=input_, is_entity=True)
199+
context_builder.add_event_raised_event(name="0000", id_=0, input_=input_, is_entity=True, is_error=is_error)
169200

170201
def test_call_entity_sent():
171202
context_builder = ContextBuilder('test_simple_function')
@@ -233,4 +264,29 @@ def test_call_entity_raised():
233264

234265
#assert_valid_schema(result)
235266

267+
assert_orchestration_state_equals(expected, result)
268+
269+
def test_call_entity_catch_exception():
270+
entityId = df.EntityId("Counter", "myCounter")
271+
context_builder = ContextBuilder('catch exceptions')
272+
add_call_entity_completed_events(
273+
context_builder,
274+
"add",
275+
df.EntityId.get_scheduler_id(entityId),
276+
input_="I am an error!",
277+
event_id=0,
278+
is_error=True
279+
)
280+
281+
result = get_orchestration_state_result(
282+
context_builder, generator_function_catch_entity_exception)
283+
284+
expected_state = base_expected_state(
285+
"Exception thrown"
286+
)
287+
288+
add_call_entity_action(expected_state, entityId, "add", 3)
289+
expected_state._is_done = True
290+
expected = expected_state.to_json()
291+
236292
assert_orchestration_state_equals(expected, result)

tests/test_utils/ContextBuilder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,14 @@ def add_execution_started_event(
125125
event.Input = input_
126126
self.history_events.append(event)
127127

128-
def add_event_raised_event(self, name:str, id_: int, input_=None, timestamp=None, is_entity=False):
128+
def add_event_raised_event(self, name:str, id_: int, input_=None, timestamp=None, is_entity=False, is_error = False):
129129
event = self.get_base_event(HistoryEventType.EVENT_RAISED, id_=id_, timestamp=timestamp)
130130
event.Name = name
131131
if is_entity:
132-
event.Input = json.dumps({ "result": json.dumps(input_) })
132+
if is_error:
133+
event.Input = json.dumps({ "result": json.dumps(input_), "exceptionType": "True" })
134+
else:
135+
event.Input = json.dumps({ "result": json.dumps(input_) })
133136
else:
134137
event.Input = input_
135138
# event.timestamp = timestamp

0 commit comments

Comments
 (0)