Skip to content

Commit 1aac4bc

Browse files
authored
Add timer interface to Tasks (#323)
1 parent 67f0a87 commit 1aac4bc

File tree

4 files changed

+96
-7
lines changed

4 files changed

+96
-7
lines changed

azure/durable_functions/models/DurableOrchestrationContext.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections import defaultdict
22
from azure.durable_functions.models.actions.SignalEntityAction import SignalEntityAction
33
from azure.durable_functions.models.actions.CallEntityAction import CallEntityAction
4-
from azure.durable_functions.models.Task import TaskBase
4+
from azure.durable_functions.models.Task import TaskBase, TimerTask
55
from azure.durable_functions.models.actions.CallHttpAction import CallHttpAction
66
from azure.durable_functions.models.DurableHttpRequest import DurableHttpRequest
77
from azure.durable_functions.models.actions.CallSubOrchestratorWithRetryAction import \
@@ -100,7 +100,8 @@ def from_json(cls, json_string: str):
100100
def _generate_task(self, action: Action,
101101
retry_options: Optional[RetryOptions] = None,
102102
id_: Optional[Union[int, str]] = None,
103-
parent: Optional[TaskBase] = None) -> Union[AtomicTask, RetryAbleTask]:
103+
parent: Optional[TaskBase] = None,
104+
task_constructor=AtomicTask) -> Union[AtomicTask, RetryAbleTask, TimerTask]:
104105
"""Generate an atomic or retryable Task based on an input.
105106
106107
Parameters
@@ -124,7 +125,7 @@ def _generate_task(self, action: Action,
124125
action_payload = [action]
125126
else:
126127
action_payload = action
127-
task = AtomicTask(id_, action_payload)
128+
task = task_constructor(id_, action_payload)
128129
task.parent = parent
129130

130131
# if task is retryable, provide the retryable wrapper class
@@ -517,7 +518,7 @@ def create_timer(self, fire_at: datetime.datetime) -> TaskBase:
517518
A Durable Timer Task that schedules the timer to wake up the activity
518519
"""
519520
action = CreateTimerAction(fire_at)
520-
task = self._generate_task(action)
521+
task = self._generate_task(action, task_constructor=TimerTask)
521522
return task
522523

523524
def wait_for_external_event(self, name: str) -> TaskBase:

azure/durable_functions/models/Task.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from azure.durable_functions.models.actions.Action import Action
66
from azure.durable_functions.models.actions.WhenAnyAction import WhenAnyAction
77
from azure.durable_functions.models.actions.WhenAllAction import WhenAllAction
8+
from azure.durable_functions.models.actions.CreateTimerAction import CreateTimerAction
89

910
import enum
1011
from typing import Any, List, Optional, Set, Type, Union
@@ -56,6 +57,14 @@ def __init__(self, id_: Union[int, str], actions: Union[List[Action], Action]):
5657
self.action_repr: Union[List[Action], Action] = actions
5758
self.is_played = False
5859

60+
@property
61+
def is_completed(self) -> bool:
62+
"""Get indicator of whether the task completed.
63+
64+
Note that completion is not equivalent to success.
65+
"""
66+
return not(self.state is TaskState.RUNNING)
67+
5968
def set_is_played(self, is_played: bool):
6069
"""Set the is_played flag for the Task.
6170
@@ -208,7 +217,47 @@ def try_set_value(self, child: TaskBase):
208217
class AtomicTask(TaskBase):
209218
"""A Task with no subtasks."""
210219

211-
pass
220+
def _get_action(self) -> Action:
221+
action: Action
222+
if isinstance(self.action_repr, list):
223+
action = self.action_repr[0]
224+
else:
225+
action = self.action_repr
226+
return action
227+
228+
229+
class TimerTask(AtomicTask):
230+
"""A Timer Task."""
231+
232+
def __init__(self, id_: Union[int, str], action: CreateTimerAction):
233+
super().__init__(id_, action)
234+
self.action_repr: Union[List[CreateTimerAction], CreateTimerAction]
235+
236+
@property
237+
def is_cancelled(self) -> bool:
238+
"""Check if the Timer is cancelled.
239+
240+
Returns
241+
-------
242+
bool
243+
Returns whether a timer has been cancelled or not
244+
"""
245+
action: CreateTimerAction = self._get_action()
246+
return action.is_cancelled
247+
248+
def cancel(self):
249+
"""Cancel a timer.
250+
251+
Raises
252+
------
253+
ValueError
254+
Raises an error if the task is already completed and an attempt is made to cancel it
255+
"""
256+
if not self.is_completed:
257+
action: CreateTimerAction = self._get_action()
258+
action.is_cancelled = True
259+
else:
260+
raise ValueError("Cannot cancel a completed task.")
212261

213262

214263
class WhenAllTask(CompoundTask):

tests/orchestrator/test_create_timer.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ def generator_function(context):
2828
yield context.create_timer(fire_at)
2929
return "Done!"
3030

31+
def generator_function_timer_can_be_cancelled(context):
32+
time_limit1 = context.current_utc_datetime + timedelta(minutes=5)
33+
timer_task1 = context.create_timer(time_limit1)
34+
35+
time_limit2 = context.current_utc_datetime + timedelta(minutes=10)
36+
timer_task2 = context.create_timer(time_limit2)
37+
38+
winner = yield context.task_any([timer_task1, timer_task2])
39+
if winner == timer_task1:
40+
timer_task2.cancel()
41+
return "Done!"
42+
else:
43+
raise Exception("timer task 1 should complete before timer task 2")
44+
3145
def add_timer_action(state: OrchestratorState, fire_at: datetime):
3246
action = CreateTimerAction(fire_at=fire_at)
3347
state._actions.append([action])
@@ -64,4 +78,25 @@ def test_timers_comparison_with_relaxed_precision():
6478
#assert_valid_schema(result)
6579
# TODO: getting the following error when validating the schema
6680
# "Additional properties are not allowed ('fireAt', 'isCanceled' were unexpected)">
67-
assert_orchestration_state_equals(expected, result)
81+
assert_orchestration_state_equals(expected, result)
82+
83+
def test_timers_can_be_cancelled():
84+
85+
context_builder = ContextBuilder("test_timers_can_be_cancelled")
86+
fire_at1 = context_builder.current_datetime + timedelta(minutes=5)
87+
fire_at2 = context_builder.current_datetime + timedelta(minutes=10)
88+
add_timer_fired_events(context_builder, 0, str(fire_at1))
89+
add_timer_fired_events(context_builder, 1, str(fire_at2))
90+
91+
result = get_orchestration_state_result(
92+
context_builder, generator_function_timer_can_be_cancelled)
93+
94+
expected_state = base_expected_state(output='Done!')
95+
expected_state._actions.append(
96+
[CreateTimerAction(fire_at=fire_at1), CreateTimerAction(fire_at=fire_at2, is_cancelled=True)])
97+
98+
expected_state._is_done = True
99+
expected = expected_state.to_json()
100+
101+
assert_orchestration_state_equals(expected, result)
102+
assert result["actions"][0][1]["isCanceled"]

tests/orchestrator/test_sequential_orchestrator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,14 @@ def generator_function_new_guid(context):
136136
outputs.append(str(output3))
137137
return outputs
138138

139-
140139
def base_expected_state(output=None, replay_schema: ReplaySchema = ReplaySchema.V1) -> OrchestratorState:
141140
return OrchestratorState(is_done=False, actions=[], output=output, replay_schema=replay_schema)
142141

142+
def add_timer_fired_events(context_builder: ContextBuilder, id_: int, timestamp: str):
143+
fire_at: str = context_builder.add_timer_created_event(id_, timestamp)
144+
context_builder.add_orchestrator_completed_event()
145+
context_builder.add_orchestrator_started_event()
146+
context_builder.add_timer_fired_event(id_=id_, fire_at=fire_at)
143147

144148
def add_hello_action(state: OrchestratorState, input_: str):
145149
action = CallActivityAction(function_name='Hello', input_=input_)

0 commit comments

Comments
 (0)