Skip to content

Commit 7dc8b85

Browse files
committed
Improvements
- Add rejection - cleanup comments - suborch default versioning
1 parent e03608f commit 7dc8b85

File tree

3 files changed

+102
-27
lines changed

3 files changed

+102
-27
lines changed

durabletask/internal/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class AbandonOrchestrationError(Exception):
2+
def __init__(self, *args: object) -> None:
3+
super().__init__(*args)

durabletask/worker.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from google.protobuf import empty_pb2
1919

2020
import durabletask.internal.helpers as ph
21+
import durabletask.internal.exceptions as pe
2122
import durabletask.internal.orchestrator_service_pb2 as pb
2223
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
2324
import durabletask.internal.shared as shared
@@ -113,10 +114,10 @@ def __init__(self, version: Optional[str] = None,
113114
"""Initialize versioning options.
114115
115116
Args:
116-
version: The specific version to use for orchestrators and activities.
117-
default_version: The default version to use if no specific version is provided.
118-
match_strategy: The strategy to use for matching versions.
119-
failure_strategy: The strategy to use if versioning fails.
117+
version: The version of orchestrations that the worker can work on.
118+
default_version: The default version that will be used for starting new orchestrations.
119+
match_strategy: The versioning strategy for the Durable Task worker.
120+
failure_strategy: The versioning failure strategy for the Durable Task worker.
120121
"""
121122
self.version = version
122123
self.default_version = default_version
@@ -333,7 +334,7 @@ def add_activity(self, fn: task.Activity) -> str:
333334
return self._registry.add_activity(fn)
334335

335336
def use_versioning(self, version: VersioningOptions) -> None:
336-
"""Sets the default version for orchestrators and activities."""
337+
"""Initializes versioning options for sub-orchestrators and activities."""
337338
if self._is_running:
338339
raise RuntimeError("Cannot set default version while the worker is running.")
339340
self._registry.versioning = version
@@ -564,14 +565,24 @@ def _execute_orchestrator(
564565
completionToken,
565566
):
566567
try:
567-
executor = _OrchestrationExecutor(self._registry, self._logger)
568+
executor = _OrchestrationExecutor(self._registry, self._logger, stub)
568569
result = executor.execute(req.instanceId, req.pastEvents, req.newEvents)
569570
res = pb.OrchestratorResponse(
570571
instanceId=req.instanceId,
571572
actions=result.actions,
572573
customStatus=ph.get_string_value(result.encoded_custom_status),
573574
completionToken=completionToken,
574575
)
576+
except pe.AbandonOrchestrationError as ex:
577+
self._logger.info(
578+
f"Abandoning orchestration. InstanceId = '{req.instanceId}'. Completion token = '{completionToken}'"
579+
)
580+
stub.AbandonTaskOrchestratorWorkItem(
581+
pb.AbandonOrchestrationTaskRequest(
582+
completionToken=completionToken
583+
)
584+
)
585+
return
575586
except Exception as ex:
576587
self._logger.exception(
577588
f"An error occurred while trying to execute instance '{req.instanceId}': {ex}"
@@ -633,7 +644,7 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
633644
_generator: Optional[Generator[task.Task, Any, Any]]
634645
_previous_task: Optional[task.Task]
635646

636-
def __init__(self, instance_id: str):
647+
def __init__(self, instance_id: str, registry: _Registry):
637648
self._generator = None
638649
self._is_replaying = True
639650
self._is_complete = False
@@ -643,6 +654,7 @@ def __init__(self, instance_id: str):
643654
self._sequence_number = 0
644655
self._current_utc_datetime = datetime(1000, 1, 1)
645656
self._instance_id = instance_id
657+
self._registry = registry
646658
self._completion_status: Optional[pb.OrchestrationStatus] = None
647659
self._received_events: dict[str, list[Any]] = {}
648660
self._pending_events: dict[str, list[task.CompletableTask]] = {}
@@ -831,14 +843,16 @@ def call_sub_orchestrator(
831843
) -> task.Task[TOutput]:
832844
id = self.next_sequence_number()
833845
orchestrator_name = task.get_name(orchestrator)
846+
default_version = self._registry.versioning.default_version if self._registry.versioning else None
847+
orchestrator_version = version if version else default_version
834848
self.call_activity_function_helper(
835849
id,
836850
orchestrator_name,
837851
input=input,
838852
retry_policy=retry_policy,
839853
is_sub_orch=True,
840854
instance_id=instance_id,
841-
version=version,
855+
version=orchestrator_version
842856
)
843857
return self._pending_tasks.get(id, task.CompletableTask())
844858

@@ -937,11 +951,12 @@ def __init__(
937951
class _OrchestrationExecutor:
938952
_generator: Optional[task.Orchestrator] = None
939953

940-
def __init__(self, registry: _Registry, logger: logging.Logger):
954+
def __init__(self, registry: _Registry, logger: logging.Logger, stub: stubs.TaskHubSidecarServiceStub):
941955
self._registry = registry
942956
self._logger = logger
943957
self._is_suspended = False
944958
self._suspended_events: list[pb.HistoryEvent] = []
959+
self._stub = stub
945960

946961
def execute(
947962
self,
@@ -954,9 +969,18 @@ def execute(
954969
"The new history event list must have at least one event in it."
955970
)
956971

957-
ctx = _RuntimeOrchestrationContext(instance_id)
972+
ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
958973
version_failure = None
959974
try:
975+
# Rebuild local state by replaying old history into the orchestrator function
976+
self._logger.debug(
977+
f"{instance_id}: Rebuilding local state with {len(old_events)} history event..."
978+
)
979+
ctx._is_replaying = True
980+
for old_event in old_events:
981+
self.process_event(ctx, old_event)
982+
983+
# Process versioning if applicable
960984
execution_started_events = [e.executionStarted for e in old_events if e.HasField("executionStarted")]
961985
if self._registry.versioning and len(execution_started_events) > 0:
962986
execution_started_event = execution_started_events[-1]
@@ -970,19 +994,7 @@ def execute(
970994
f"Error action = '{self._registry.versioning.failure_strategy}'. "
971995
f"Version error = '{version_failure}'"
972996
)
973-
if self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
974-
raise VersionFailureException
975-
elif self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
976-
# TODO: We don't have abandoned orchestrations yet, so we just fail
977-
raise VersionFailureException
978-
979-
# Rebuild local state by replaying old history into the orchestrator function
980-
self._logger.debug(
981-
f"{instance_id}: Rebuilding local state with {len(old_events)} history event..."
982-
)
983-
ctx._is_replaying = True
984-
for old_event in old_events:
985-
self.process_event(ctx, old_event)
997+
raise VersionFailureException
986998

987999
# Get new actions by executing newly received events into the orchestrator function
9881000
if self._logger.level <= logging.DEBUG:
@@ -995,10 +1007,13 @@ def execute(
9951007
self.process_event(ctx, new_event)
9961008

9971009
except VersionFailureException as ex:
998-
if version_failure:
999-
ctx.set_failed(version_failure)
1000-
else:
1001-
ctx.set_failed(ex)
1010+
if self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
1011+
if version_failure:
1012+
ctx.set_failed(version_failure)
1013+
else:
1014+
ctx.set_failed(ex)
1015+
elif self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
1016+
raise pe.AbandonOrchestrationError
10021017

10031018
except Exception as ex:
10041019
# Unhandled exceptions fail the orchestration

tests/durabletask-azuremanaged/test_dts_orchestration_versioning_e2e.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,60 @@ def test_upper_version_worker_succeeds():
153153

154154

155155
def test_upper_version_worker_strict_fails():
156+
# Start a worker, which will connect to the sidecar in a background thread
157+
instance_id: str = ''
158+
thrown = False
159+
try:
160+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
161+
taskhub=taskhub_name, token_credential=None) as w:
162+
w.add_orchestrator(single_activity)
163+
w.add_activity(plus_one)
164+
w.use_versioning(worker.VersioningOptions(
165+
version="1.0.0",
166+
default_version="1.1.0",
167+
match_strategy=worker.VersionMatchStrategy.STRICT,
168+
failure_strategy=worker.VersionFailureStrategy.REJECT
169+
))
170+
w.start()
171+
172+
task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
173+
taskhub=taskhub_name, token_credential=None,
174+
default_version="1.1.0")
175+
instance_id = task_hub_client.schedule_new_orchestration(single_activity, input=1)
176+
state = task_hub_client.wait_for_orchestration_completion(
177+
instance_id, timeout=5)
178+
except TimeoutError as e:
179+
thrown = True
180+
assert str(e).find("Timed-out waiting for the orchestration to complete") >= 0
181+
182+
assert thrown is True
183+
184+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
185+
taskhub=taskhub_name, token_credential=None) as w:
186+
w.add_orchestrator(single_activity)
187+
w.add_activity(plus_one)
188+
w.use_versioning(worker.VersioningOptions(
189+
version="1.1.0",
190+
default_version="1.1.0",
191+
match_strategy=worker.VersionMatchStrategy.STRICT,
192+
failure_strategy=worker.VersionFailureStrategy.REJECT
193+
))
194+
w.start()
195+
196+
task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
197+
taskhub=taskhub_name, token_credential=None,
198+
default_version="1.1.0")
199+
state = task_hub_client.wait_for_orchestration_completion(
200+
instance_id, timeout=5)
201+
202+
assert state is not None
203+
assert state.name == task.get_name(single_activity)
204+
assert state.instance_id == instance_id
205+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
206+
assert state.failure_details is None
207+
208+
209+
def test_reject_abandons_and_reprocess():
156210
# Start a worker, which will connect to the sidecar in a background thread
157211
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
158212
taskhub=taskhub_name, token_credential=None) as w:
@@ -181,6 +235,9 @@ def test_upper_version_worker_strict_fails():
181235
assert state.failure_details.message.find("The orchestration version '1.0.0' does not match the worker version '1.1.0'.") >= 0
182236

183237

238+
# Sub-orchestration tests
239+
240+
184241
def sequence_suborchestator(ctx: task.OrchestrationContext, start_val: int):
185242
numbers = []
186243
for current in range(start_val, start_val + 5):

0 commit comments

Comments
 (0)