Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/durabletask-azuremanaged.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt

- name: Install durabletask-azuremanaged locally
working-directory: durabletask-azuremanaged
run: |
pip install . --no-deps --force-reinstall

- name: Install durabletask locally
run: |
pip install . --no-deps --force-reinstall

- name: Run the tests
working-directory: tests/durabletask-azuremanaged
run: |
Expand Down
8 changes: 3 additions & 5 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
],
"python.analysis.typeCheckingMode": "basic",
"python.testing.pytestArgs": [
"-v",
"--cov=durabletask/",
"--cov-report=lcov",
"tests/"
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
Expand All @@ -30,5 +27,6 @@
"jacoco.xml",
"coverage.cobertura.xml"
],
"makefile.configureOnOpen": false
"makefile.configureOnOpen": false,
"debugpy.debugJustMyCode": false
}
6 changes: 4 additions & 2 deletions durabletask-azuremanaged/durabletask/azuremanaged/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, *,
host_address: str,
taskhub: str,
token_credential: Optional[TokenCredential],
secure_channel: bool = True):
secure_channel: bool = True,
default_version: Optional[str] = None):

if not taskhub:
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
Expand All @@ -30,4 +31,5 @@ def __init__(self, *,
host_address=host_address,
secure_channel=secure_channel,
metadata=None,
interceptors=interceptors)
interceptors=interceptors,
default_version=default_version)
4 changes: 2 additions & 2 deletions durabletask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

"""Durable Task SDK for Python"""

from durabletask.worker import ConcurrencyOptions
from durabletask.worker import ConcurrencyOptions, VersioningOptions

__all__ = ["ConcurrencyOptions"]
__all__ = ["ConcurrencyOptions", "VersioningOptions"]

PACKAGE_NAME = "durabletask"
11 changes: 7 additions & 4 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def __init__(self, *,
log_handler: Optional[logging.Handler] = None,
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
default_version: Optional[str] = None):

# If the caller provided metadata, we need to create a new interceptor for it and
# add it to the list of interceptors.
Expand All @@ -118,13 +119,15 @@ def __init__(self, *,
)
self._stub = stubs.TaskHubSidecarServiceStub(channel)
self._logger = shared.get_logger("client", log_handler, log_formatter)
self.default_version = default_version

def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
start_at: Optional[datetime] = None,
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
tags: Optional[dict[str, str]] = None) -> str:
tags: Optional[dict[str, str]] = None,
version: Optional[str] = None) -> str:

name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)

Expand All @@ -133,9 +136,9 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
version=wrappers_pb2.StringValue(value=""),
version=helpers.get_string_value(version if version else self.default_version),
orchestrationIdReusePolicy=reuse_id_policy,
tags=tags,
tags=tags
)

self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
Expand Down
6 changes: 4 additions & 2 deletions durabletask/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,13 @@ def new_create_sub_orchestration_action(
id: int,
name: str,
instance_id: Optional[str],
encoded_input: Optional[str]) -> pb.OrchestratorAction:
encoded_input: Optional[str],
version: Optional[str]) -> pb.OrchestratorAction:
return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction(
name=name,
instanceId=instance_id,
input=get_string_value(encoded_input)
input=get_string_value(encoded_input),
version=get_string_value(version)
))


Expand Down
3 changes: 2 additions & 1 deletion durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
retry_policy: Optional[RetryPolicy] = None,
version: Optional[str] = None) -> Task[TOutput]:
"""Schedule sub-orchestrator function for execution.

Parameters
Expand Down
136 changes: 133 additions & 3 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from datetime import datetime, timedelta
from threading import Event, Thread
from types import GeneratorType
from enum import Enum
from typing import Any, Generator, Optional, Sequence, TypeVar, Union
from packaging.version import InvalidVersion, parse

import grpc
from google.protobuf import empty_pb2
Expand Down Expand Up @@ -72,9 +74,60 @@ def __init__(
)


class VersionMatchStrategy(Enum):
"""Enumeration for version matching strategies."""

NONE = 1
STRICT = 2
CURRENT_OR_OLDER = 3


class VersionFailureStrategy(Enum):
"""Enumeration for version failure strategies."""

REJECT = 1
FAIL = 2


class VersionFailureException(Exception):
pass


class VersioningOptions:
"""Configuration options for orchestrator and activity versioning.

This class provides options to control how versioning is handled for orchestrators
and activities, including whether to use the default version and how to compare versions.
"""

version: Optional[str] = None
default_version: Optional[str] = None
match_strategy: Optional[VersionMatchStrategy] = None
failure_strategy: Optional[VersionFailureStrategy] = None

def __init__(self, version: Optional[str] = None,
default_version: Optional[str] = None,
match_strategy: Optional[VersionMatchStrategy] = None,
failure_strategy: Optional[VersionFailureStrategy] = None
):
"""Initialize versioning options.

Args:
version: The specific version to use for orchestrators and activities.
default_version: The default version to use if no specific version is provided.
match_strategy: The strategy to use for matching versions.
failure_strategy: The strategy to use if versioning fails.
"""
self.version = version
self.default_version = default_version
self.match_strategy = match_strategy
self.failure_strategy = failure_strategy


class _Registry:
orchestrators: dict[str, task.Orchestrator]
activities: dict[str, task.Activity]
versioning: Optional[VersioningOptions] = None

def __init__(self):
self.orchestrators = {}
Expand Down Expand Up @@ -279,6 +332,12 @@ def add_activity(self, fn: task.Activity) -> str:
)
return self._registry.add_activity(fn)

def use_versioning(self, version: VersioningOptions) -> None:
"""Sets the default version for orchestrators and activities."""
if self._is_running:
raise RuntimeError("Cannot set default version while the worker is running.")
self._registry.versioning = version

def start(self):
"""Starts the worker on a background thread and begins listening for work items."""
if self._is_running:
Expand Down Expand Up @@ -646,7 +705,7 @@ def set_complete(
)
self._pending_actions[action.id] = action

def set_failed(self, ex: Exception):
def set_failed(self, ex: Union[Exception, pb.TaskFailureDetails]):
if self._is_complete:
return

Expand All @@ -658,7 +717,7 @@ def set_failed(self, ex: Exception):
self.next_sequence_number(),
pb.ORCHESTRATION_STATUS_FAILED,
None,
ph.new_failure_details(ex),
ph.new_failure_details(ex) if isinstance(ex, Exception) else ex,
)
self._pending_actions[action.id] = action

Expand Down Expand Up @@ -768,6 +827,7 @@ def call_sub_orchestrator(
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[task.RetryPolicy] = None,
version: Optional[str] = None,
) -> task.Task[TOutput]:
id = self.next_sequence_number()
orchestrator_name = task.get_name(orchestrator)
Expand All @@ -778,6 +838,7 @@ def call_sub_orchestrator(
retry_policy=retry_policy,
is_sub_orch=True,
instance_id=instance_id,
version=version,
)
return self._pending_tasks.get(id, task.CompletableTask())

Expand All @@ -792,6 +853,7 @@ def call_activity_function_helper(
is_sub_orch: bool = False,
instance_id: Optional[str] = None,
fn_task: Optional[task.CompletableTask[TOutput]] = None,
version: Optional[str] = None,
):
if id is None:
id = self.next_sequence_number()
Expand All @@ -816,7 +878,7 @@ def call_activity_function_helper(
if not isinstance(activity_function, str):
raise ValueError("Orchestrator function name must be a string")
action = ph.new_create_sub_orchestration_action(
id, activity_function, instance_id, encoded_input
id, activity_function, instance_id, encoded_input, version
)
self._pending_actions[id] = action

Expand Down Expand Up @@ -893,7 +955,27 @@ def execute(
)

ctx = _RuntimeOrchestrationContext(instance_id)
version_failure = None
try:
execution_started_events = [e.executionStarted for e in old_events if e.HasField("executionStarted")]
if self._registry.versioning and len(execution_started_events) > 0:
execution_started_event = execution_started_events[-1]
version_failure = self.evaluate_orchestration_versioning(
self._registry.versioning,
execution_started_event.version.value if execution_started_event.version else None,
)
if version_failure:
self._logger.warning(
f"Orchestration version did not meet worker versioning requirements. "
f"Error action = '{self._registry.versioning.failure_strategy}'. "
f"Version error = '{version_failure}'"
)
if self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
raise VersionFailureException
elif self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
# TODO: We don't have abandoned orchestrations yet, so we just fail
raise VersionFailureException

# Rebuild local state by replaying old history into the orchestrator function
self._logger.debug(
f"{instance_id}: Rebuilding local state with {len(old_events)} history event..."
Expand All @@ -912,6 +994,12 @@ def execute(
for new_event in new_events:
self.process_event(ctx, new_event)

except VersionFailureException as ex:
if version_failure:
ctx.set_failed(version_failure)
else:
ctx.set_failed(ex)

except Exception as ex:
# Unhandled exceptions fail the orchestration
ctx.set_failed(ex)
Expand Down Expand Up @@ -1223,6 +1311,48 @@ def process_event(
# The orchestrator generator function completed
ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)

def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
if versioning is None:
return None
version_comparison = self.compare_versions(orchestration_version, versioning.version)
if versioning.match_strategy == VersionMatchStrategy.NONE:
return None
elif versioning.match_strategy == VersionMatchStrategy.STRICT:
if version_comparison != 0:
return pb.TaskFailureDetails(
errorType="VersionMismatch",
errorMessage=f"The orchestration version '{orchestration_version}' does not match the worker version '{versioning.version}'.",
isNonRetriable=True,
)
elif versioning.match_strategy == VersionMatchStrategy.CURRENT_OR_OLDER:
if version_comparison > 0:
return pb.TaskFailureDetails(
errorType="VersionMismatch",
errorMessage=f"The orchestration version '{orchestration_version}' is greater than the worker version '{versioning.version}'.",
isNonRetriable=True,
)
else:
# If there is a type of versioning we don't understand, it is better to treat it as a versioning failure.
return pb.TaskFailureDetails(
errorType="VersionMismatch",
errorMessage=f"The version match strategy '{versioning.match_strategy}' is unknown.",
isNonRetriable=True,
)

def compare_versions(self, source_version: Optional[str], default_version: Optional[str]) -> int:
if not source_version and not default_version:
return 0
if not source_version:
return -1
if not default_version:
return 1
try:
source_version_parsed = parse(source_version)
default_version_parsed = parse(default_version)
return (source_version_parsed > default_version_parsed) - (source_version_parsed < default_version_parsed)
except InvalidVersion:
return (source_version > default_version) - (source_version < default_version)


class _ActivityExecutor:
def __init__(self, registry: _Registry, logger: logging.Logger):
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ readme = "README.md"
dependencies = [
"grpcio",
"protobuf",
"asyncio"
"asyncio",
"packaging"
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ protobuf
pytest
pytest-cov
azure-identity
asyncio
asyncio
packaging
Loading
Loading