Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/durabletask-azuremanaged.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt

- name: Install durabletask-azuremanaged locally
working-directory: durabletask-azuremanaged
run: |
pip install . --no-deps --force-reinstall

- name: Install durabletask locally
run: |
pip install . --no-deps --force-reinstall

- name: Run the tests
working-directory: tests/durabletask-azuremanaged
run: |
Expand Down
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@
"jacoco.xml",
"coverage.cobertura.xml"
],
"makefile.configureOnOpen": false
"makefile.configureOnOpen": false,
"debugpy.debugJustMyCode": false
}
6 changes: 4 additions & 2 deletions durabletask-azuremanaged/durabletask/azuremanaged/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, *,
host_address: str,
taskhub: str,
token_credential: Optional[TokenCredential],
secure_channel: bool = True):
secure_channel: bool = True,
default_version: Optional[str] = None):

if not taskhub:
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
Expand All @@ -30,4 +31,5 @@ def __init__(self, *,
host_address=host_address,
secure_channel=secure_channel,
metadata=None,
interceptors=interceptors)
interceptors=interceptors,
default_version=default_version)
4 changes: 2 additions & 2 deletions durabletask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

"""Durable Task SDK for Python"""

from durabletask.worker import ConcurrencyOptions
from durabletask.worker import ConcurrencyOptions, VersioningOptions

__all__ = ["ConcurrencyOptions"]
__all__ = ["ConcurrencyOptions", "VersioningOptions"]

PACKAGE_NAME = "durabletask"
11 changes: 7 additions & 4 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def __init__(self, *,
log_handler: Optional[logging.Handler] = None,
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
default_version: Optional[str] = None):

# If the caller provided metadata, we need to create a new interceptor for it and
# add it to the list of interceptors.
Expand All @@ -118,13 +119,15 @@ def __init__(self, *,
)
self._stub = stubs.TaskHubSidecarServiceStub(channel)
self._logger = shared.get_logger("client", log_handler, log_formatter)
self.default_version = default_version

def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
start_at: Optional[datetime] = None,
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
tags: Optional[dict[str, str]] = None) -> str:
tags: Optional[dict[str, str]] = None,
version: Optional[str] = None) -> str:

name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)

Expand All @@ -133,9 +136,9 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
version=wrappers_pb2.StringValue(value=""),
version=helpers.get_string_value(version if version else self.default_version),
orchestrationIdReusePolicy=reuse_id_policy,
tags=tags,
tags=tags
)

self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
Expand Down
7 changes: 7 additions & 0 deletions durabletask/internal/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class VersionFailureException(Exception):
pass


class AbandonOrchestrationError(Exception):
def __init__(self, *args: object) -> None:
super().__init__(*args)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super happy with this small file or propagating information using exceptions like I'm doing. Open to better solutions

6 changes: 4 additions & 2 deletions durabletask/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,13 @@ def new_create_sub_orchestration_action(
id: int,
name: str,
instance_id: Optional[str],
encoded_input: Optional[str]) -> pb.OrchestratorAction:
encoded_input: Optional[str],
version: Optional[str]) -> pb.OrchestratorAction:
return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction(
name=name,
instanceId=instance_id,
input=get_string_value(encoded_input)
input=get_string_value(encoded_input),
version=get_string_value(version)
))


Expand Down
3 changes: 2 additions & 1 deletion durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
retry_policy: Optional[RetryPolicy] = None,
version: Optional[str] = None) -> Task[TOutput]:
"""Schedule sub-orchestrator function for execution.

Parameters
Expand Down
150 changes: 145 additions & 5 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from datetime import datetime, timedelta
from threading import Event, Thread
from types import GeneratorType
from enum import Enum
from typing import Any, Generator, Optional, Sequence, TypeVar, Union
from packaging.version import InvalidVersion, parse

import grpc
from google.protobuf import empty_pb2

import durabletask.internal.helpers as ph
import durabletask.internal.exceptions as pe
import durabletask.internal.orchestrator_service_pb2 as pb
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
import durabletask.internal.shared as shared
Expand Down Expand Up @@ -72,9 +75,56 @@ def __init__(
)


class VersionMatchStrategy(Enum):
"""Enumeration for version matching strategies."""

NONE = 1
STRICT = 2
CURRENT_OR_OLDER = 3


class VersionFailureStrategy(Enum):
"""Enumeration for version failure strategies."""

REJECT = 1
FAIL = 2


class VersioningOptions:
"""Configuration options for orchestrator and activity versioning.

This class provides options to control how versioning is handled for orchestrators
and activities, including whether to use the default version and how to compare versions.
"""

version: Optional[str] = None
default_version: Optional[str] = None
match_strategy: Optional[VersionMatchStrategy] = None
failure_strategy: Optional[VersionFailureStrategy] = None

def __init__(self, version: Optional[str] = None,
default_version: Optional[str] = None,
match_strategy: Optional[VersionMatchStrategy] = None,
failure_strategy: Optional[VersionFailureStrategy] = None
):
"""Initialize versioning options.

Args:
version: The version of orchestrations that the worker can work on.
default_version: The default version that will be used for starting new orchestrations.
match_strategy: The versioning strategy for the Durable Task worker.
failure_strategy: The versioning failure strategy for the Durable Task worker.
"""
self.version = version
self.default_version = default_version
self.match_strategy = match_strategy
self.failure_strategy = failure_strategy


class _Registry:
orchestrators: dict[str, task.Orchestrator]
activities: dict[str, task.Activity]
versioning: Optional[VersioningOptions] = None

def __init__(self):
self.orchestrators = {}
Expand Down Expand Up @@ -279,6 +329,12 @@ def add_activity(self, fn: task.Activity) -> str:
)
return self._registry.add_activity(fn)

def use_versioning(self, version: VersioningOptions) -> None:
"""Initializes versioning options for sub-orchestrators and activities."""
if self._is_running:
raise RuntimeError("Cannot set default version while the worker is running.")
self._registry.versioning = version

def start(self):
"""Starts the worker on a background thread and begins listening for work items."""
if self._is_running:
Expand Down Expand Up @@ -513,6 +569,16 @@ def _execute_orchestrator(
customStatus=ph.get_string_value(result.encoded_custom_status),
completionToken=completionToken,
)
except pe.AbandonOrchestrationError:
self._logger.info(
f"Abandoning orchestration. InstanceId = '{req.instanceId}'. Completion token = '{completionToken}'"
)
stub.AbandonTaskOrchestratorWorkItem(
pb.AbandonOrchestrationTaskRequest(
completionToken=completionToken
)
)
return
except Exception as ex:
self._logger.exception(
f"An error occurred while trying to execute instance '{req.instanceId}': {ex}"
Expand Down Expand Up @@ -574,7 +640,7 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
_generator: Optional[Generator[task.Task, Any, Any]]
_previous_task: Optional[task.Task]

def __init__(self, instance_id: str):
def __init__(self, instance_id: str, registry: _Registry):
self._generator = None
self._is_replaying = True
self._is_complete = False
Expand All @@ -584,6 +650,7 @@ def __init__(self, instance_id: str):
self._sequence_number = 0
self._current_utc_datetime = datetime(1000, 1, 1)
self._instance_id = instance_id
self._registry = registry
self._completion_status: Optional[pb.OrchestrationStatus] = None
self._received_events: dict[str, list[Any]] = {}
self._pending_events: dict[str, list[task.CompletableTask]] = {}
Expand Down Expand Up @@ -646,7 +713,7 @@ def set_complete(
)
self._pending_actions[action.id] = action

def set_failed(self, ex: Exception):
def set_failed(self, ex: Union[Exception, pb.TaskFailureDetails]):
if self._is_complete:
return

Expand All @@ -658,7 +725,7 @@ def set_failed(self, ex: Exception):
self.next_sequence_number(),
pb.ORCHESTRATION_STATUS_FAILED,
None,
ph.new_failure_details(ex),
ph.new_failure_details(ex) if isinstance(ex, Exception) else ex,
)
self._pending_actions[action.id] = action

Expand Down Expand Up @@ -768,16 +835,20 @@ def call_sub_orchestrator(
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[task.RetryPolicy] = None,
version: Optional[str] = None,
) -> task.Task[TOutput]:
id = self.next_sequence_number()
orchestrator_name = task.get_name(orchestrator)
default_version = self._registry.versioning.default_version if self._registry.versioning else None
orchestrator_version = version if version else default_version
self.call_activity_function_helper(
id,
orchestrator_name,
input=input,
retry_policy=retry_policy,
is_sub_orch=True,
instance_id=instance_id,
version=orchestrator_version
)
return self._pending_tasks.get(id, task.CompletableTask())

Expand All @@ -792,6 +863,7 @@ def call_activity_function_helper(
is_sub_orch: bool = False,
instance_id: Optional[str] = None,
fn_task: Optional[task.CompletableTask[TOutput]] = None,
version: Optional[str] = None,
):
if id is None:
id = self.next_sequence_number()
Expand All @@ -816,7 +888,7 @@ def call_activity_function_helper(
if not isinstance(activity_function, str):
raise ValueError("Orchestrator function name must be a string")
action = ph.new_create_sub_orchestration_action(
id, activity_function, instance_id, encoded_input
id, activity_function, instance_id, encoded_input, version
)
self._pending_actions[id] = action

Expand Down Expand Up @@ -892,7 +964,8 @@ def execute(
"The new history event list must have at least one event in it."
)

ctx = _RuntimeOrchestrationContext(instance_id)
ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
version_failure = None
try:
# Rebuild local state by replaying old history into the orchestrator function
self._logger.debug(
Expand All @@ -902,6 +975,22 @@ def execute(
for old_event in old_events:
self.process_event(ctx, old_event)

# Process versioning if applicable
execution_started_events = [e.executionStarted for e in old_events if e.HasField("executionStarted")]
if self._registry.versioning and len(execution_started_events) > 0:
execution_started_event = execution_started_events[-1]
version_failure = self.evaluate_orchestration_versioning(
self._registry.versioning,
execution_started_event.version.value if execution_started_event.version else None,
)
if version_failure:
self._logger.warning(
f"Orchestration version did not meet worker versioning requirements. "
f"Error action = '{self._registry.versioning.failure_strategy}'. "
f"Version error = '{version_failure}'"
)
raise pe.VersionFailureException

# Get new actions by executing newly received events into the orchestrator function
if self._logger.level <= logging.DEBUG:
summary = _get_new_event_summary(new_events)
Expand All @@ -912,6 +1001,15 @@ def execute(
for new_event in new_events:
self.process_event(ctx, new_event)

except pe.VersionFailureException as ex:
if self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
if version_failure:
ctx.set_failed(version_failure)
else:
ctx.set_failed(ex)
elif self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
raise pe.AbandonOrchestrationError

except Exception as ex:
# Unhandled exceptions fail the orchestration
ctx.set_failed(ex)
Expand Down Expand Up @@ -1223,6 +1321,48 @@ def process_event(
# The orchestrator generator function completed
ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)

def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
if versioning is None:
return None
version_comparison = self.compare_versions(orchestration_version, versioning.version)
if versioning.match_strategy == VersionMatchStrategy.NONE:
return None
elif versioning.match_strategy == VersionMatchStrategy.STRICT:
if version_comparison != 0:
return pb.TaskFailureDetails(
errorType="VersionMismatch",
errorMessage=f"The orchestration version '{orchestration_version}' does not match the worker version '{versioning.version}'.",
isNonRetriable=True,
)
elif versioning.match_strategy == VersionMatchStrategy.CURRENT_OR_OLDER:
if version_comparison > 0:
return pb.TaskFailureDetails(
errorType="VersionMismatch",
errorMessage=f"The orchestration version '{orchestration_version}' is greater than the worker version '{versioning.version}'.",
isNonRetriable=True,
)
else:
# If there is a type of versioning we don't understand, it is better to treat it as a versioning failure.
return pb.TaskFailureDetails(
errorType="VersionMismatch",
errorMessage=f"The version match strategy '{versioning.match_strategy}' is unknown.",
isNonRetriable=True,
)

def compare_versions(self, source_version: Optional[str], default_version: Optional[str]) -> int:
if not source_version and not default_version:
return 0
if not source_version:
return -1
if not default_version:
return 1
try:
source_version_parsed = parse(source_version)
default_version_parsed = parse(default_version)
return (source_version_parsed > default_version_parsed) - (source_version_parsed < default_version_parsed)
except InvalidVersion:
return (source_version > default_version) - (source_version < default_version)


class _ActivityExecutor:
def __init__(self, registry: _Registry, logger: logging.Logger):
Expand Down
Loading
Loading