Skip to content

Commit 2d12059

Browse files
committed
Add orchestration versioning support
- Known gap - VersionFailureStrategy.REJECT, no abandon strategy yet
1 parent 45292b1 commit 2d12059

File tree

11 files changed

+475
-21
lines changed

11 files changed

+475
-21
lines changed

.vscode/settings.json

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
],
1515
"python.analysis.typeCheckingMode": "basic",
1616
"python.testing.pytestArgs": [
17-
"-v",
18-
"--cov=durabletask/",
19-
"--cov-report=lcov",
20-
"tests/"
17+
"tests"
2118
],
2219
"python.testing.unittestEnabled": false,
2320
"python.testing.pytestEnabled": true,
@@ -30,5 +27,6 @@
3027
"jacoco.xml",
3128
"coverage.cobertura.xml"
3229
],
33-
"makefile.configureOnOpen": false
30+
"makefile.configureOnOpen": false,
31+
"debugpy.debugJustMyCode": false
3432
}

durabletask-azuremanaged/durabletask/azuremanaged/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def __init__(self, *,
1717
host_address: str,
1818
taskhub: str,
1919
token_credential: Optional[TokenCredential],
20-
secure_channel: bool = True):
20+
secure_channel: bool = True,
21+
default_version: Optional[str] = None):
2122

2223
if not taskhub:
2324
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
@@ -30,4 +31,5 @@ def __init__(self, *,
3031
host_address=host_address,
3132
secure_channel=secure_channel,
3233
metadata=None,
33-
interceptors=interceptors)
34+
interceptors=interceptors,
35+
default_version=default_version)

durabletask/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
"""Durable Task SDK for Python"""
55

6-
from durabletask.worker import ConcurrencyOptions
6+
from durabletask.worker import ConcurrencyOptions, VersioningOptions
77

8-
__all__ = ["ConcurrencyOptions"]
8+
__all__ = ["ConcurrencyOptions", "VersioningOptions"]
99

1010
PACKAGE_NAME = "durabletask"

durabletask/client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def __init__(self, *,
9898
log_handler: Optional[logging.Handler] = None,
9999
log_formatter: Optional[logging.Formatter] = None,
100100
secure_channel: bool = False,
101-
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
101+
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
102+
default_version: Optional[str] = None):
102103

103104
# If the caller provided metadata, we need to create a new interceptor for it and
104105
# add it to the list of interceptors.
@@ -118,13 +119,15 @@ def __init__(self, *,
118119
)
119120
self._stub = stubs.TaskHubSidecarServiceStub(channel)
120121
self._logger = shared.get_logger("client", log_handler, log_formatter)
122+
self.default_version = default_version
121123

122124
def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
123125
input: Optional[TInput] = None,
124126
instance_id: Optional[str] = None,
125127
start_at: Optional[datetime] = None,
126128
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
127-
tags: Optional[dict[str, str]] = None) -> str:
129+
tags: Optional[dict[str, str]] = None,
130+
version: Optional[str] = None) -> str:
128131

129132
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
130133

@@ -133,9 +136,9 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
133136
instanceId=instance_id if instance_id else uuid.uuid4().hex,
134137
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
135138
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
136-
version=wrappers_pb2.StringValue(value=""),
139+
version=helpers.get_string_value(version if version else self.default_version),
137140
orchestrationIdReusePolicy=reuse_id_policy,
138-
tags=tags,
141+
tags=tags
139142
)
140143

141144
self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")

durabletask/internal/helpers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,13 @@ def new_create_sub_orchestration_action(
199199
id: int,
200200
name: str,
201201
instance_id: Optional[str],
202-
encoded_input: Optional[str]) -> pb.OrchestratorAction:
202+
encoded_input: Optional[str],
203+
version: Optional[str]) -> pb.OrchestratorAction:
203204
return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction(
204205
name=name,
205206
instanceId=instance_id,
206-
input=get_string_value(encoded_input)
207+
input=get_string_value(encoded_input),
208+
version=get_string_value(version)
207209
))
208210

209211

durabletask/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
126126
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
127127
input: Optional[TInput] = None,
128128
instance_id: Optional[str] = None,
129-
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
129+
retry_policy: Optional[RetryPolicy] = None,
130+
version: Optional[str] = None) -> Task[TOutput]:
130131
"""Schedule sub-orchestrator function for execution.
131132
132133
Parameters

durabletask/worker.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from datetime import datetime, timedelta
1111
from threading import Event, Thread
1212
from types import GeneratorType
13+
from enum import Enum
1314
from typing import Any, Generator, Optional, Sequence, TypeVar, Union
15+
from packaging.version import InvalidVersion, parse
1416

1517
import grpc
1618
from google.protobuf import empty_pb2
@@ -72,9 +74,60 @@ def __init__(
7274
)
7375

7476

77+
class VersionMatchStrategy(Enum):
78+
"""Enumeration for version matching strategies."""
79+
80+
NONE = 1
81+
STRICT = 2
82+
CURRENT_OR_OLDER = 3
83+
84+
85+
class VersionFailureStrategy(Enum):
86+
"""Enumeration for version failure strategies."""
87+
88+
REJECT = 1
89+
FAIL = 2
90+
91+
92+
class VersionFailureException(Exception):
93+
pass
94+
95+
96+
class VersioningOptions:
97+
"""Configuration options for orchestrator and activity versioning.
98+
99+
This class provides options to control how versioning is handled for orchestrators
100+
and activities, including whether to use the default version and how to compare versions.
101+
"""
102+
103+
version: Optional[str] = None
104+
default_version: Optional[str] = None
105+
match_strategy: Optional[VersionMatchStrategy] = None
106+
failure_strategy: Optional[VersionFailureStrategy] = None
107+
108+
def __init__(self, version: Optional[str] = None,
109+
default_version: Optional[str] = None,
110+
match_strategy: Optional[VersionMatchStrategy] = None,
111+
failure_strategy: Optional[VersionFailureStrategy] = None
112+
):
113+
"""Initialize versioning options.
114+
115+
Args:
116+
version: The specific version to use for orchestrators and activities.
117+
default_version: The default version to use if no specific version is provided.
118+
match_strategy: The strategy to use for matching versions.
119+
failure_strategy: The strategy to use if versioning fails.
120+
"""
121+
self.version = version
122+
self.default_version = default_version
123+
self.match_strategy = match_strategy
124+
self.failure_strategy = failure_strategy
125+
126+
75127
class _Registry:
76128
orchestrators: dict[str, task.Orchestrator]
77129
activities: dict[str, task.Activity]
130+
versioning: Optional[VersioningOptions] = None
78131

79132
def __init__(self):
80133
self.orchestrators = {}
@@ -279,6 +332,12 @@ def add_activity(self, fn: task.Activity) -> str:
279332
)
280333
return self._registry.add_activity(fn)
281334

335+
def use_versioning(self, version: VersioningOptions) -> None:
336+
"""Sets the default version for orchestrators and activities."""
337+
if self._is_running:
338+
raise RuntimeError("Cannot set default version while the worker is running.")
339+
self._registry.versioning = version
340+
282341
def start(self):
283342
"""Starts the worker on a background thread and begins listening for work items."""
284343
if self._is_running:
@@ -646,7 +705,7 @@ def set_complete(
646705
)
647706
self._pending_actions[action.id] = action
648707

649-
def set_failed(self, ex: Exception):
708+
def set_failed(self, ex: Union[Exception, pb.TaskFailureDetails]):
650709
if self._is_complete:
651710
return
652711

@@ -658,7 +717,7 @@ def set_failed(self, ex: Exception):
658717
self.next_sequence_number(),
659718
pb.ORCHESTRATION_STATUS_FAILED,
660719
None,
661-
ph.new_failure_details(ex),
720+
ph.new_failure_details(ex) if isinstance(ex, Exception) else ex,
662721
)
663722
self._pending_actions[action.id] = action
664723

@@ -768,6 +827,7 @@ def call_sub_orchestrator(
768827
input: Optional[TInput] = None,
769828
instance_id: Optional[str] = None,
770829
retry_policy: Optional[task.RetryPolicy] = None,
830+
version: Optional[str] = None,
771831
) -> task.Task[TOutput]:
772832
id = self.next_sequence_number()
773833
orchestrator_name = task.get_name(orchestrator)
@@ -778,6 +838,7 @@ def call_sub_orchestrator(
778838
retry_policy=retry_policy,
779839
is_sub_orch=True,
780840
instance_id=instance_id,
841+
version=version,
781842
)
782843
return self._pending_tasks.get(id, task.CompletableTask())
783844

@@ -792,6 +853,7 @@ def call_activity_function_helper(
792853
is_sub_orch: bool = False,
793854
instance_id: Optional[str] = None,
794855
fn_task: Optional[task.CompletableTask[TOutput]] = None,
856+
version: Optional[str] = None,
795857
):
796858
if id is None:
797859
id = self.next_sequence_number()
@@ -816,7 +878,7 @@ def call_activity_function_helper(
816878
if not isinstance(activity_function, str):
817879
raise ValueError("Orchestrator function name must be a string")
818880
action = ph.new_create_sub_orchestration_action(
819-
id, activity_function, instance_id, encoded_input
881+
id, activity_function, instance_id, encoded_input, version
820882
)
821883
self._pending_actions[id] = action
822884

@@ -893,7 +955,27 @@ def execute(
893955
)
894956

895957
ctx = _RuntimeOrchestrationContext(instance_id)
958+
version_failure = None
896959
try:
960+
execution_started_events = [e.executionStarted for e in old_events if e.HasField("executionStarted")]
961+
if self._registry.versioning and len(execution_started_events) > 0:
962+
execution_started_event = execution_started_events[-1]
963+
version_failure = self.evaluate_orchestration_versioning(
964+
self._registry.versioning,
965+
execution_started_event.version.value if execution_started_event.version else None,
966+
)
967+
if version_failure:
968+
self._logger.warning(
969+
f"Orchestration version did not meet worker versioning requirements. "
970+
f"Error action = '{self._registry.versioning.failure_strategy}'. "
971+
f"Version error = '{version_failure}'"
972+
)
973+
if self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
974+
raise VersionFailureException
975+
elif self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
976+
# TODO: We don't have abandoned orchestrations yet, so we just fail
977+
raise VersionFailureException
978+
897979
# Rebuild local state by replaying old history into the orchestrator function
898980
self._logger.debug(
899981
f"{instance_id}: Rebuilding local state with {len(old_events)} history event..."
@@ -912,6 +994,12 @@ def execute(
912994
for new_event in new_events:
913995
self.process_event(ctx, new_event)
914996

997+
except VersionFailureException as ex:
998+
if version_failure:
999+
ctx.set_failed(version_failure)
1000+
else:
1001+
ctx.set_failed(ex)
1002+
9151003
except Exception as ex:
9161004
# Unhandled exceptions fail the orchestration
9171005
ctx.set_failed(ex)
@@ -1223,6 +1311,48 @@ def process_event(
12231311
# The orchestrator generator function completed
12241312
ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)
12251313

1314+
def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
1315+
if versioning is None:
1316+
return None
1317+
version_comparison = self.compare_versions(orchestration_version, versioning.version)
1318+
if versioning.match_strategy == VersionMatchStrategy.NONE:
1319+
return None
1320+
elif versioning.match_strategy == VersionMatchStrategy.STRICT:
1321+
if version_comparison != 0:
1322+
return pb.TaskFailureDetails(
1323+
errorType="VersionMismatch",
1324+
errorMessage=f"The orchestration version '{orchestration_version}' does not match the worker version '{versioning.version}'.",
1325+
isNonRetriable=True,
1326+
)
1327+
elif versioning.match_strategy == VersionMatchStrategy.CURRENT_OR_OLDER:
1328+
if version_comparison > 0:
1329+
return pb.TaskFailureDetails(
1330+
errorType="VersionMismatch",
1331+
errorMessage=f"The orchestration version '{orchestration_version}' is greater than the worker version '{versioning.version}'.",
1332+
isNonRetriable=True,
1333+
)
1334+
else:
1335+
# If there is a type of versioning we don't understand, it is better to treat it as a versioning failure.
1336+
return pb.TaskFailureDetails(
1337+
errorType="VersionMismatch",
1338+
errorMessage=f"The version match strategy '{versioning.match_strategy}' is unknown.",
1339+
isNonRetriable=True,
1340+
)
1341+
1342+
def compare_versions(self, source_version: Optional[str], default_version: Optional[str]) -> int:
1343+
if not source_version and not default_version:
1344+
return 0
1345+
if not source_version:
1346+
return -1
1347+
if not default_version:
1348+
return 1
1349+
try:
1350+
source_version_parsed = parse(source_version)
1351+
default_version_parsed = parse(default_version)
1352+
return (source_version_parsed > default_version_parsed) - (source_version_parsed < default_version_parsed)
1353+
except InvalidVersion:
1354+
return (source_version > default_version) - (source_version < default_version)
1355+
12261356

12271357
class _ActivityExecutor:
12281358
def __init__(self, registry: _Registry, logger: logging.Logger):

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ readme = "README.md"
2727
dependencies = [
2828
"grpcio",
2929
"protobuf",
30-
"asyncio"
30+
"asyncio",
31+
"packaging"
3132
]
3233

3334
[project.urls]

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ protobuf
44
pytest
55
pytest-cov
66
azure-identity
7-
asyncio
7+
asyncio
8+
packaging

0 commit comments

Comments
 (0)