From 03d27bda4e9360efee0edfdcd2f0f56dc6adb6c8 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 29 May 2025 16:04:31 -0700 Subject: [PATCH 01/18] Reconnect upon connection error --- CHANGELOG.md | 14 +- durabletask-azuremanaged/pyproject.toml | 6 +- durabletask/worker.py | 163 ++++++++++++++++++++---- pyproject.toml | 2 +- 4 files changed, 151 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13b0e69..89aaf79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,13 +5,23 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## v0.2.0 (Unreleased) +## v0.3.0 + +### New + +- Added configurable worker concurrency with `max_workers` parameter in `TaskHubGrpcWorker` constructor - allows customization of ThreadPoolExecutor size (default: 16 workers) + +### Fixed + +- Fixed an issue where a worker could not recover after its connection was interrupted or severed + +## v0.2.1 ### New - Added `set_custom_status` orchestrator API ([#31](https://github.com/microsoft/durabletask-python/pull/31)) - contributed by [@famarting](https://github.com/famarting) - Added `purge_orchestration` client API ([#34](https://github.com/microsoft/durabletask-python/pull/34)) - contributed by [@famarting](https://github.com/famarting) -- Added new `durabletask-azuremanaged` package for use with the [Durable Task Scheduler](https://techcommunity.microsoft.com/blog/appsonazureblog/announcing-limited-early-access-of-the-durable-task-scheduler-for-azure-durable-/4286526) - by [@RyanLettieri](https://github.com/RyanLettieri) +- Added new `durabletask-azuremanaged` package for use with the [Durable Task Scheduler](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/durable-task-scheduler) - by [@RyanLettieri](https://github.com/RyanLettieri) ### Changes diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml index 5962285..3de2b53 100644 --- a/durabletask-azuremanaged/pyproject.toml +++ b/durabletask-azuremanaged/pyproject.toml @@ -9,8 +9,8 @@ build-backend = "setuptools.build_meta" [project] name = "durabletask.azuremanaged" -version = "0.1.4" -description = "Extensions for the Durable Task Python SDK for integrating with the Durable Task Scheduler in Azure" +version = "0.1.5" +description = "Durable Task Python SDK provider implementation for the Azure Durable Task Scheduler" keywords = [ "durable", "task", @@ -26,7 +26,7 @@ requires-python = ">=3.9" license = {file = "LICENSE"} readme = "README.md" dependencies = [ - "durabletask>=0.2.1", + "durabletask>=0.3.0", "azure-identity>=1.19.0" ] diff --git a/durabletask/worker.py b/durabletask/worker.py index 2c31e52..af6385e 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -3,6 +3,7 @@ import concurrent.futures import logging +import random from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType @@ -11,8 +12,8 @@ import grpc from google.protobuf import empty_pb2 -import durabletask.internal.helpers as ph import durabletask.internal.helpers as pbh +import durabletask.internal.helpers as ph import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared @@ -91,13 +92,15 @@ def __init__(self, *, log_handler=None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, - interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, + max_workers: Optional[int] = None): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel + self._max_workers = max_workers if max_workers is not None else 16 # Determine the interceptors to use if interceptors is not None: @@ -129,31 +132,117 @@ def add_activity(self, fn: task.Activity) -> str: def start(self): """Starts the worker on a background thread and begins listening for work items.""" - channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors) - stub = stubs.TaskHubSidecarServiceStub(channel) - if self._is_running: raise RuntimeError('The worker is already running.') def run_loop(): + """Enhanced run loop with better connection management and retry logic.""" + + # Connection state management for retry fix + current_channel: Optional[grpc.Channel] = None + current_stub: Optional[stubs.TaskHubSidecarServiceStub] = None + conn_retry_count = 0 + conn_max_retry_delay = 60 + + def create_fresh_connection() -> None: + """Create a new gRPC channel and stub, invalidating any existing ones. + + Raises: + Exception: If connection creation or testing fails. + """ + nonlocal current_channel, current_stub, conn_retry_count + + # Close existing connection if any + if current_channel: + try: + current_channel.close() + except Exception: + pass + + current_channel = None + current_stub = None + + try: + # Create new connection + current_channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors) + current_stub = stubs.TaskHubSidecarServiceStub(current_channel) + + # Test the connection + current_stub.Hello(empty_pb2.Empty()) + conn_retry_count = 0 # Reset on successful connection + self._logger.debug(f"Created fresh connection to {self._host_address}") + + except Exception as e: + self._logger.debug(f"Failed to create connection: {e}") + current_channel = None + current_stub = None + raise # Re-raise the original exception + + def invalidate_connection() -> None: + """Mark current connection as invalid.""" + nonlocal current_channel, current_stub + if current_channel: + try: + current_channel.close() + except Exception: + pass + current_channel = None + current_stub = None + + def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool: + """Determine if a gRPC error should trigger connection invalidation. + + Connection-level errors (network, authentication, server unavailable) + should invalidate the connection, while application-level errors + (bad requests, not found, etc.) should not. + """ + error_code = rpc_error.code() # type: ignore + + # Connection-level errors that warrant invalidation + connection_level_errors = { + grpc.StatusCode.UNAVAILABLE, # Server down/unreachable + grpc.StatusCode.DEADLINE_EXCEEDED, # Timeout, likely network issue + grpc.StatusCode.CANCELLED, # Connection cancelled + grpc.StatusCode.UNAUTHENTICATED, # Auth failed, may need new connection + grpc.StatusCode.ABORTED, # Transaction aborted, connection may be bad + } + + return error_code in connection_level_errors + # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity # functions. We'd need to know ahead of time whether a function is async or not. - # TODO: Max concurrency configuration settings - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="DurableTask") as executor: while not self._shutdown.is_set(): - try: - # send a "Hello" message to the sidecar to ensure that it's listening - stub.Hello(empty_pb2.Empty()) + # Ensure we have a valid connection before attempting work + if current_stub is None: + try: + create_fresh_connection() + except Exception: + # Connection failed, implement exponential backoff + conn_retry_count += 1 + delay = min(conn_max_retry_delay, (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1)) + self._logger.warning(f'Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})') + if self._shutdown.wait(delay): + break # Shutdown requested during wait + continue - # stream work items + try: + # Stream work items with the current connection + # Type assertion since we know current_stub is not None at this point + assert current_stub is not None, "current_stub should not be None at this point" + stub = current_stub # Local reference for type safety self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest()) self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...') - # The stream blocks until either a work item is received or the stream is canceled - # by another thread (see the stop() method). + # Process work items concurrently as they arrive for work_item in self._response_stream: # type: ignore + if self._shutdown.is_set(): + break + request_type = work_item.WhichOneof('request') self._logger.debug(f'Received "{request_type}" work item') + + # Submit work items to thread pool for concurrent processing if work_item.HasField('orchestratorRequest'): executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken) elif work_item.HasField('activityRequest'): @@ -163,21 +252,39 @@ def run_loop(): else: self._logger.warning(f'Unexpected work item type: {request_type}') + # Stream ended normally (shouldn't happen unless server closes) + self._logger.info("Work item stream ended normally") + except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore + # Intelligently decide whether to invalidate connection based on error type + should_invalidate = should_invalidate_connection(rpc_error) + if should_invalidate: + invalidate_connection() + + error_code = rpc_error.code() # type: ignore + if error_code == grpc.StatusCode.CANCELLED: self._logger.info(f'Disconnected from {self._host_address}') - elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore - self._logger.warning( - f'The sidecar at address {self._host_address} is unavailable - will continue retrying') + break # Likely shutdown + elif error_code == grpc.StatusCode.UNAVAILABLE: + self._logger.warning(f'The sidecar at address {self._host_address} is unavailable - will continue retrying') + elif should_invalidate: + self._logger.warning(f'Connection-level gRPC error ({error_code}): {rpc_error} - invalidating connection') else: - self._logger.warning(f'Unexpected error: {rpc_error}') + self._logger.warning(f'Application-level gRPC error ({error_code}): {rpc_error} - keeping connection') + + # Brief pause before retry + self._shutdown.wait(1) + except Exception as ex: + # Unexpected error, invalidate connection and retry + invalidate_connection() self._logger.warning(f'Unexpected error: {ex}') + self._shutdown.wait(1) - # CONSIDER: exponential backoff - self._shutdown.wait(5) - self._logger.info("No longer listening for work items") - return + # Final cleanup + invalidate_connection() + + self._logger.info("No longer listening for work items") self._logger.info(f"Starting gRPC worker that connects to {self._host_address}") self._runLoop = Thread(target=run_loop) @@ -367,14 +474,14 @@ def instance_id(self) -> str: def current_utc_datetime(self) -> datetime: return self._current_utc_datetime - @property - def is_replaying(self) -> bool: - return self._is_replaying - @current_utc_datetime.setter def current_utc_datetime(self, value: datetime): self._current_utc_datetime = value + @property + def is_replaying(self) -> bool: + return self._is_replaying + def set_custom_status(self, custom_status: Any) -> None: self._encoded_custom_status = shared.to_json(custom_status) if custom_status is not None else None @@ -389,7 +496,7 @@ def create_timer_internal(self, fire_at: Union[datetime, timedelta], action = ph.new_create_timer_action(id, fire_at) self._pending_actions[id] = action - timer_task = task.TimerTask() + timer_task: task.TimerTask = task.TimerTask() if retryable_task is not None: timer_task.set_retryable_parent(retryable_task) self._pending_tasks[id] = timer_task @@ -457,7 +564,7 @@ def wait_for_external_event(self, name: str) -> task.Task: # event with the given name so that we can resume the generator when it # arrives. If there are multiple events with the same name, we return # them in the order they were received. - external_event_task = task.CompletableTask() + external_event_task: task.CompletableTask = task.CompletableTask() event_name = name.casefold() event_list = self._received_events.get(event_name, None) if event_list: diff --git a/pyproject.toml b/pyproject.toml index 60a9d37..1491988 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "durabletask" -version = "0.2.1" +version = "0.3.0" description = "A Durable Task Client SDK for Python" keywords = [ "durable", From e620020d9948c818cc155beb80322ffa18bf7293 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 12:15:58 -0700 Subject: [PATCH 02/18] concurrency --- CHANGELOG.md | 13 +++++++- durabletask/__init__.py | 3 ++ durabletask/worker.py | 69 ++++++++++++++++++++++++++++++++++++++--- examples/README.md | 2 +- 4 files changed, 81 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89aaf79..21108d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## v0.4.0 + +### New + +- Added `ConcurrencyOptions` class for fine-grained concurrency control equivalent to .NET DurableTask SDK +- Enhanced `TaskHubGrpcWorker` with `concurrency_options` parameter supporting separate limits for activities, orchestrations, and entities (100 * processor_count default each) + +### Breaking Changes + +- Removed `max_workers` parameter from `TaskHubGrpcWorker` constructor in favor of the more flexible `ConcurrencyOptions` approach + ## v0.3.0 ### New -- Added configurable worker concurrency with `max_workers` parameter in `TaskHubGrpcWorker` constructor - allows customization of ThreadPoolExecutor size (default: 16 workers) +- Added configurable worker concurrency with `max_workers` parameter in `TaskHubGrpcWorker` constructor - allows customization of ThreadPoolExecutor size (default: 16 workers) (NOTE: This parameter was later removed in v0.4.0 in favor of `ConcurrencyOptions`) ### Fixed diff --git a/durabletask/__init__.py b/durabletask/__init__.py index a37823c..88af82b 100644 --- a/durabletask/__init__.py +++ b/durabletask/__init__.py @@ -3,5 +3,8 @@ """Durable Task SDK for Python""" +from durabletask.worker import ConcurrencyOptions + +__all__ = ["ConcurrencyOptions"] PACKAGE_NAME = "durabletask" diff --git a/durabletask/worker.py b/durabletask/worker.py index af6385e..c57695e 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -3,6 +3,7 @@ import concurrent.futures import logging +import os import random from datetime import datetime, timedelta from threading import Event, Thread @@ -24,6 +25,53 @@ TOutput = TypeVar('TOutput') +class ConcurrencyOptions: + """Configuration options for controlling concurrency of different work item types. + + This class mirrors the .NET DurableTask SDK's ConcurrencyOptions class, + providing fine-grained control over concurrent processing limits for + activities, orchestrations, and entities. + """ + + def __init__(self, + maximum_concurrent_activity_work_items: Optional[int] = None, + maximum_concurrent_orchestration_work_items: Optional[int] = None): + """Initialize concurrency options. + + Args: + maximum_concurrent_activity_work_items: Maximum number of activity work items + that can be processed concurrently. Defaults to 100 * processor_count. + maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items + that can be processed concurrently. Defaults to 100 * processor_count. + """ + processor_count = os.cpu_count() or 1 + default_concurrency = 100 * processor_count + + self.maximum_concurrent_activity_work_items = ( + maximum_concurrent_activity_work_items + if maximum_concurrent_activity_work_items is not None + else default_concurrency + ) + + self.maximum_concurrent_orchestration_work_items = ( + maximum_concurrent_orchestration_work_items + if maximum_concurrent_orchestration_work_items is not None + else default_concurrency + ) + + @property + def max_total_workers(self) -> int: + """Calculate the maximum total workers needed for the thread pool. + + Since Python's ThreadPoolExecutor doesn't differentiate between work item types, + we use the maximum of all concurrency limits to ensure we have enough workers. + """ + return max( + self.maximum_concurrent_activity_work_items, + self.maximum_concurrent_orchestration_work_items, + ) + + class _Registry: orchestrators: dict[str, task.Orchestrator] @@ -93,14 +141,16 @@ def __init__(self, *, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, - max_workers: Optional[int] = None): + concurrency_options: Optional[ConcurrencyOptions] = None): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel - self._max_workers = max_workers if max_workers is not None else 16 + + # Use provided concurrency options or create default ones + self._concurrency_options = concurrency_options if concurrency_options is not None else ConcurrencyOptions() # Determine the interceptors to use if interceptors is not None: @@ -112,6 +162,11 @@ def __init__(self, *, else: self._interceptors = None + @property + def concurrency_options(self) -> ConcurrencyOptions: + """Get the current concurrency options for this worker.""" + return self._concurrency_options + def __enter__(self): return self @@ -211,7 +266,7 @@ def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool: # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity # functions. We'd need to know ahead of time whether a function is async or not. - with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="DurableTask") as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=self._concurrency_options.max_total_workers, thread_name_prefix="DurableTask") as executor: while not self._shutdown.is_set(): # Ensure we have a valid connection before attempting work if current_stub is None: @@ -231,7 +286,13 @@ def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool: # Type assertion since we know current_stub is not None at this point assert current_stub is not None, "current_stub should not be None at this point" stub = current_stub # Local reference for type safety - self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest()) + + # Create GetWorkItemsRequest with concurrency limits + get_work_items_request = pb.GetWorkItemsRequest( + maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items, + maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items + ) + self._response_stream = stub.GetWorkItems(get_work_items_request) self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...') # Process work items concurrently as they arrive diff --git a/examples/README.md b/examples/README.md index 7cfbc7a..404b127 100644 --- a/examples/README.md +++ b/examples/README.md @@ -24,4 +24,4 @@ In some cases, the sample may require command-line parameters or user inputs. In - [Activity sequence](./activity_sequence.py): Orchestration that schedules three activity calls in a sequence. - [Fan-out/fan-in](./fanout_fanin.py): Orchestration that schedules a dynamic number of activity calls in parallel, waits for all of them to complete, and then performs an aggregation on the results. -- [Human interaction](./human_interaction.py): Orchestration that waits for a human to approve an order before continuing. \ No newline at end of file +- [Human interaction](./human_interaction.py): Orchestration that waits for a human to approve an order before continuing. From 8d0fe6f6c239cd48f72b6d2b8ebefcfa35aca4b9 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 13:11:20 -0700 Subject: [PATCH 03/18] Test updates --- tests/durabletask/test_client.py | 7 +- tests/durabletask/test_concurrency_options.py | 127 ++++++++++++++++++ 2 files changed, 131 insertions(+), 3 deletions(-) create mode 100644 tests/durabletask/test_concurrency_options.py diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index 64bbec8..e750134 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,13 +1,14 @@ -from unittest.mock import patch, ANY +from unittest.mock import ANY, patch +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl from durabletask.internal.shared import (get_default_host_address, get_grpc_channel) -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl HOST_ADDRESS = 'localhost:50051' METADATA = [('key1', 'value1'), ('key2', 'value2')] INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] + def test_get_grpc_channel_insecure(): with patch('grpc.insecure_channel') as mock_channel: get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) @@ -85,4 +86,4 @@ def test_grpc_channel_with_host_name_protocol_stripping(): prefix = "" get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) \ No newline at end of file + mock_secure_channel.assert_called_with(host_name, ANY) diff --git a/tests/durabletask/test_concurrency_options.py b/tests/durabletask/test_concurrency_options.py new file mode 100644 index 0000000..d963d92 --- /dev/null +++ b/tests/durabletask/test_concurrency_options.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +from durabletask import ConcurrencyOptions +from durabletask.worker import TaskHubGrpcWorker + + +def test_default_concurrency_options(): + """Test that default concurrency options work correctly.""" + options = ConcurrencyOptions() + processor_count = os.cpu_count() or 1 + expected_default = 100 * processor_count + + assert options.maximum_concurrent_activity_work_items == expected_default + assert options.maximum_concurrent_orchestration_work_items == expected_default + assert options.max_total_workers == expected_default + + +def test_custom_concurrency_options(): + """Test that custom concurrency options work correctly.""" + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=50, + maximum_concurrent_orchestration_work_items=25, + ) + + assert options.maximum_concurrent_activity_work_items == 50 + assert options.maximum_concurrent_orchestration_work_items == 25 + assert options.max_total_workers == 50 # Max of both values + + +def test_partial_custom_options(): + """Test that partially specified options use defaults for unspecified values.""" + processor_count = os.cpu_count() or 1 + expected_default = 100 * processor_count + + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=30 + # Leave other options as default + ) + + assert options.maximum_concurrent_activity_work_items == 30 + assert options.maximum_concurrent_orchestration_work_items == expected_default + assert ( + options.max_total_workers == expected_default + ) # Should be the default since it's larger + + +def test_max_total_workers_calculation(): + """Test that max_total_workers returns the maximum of all concurrency limits.""" + # Case 1: Activity is highest + options1 = ConcurrencyOptions( + maximum_concurrent_activity_work_items=100, + maximum_concurrent_orchestration_work_items=50, + ) + assert options1.max_total_workers == 100 + + # Case 2: Orchestration is highest + options2 = ConcurrencyOptions( + maximum_concurrent_activity_work_items=25, + maximum_concurrent_orchestration_work_items=100, + ) + assert options2.max_total_workers == 100 + + +def test_worker_with_concurrency_options(): + """Test that TaskHubGrpcWorker accepts concurrency options.""" + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=10, + maximum_concurrent_orchestration_work_items=20, + ) + + worker = TaskHubGrpcWorker(concurrency_options=options) + + assert worker.concurrency_options == options + + +def test_worker_default_options(): + """Test that TaskHubGrpcWorker uses default options when no parameters are provided.""" + worker = TaskHubGrpcWorker() + + processor_count = os.cpu_count() or 1 + expected_default = 100 * processor_count + + assert ( + worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default + ) + assert ( + worker.concurrency_options.maximum_concurrent_orchestration_work_items == expected_default + ) + + +def test_concurrency_options_property_access(): + """Test that the concurrency_options property works correctly.""" + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=15, + maximum_concurrent_orchestration_work_items=25, + ) + + worker = TaskHubGrpcWorker(concurrency_options=options) + retrieved_options = worker.concurrency_options + + # Should be the same object + assert retrieved_options is options + + # Should have correct values + assert retrieved_options.maximum_concurrent_activity_work_items == 15 + assert retrieved_options.maximum_concurrent_orchestration_work_items == 25 + + +def test_edge_cases(): + """Test edge cases like zero or very large values.""" + # Test with zeros (should still work) + options_zero = ConcurrencyOptions( + maximum_concurrent_activity_work_items=0, + maximum_concurrent_orchestration_work_items=0, + ) + assert options_zero.max_total_workers == 0 + + # Test with very large values + options_large = ConcurrencyOptions( + maximum_concurrent_activity_work_items=999999, + maximum_concurrent_orchestration_work_items=1, + ) + assert options_large.max_total_workers == 999999 + assert options_large.max_total_workers == 999999 From 194b24ed68da51acb0294aaf097974c0a999119a Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 15:30:02 -0700 Subject: [PATCH 04/18] More updates --- CHANGELOG.md | 13 +- durabletask/worker.py | 838 ++++++++++++------ tests/durabletask/test_concurrency_options.py | 50 +- .../test_worker_concurrency_loop.py | 135 +++ .../test_worker_concurrency_loop_async.py | 79 ++ 5 files changed, 774 insertions(+), 341 deletions(-) create mode 100644 tests/durabletask/test_worker_concurrency_loop.py create mode 100644 tests/durabletask/test_worker_concurrency_loop_async.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 21108d9..6921faa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,22 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## v0.4.0 - -### New - -- Added `ConcurrencyOptions` class for fine-grained concurrency control equivalent to .NET DurableTask SDK -- Enhanced `TaskHubGrpcWorker` with `concurrency_options` parameter supporting separate limits for activities, orchestrations, and entities (100 * processor_count default each) - -### Breaking Changes - -- Removed `max_workers` parameter from `TaskHubGrpcWorker` constructor in favor of the more flexible `ConcurrencyOptions` approach - ## v0.3.0 ### New -- Added configurable worker concurrency with `max_workers` parameter in `TaskHubGrpcWorker` constructor - allows customization of ThreadPoolExecutor size (default: 16 workers) (NOTE: This parameter was later removed in v0.4.0 in favor of `ConcurrencyOptions`) +- Added `ConcurrencyOptions` class for fine-grained concurrency control with separate limits for activities and orchestrations. The thread pool worker count can also be configured. ### Fixed diff --git a/durabletask/worker.py b/durabletask/worker.py index c57695e..8c0abc1 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1,10 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import concurrent.futures +import asyncio +import inspect import logging import os import random +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType @@ -21,8 +23,8 @@ from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class ConcurrencyOptions: @@ -33,9 +35,12 @@ class ConcurrencyOptions: activities, orchestrations, and entities. """ - def __init__(self, - maximum_concurrent_activity_work_items: Optional[int] = None, - maximum_concurrent_orchestration_work_items: Optional[int] = None): + def __init__( + self, + maximum_concurrent_activity_work_items: Optional[int] = None, + maximum_concurrent_orchestration_work_items: Optional[int] = None, + maximum_thread_pool_workers: Optional[int] = None, + ): """Initialize concurrency options. Args: @@ -43,9 +48,12 @@ def __init__(self, that can be processed concurrently. Defaults to 100 * processor_count. maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items that can be processed concurrently. Defaults to 100 * processor_count. + maximum_thread_pool_workers: Maximum number of thread pool workers to use. """ processor_count = os.cpu_count() or 1 default_concurrency = 100 * processor_count + # see https://docs.python.org/3/library/concurrent.futures.html + default_max_workers = processor_count + 4 self.maximum_concurrent_activity_work_items = ( maximum_concurrent_activity_work_items @@ -59,21 +67,14 @@ def __init__(self, else default_concurrency ) - @property - def max_total_workers(self) -> int: - """Calculate the maximum total workers needed for the thread pool. - - Since Python's ThreadPoolExecutor doesn't differentiate between work item types, - we use the maximum of all concurrency limits to ensure we have enough workers. - """ - return max( - self.maximum_concurrent_activity_work_items, - self.maximum_concurrent_orchestration_work_items, + self.maximum_thread_pool_workers = ( + maximum_thread_pool_workers + if maximum_thread_pool_workers is not None + else default_max_workers ) class _Registry: - orchestrators: dict[str, task.Orchestrator] activities: dict[str, task.Activity] @@ -83,7 +84,7 @@ def __init__(self): def add_orchestrator(self, fn: task.Orchestrator) -> str: if fn is None: - raise ValueError('An orchestrator function argument is required.') + raise ValueError("An orchestrator function argument is required.") name = task.get_name(fn) self.add_named_orchestrator(name, fn) @@ -91,7 +92,7 @@ def add_orchestrator(self, fn: task.Orchestrator) -> str: def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None: if not name: - raise ValueError('A non-empty orchestrator name is required.') + raise ValueError("A non-empty orchestrator name is required.") if name in self.orchestrators: raise ValueError(f"A '{name}' orchestrator already exists.") @@ -102,7 +103,7 @@ def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]: def add_activity(self, fn: task.Activity) -> str: if fn is None: - raise ValueError('An activity function argument is required.') + raise ValueError("An activity function argument is required.") name = task.get_name(fn) self.add_named_activity(name, fn) @@ -110,7 +111,7 @@ def add_activity(self, fn: task.Activity) -> str: def add_named_activity(self, name: str, fn: task.Activity) -> None: if not name: - raise ValueError('A non-empty activity name is required.') + raise ValueError("A non-empty activity name is required.") if name in self.activities: raise ValueError(f"A '{name}' activity already exists.") @@ -122,11 +123,13 @@ def get_activity(self, name: str) -> Optional[task.Activity]: class OrchestratorNotRegisteredError(ValueError): """Raised when attempting to start an orchestration that is not registered""" + pass class ActivityNotRegisteredError(ValueError): """Raised when attempting to call an activity that is not registered""" + pass @@ -134,23 +137,32 @@ class TaskHubGrpcWorker: _response_stream: Optional[grpc.Future] = None _interceptors: Optional[list[shared.ClientInterceptor]] = None - def __init__(self, *, - host_address: Optional[str] = None, - metadata: Optional[list[tuple[str, str]]] = None, - log_handler=None, - log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False, - interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, - concurrency_options: Optional[ConcurrencyOptions] = None): + def __init__( + self, + *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler=None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, + concurrency_options: Optional[ConcurrencyOptions] = None, + ): self._registry = _Registry() - self._host_address = host_address if host_address else shared.get_default_host_address() + self._host_address = ( + host_address if host_address else shared.get_default_host_address() + ) self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel # Use provided concurrency options or create default ones - self._concurrency_options = concurrency_options if concurrency_options is not None else ConcurrencyOptions() + self._concurrency_options = ( + concurrency_options + if concurrency_options is not None + else ConcurrencyOptions() + ) # Determine the interceptors to use if interceptors is not None: @@ -162,6 +174,8 @@ def __init__(self, *, else: self._interceptors = None + self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options) + @property def concurrency_options(self) -> ConcurrencyOptions: """Get the current concurrency options for this worker.""" @@ -176,181 +190,195 @@ def __exit__(self, type, value, traceback): def add_orchestrator(self, fn: task.Orchestrator) -> str: """Registers an orchestrator function with the worker.""" if self._is_running: - raise RuntimeError('Orchestrators cannot be added while the worker is running.') + raise RuntimeError( + "Orchestrators cannot be added while the worker is running." + ) return self._registry.add_orchestrator(fn) def add_activity(self, fn: task.Activity) -> str: """Registers an activity function with the worker.""" if self._is_running: - raise RuntimeError('Activities cannot be added while the worker is running.') + raise RuntimeError( + "Activities cannot be added while the worker is running." + ) return self._registry.add_activity(fn) def start(self): """Starts the worker on a background thread and begins listening for work items.""" if self._is_running: - raise RuntimeError('The worker is already running.') + raise RuntimeError("The worker is already running.") def run_loop(): - """Enhanced run loop with better connection management and retry logic.""" - - # Connection state management for retry fix - current_channel: Optional[grpc.Channel] = None - current_stub: Optional[stubs.TaskHubSidecarServiceStub] = None - conn_retry_count = 0 - conn_max_retry_delay = 60 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._async_run_loop()) - def create_fresh_connection() -> None: - """Create a new gRPC channel and stub, invalidating any existing ones. - - Raises: - Exception: If connection creation or testing fails. - """ - nonlocal current_channel, current_stub, conn_retry_count - - # Close existing connection if any - if current_channel: - try: - current_channel.close() - except Exception: - pass + self._logger.info(f"Starting gRPC worker that connects to {self._host_address}") + self._runLoop = Thread(target=run_loop) + self._runLoop.start() + self._is_running = True + async def _async_run_loop(self): + self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options) + worker_task = asyncio.create_task(self._async_worker_manager.run()) + # Connection state management for retry fix + current_channel = None + current_stub = None + conn_retry_count = 0 + conn_max_retry_delay = 60 + + def create_fresh_connection(): + nonlocal current_channel, current_stub, conn_retry_count + if current_channel: + try: + current_channel.close() + except Exception: + pass + current_channel = None + current_stub = None + try: + current_channel = shared.get_grpc_channel( + self._host_address, self._secure_channel, self._interceptors + ) + current_stub = stubs.TaskHubSidecarServiceStub(current_channel) + current_stub.Hello(empty_pb2.Empty()) + conn_retry_count = 0 + self._logger.debug(f"Created fresh connection to {self._host_address}") + except Exception as e: + self._logger.debug(f"Failed to create connection: {e}") current_channel = None current_stub = None + raise + def invalidate_connection(): + nonlocal current_channel, current_stub + if current_channel: + try: + current_channel.close() + except Exception: + pass + current_channel = None + current_stub = None + + def should_invalidate_connection(rpc_error): + error_code = rpc_error.code() # type: ignore + connection_level_errors = { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.CANCELLED, + grpc.StatusCode.UNAUTHENTICATED, + grpc.StatusCode.ABORTED, + } + return error_code in connection_level_errors + + while not self._shutdown.is_set(): + if current_stub is None: try: - # Create new connection - current_channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors) - current_stub = stubs.TaskHubSidecarServiceStub(current_channel) - - # Test the connection - current_stub.Hello(empty_pb2.Empty()) - conn_retry_count = 0 # Reset on successful connection - self._logger.debug(f"Created fresh connection to {self._host_address}") - - except Exception as e: - self._logger.debug(f"Failed to create connection: {e}") - current_channel = None - current_stub = None - raise # Re-raise the original exception - - def invalidate_connection() -> None: - """Mark current connection as invalid.""" - nonlocal current_channel, current_stub - if current_channel: + create_fresh_connection() + except Exception: + conn_retry_count += 1 + delay = min( + conn_max_retry_delay, + (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1), + ) + self._logger.warning( + f"Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})" + ) + if self._shutdown.wait(delay): + break + continue + try: + assert current_stub is not None + stub = current_stub + get_work_items_request = pb.GetWorkItemsRequest( + maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items, + maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items, + ) + self._response_stream = stub.GetWorkItems(get_work_items_request) + self._logger.info( + f"Successfully connected to {self._host_address}. Waiting for work items..." + ) + + # Use a thread to read from the blocking gRPC stream and forward to asyncio + import queue + + work_item_queue = queue.Queue() + + def stream_reader(): try: - current_channel.close() - except Exception: - pass - current_channel = None - current_stub = None - - def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool: - """Determine if a gRPC error should trigger connection invalidation. - - Connection-level errors (network, authentication, server unavailable) - should invalidate the connection, while application-level errors - (bad requests, not found, etc.) should not. - """ - error_code = rpc_error.code() # type: ignore + for work_item in self._response_stream: + work_item_queue.put(work_item) + except Exception as e: + work_item_queue.put(e) - # Connection-level errors that warrant invalidation - connection_level_errors = { - grpc.StatusCode.UNAVAILABLE, # Server down/unreachable - grpc.StatusCode.DEADLINE_EXCEEDED, # Timeout, likely network issue - grpc.StatusCode.CANCELLED, # Connection cancelled - grpc.StatusCode.UNAUTHENTICATED, # Auth failed, may need new connection - grpc.StatusCode.ABORTED, # Transaction aborted, connection may be bad - } + import threading - return error_code in connection_level_errors - - # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity - # functions. We'd need to know ahead of time whether a function is async or not. - with concurrent.futures.ThreadPoolExecutor(max_workers=self._concurrency_options.max_total_workers, thread_name_prefix="DurableTask") as executor: + reader_thread = threading.Thread(target=stream_reader, daemon=True) + reader_thread.start() + loop = asyncio.get_running_loop() while not self._shutdown.is_set(): - # Ensure we have a valid connection before attempting work - if current_stub is None: - try: - create_fresh_connection() - except Exception: - # Connection failed, implement exponential backoff - conn_retry_count += 1 - delay = min(conn_max_retry_delay, (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1)) - self._logger.warning(f'Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})') - if self._shutdown.wait(delay): - break # Shutdown requested during wait - continue - try: - # Stream work items with the current connection - # Type assertion since we know current_stub is not None at this point - assert current_stub is not None, "current_stub should not be None at this point" - stub = current_stub # Local reference for type safety - - # Create GetWorkItemsRequest with concurrency limits - get_work_items_request = pb.GetWorkItemsRequest( - maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items, - maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items + work_item = await loop.run_in_executor( + None, work_item_queue.get ) - self._response_stream = stub.GetWorkItems(get_work_items_request) - self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...') - - # Process work items concurrently as they arrive - for work_item in self._response_stream: # type: ignore - if self._shutdown.is_set(): - break - - request_type = work_item.WhichOneof('request') - self._logger.debug(f'Received "{request_type}" work item') - - # Submit work items to thread pool for concurrent processing - if work_item.HasField('orchestratorRequest'): - executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken) - elif work_item.HasField('activityRequest'): - executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken) - elif work_item.HasField('healthPing'): - pass # no-op - else: - self._logger.warning(f'Unexpected work item type: {request_type}') - - # Stream ended normally (shouldn't happen unless server closes) - self._logger.info("Work item stream ended normally") - - except grpc.RpcError as rpc_error: - # Intelligently decide whether to invalidate connection based on error type - should_invalidate = should_invalidate_connection(rpc_error) - if should_invalidate: - invalidate_connection() - - error_code = rpc_error.code() # type: ignore - if error_code == grpc.StatusCode.CANCELLED: - self._logger.info(f'Disconnected from {self._host_address}') - break # Likely shutdown - elif error_code == grpc.StatusCode.UNAVAILABLE: - self._logger.warning(f'The sidecar at address {self._host_address} is unavailable - will continue retrying') - elif should_invalidate: - self._logger.warning(f'Connection-level gRPC error ({error_code}): {rpc_error} - invalidating connection') + if isinstance(work_item, Exception): + raise work_item + request_type = work_item.WhichOneof("request") + self._logger.debug(f'Received "{request_type}" work item') + if work_item.HasField("orchestratorRequest"): + self._async_worker_manager.submit_orchestration( + self._execute_orchestrator, + work_item.orchestratorRequest, + stub, + work_item.completionToken, + ) + elif work_item.HasField("activityRequest"): + self._async_worker_manager.submit_activity( + self._execute_activity, + work_item.activityRequest, + stub, + work_item.completionToken, + ) + elif work_item.HasField("healthPing"): + pass else: - self._logger.warning(f'Application-level gRPC error ({error_code}): {rpc_error} - keeping connection') - - # Brief pause before retry - self._shutdown.wait(1) - - except Exception as ex: - # Unexpected error, invalidate connection and retry - invalidate_connection() - self._logger.warning(f'Unexpected error: {ex}') - self._shutdown.wait(1) - - # Final cleanup + self._logger.warning( + f"Unexpected work item type: {request_type}" + ) + except Exception as e: + self._logger.warning(f"Error in work item stream: {e}") + break + reader_thread.join(timeout=1) + self._logger.info("Work item stream ended normally") + except grpc.RpcError as rpc_error: + should_invalidate = should_invalidate_connection(rpc_error) + if should_invalidate: + invalidate_connection() + error_code = rpc_error.code() # type: ignore + if error_code == grpc.StatusCode.CANCELLED: + self._logger.info(f"Disconnected from {self._host_address}") + break + elif error_code == grpc.StatusCode.UNAVAILABLE: + self._logger.warning( + f"The sidecar at address {self._host_address} is unavailable - will continue retrying" + ) + elif should_invalidate: + self._logger.warning( + f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection" + ) + else: + self._logger.warning( + f"Application-level gRPC error ({error_code}): {rpc_error}" + ) + self._shutdown.wait(1) + except Exception as ex: invalidate_connection() - - self._logger.info("No longer listening for work items") - - self._logger.info(f"Starting gRPC worker that connects to {self._host_address}") - self._runLoop = Thread(target=run_loop) - self._runLoop.start() - self._is_running = True + self._logger.warning(f"Unexpected error: {ex}") + self._shutdown.wait(1) + invalidate_connection() + self._logger.info("No longer listening for work items") + self._async_worker_manager.shutdown() + await worker_task def stop(self): """Stops the worker and waits for any pending work items to complete.""" @@ -363,10 +391,16 @@ def stop(self): self._response_stream.cancel() if self._runLoop is not None: self._runLoop.join(timeout=30) + self._async_worker_manager.shutdown() self._logger.info("Worker shutdown completed") self._is_running = False - def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken): + def _execute_orchestrator( + self, + req: pb.OrchestratorRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): try: executor = _OrchestrationExecutor(self._registry, self._logger) result = executor.execute(req.instanceId, req.pastEvents, req.newEvents) @@ -374,40 +408,63 @@ def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHub instanceId=req.instanceId, actions=result.actions, customStatus=pbh.get_string_value(result.encoded_custom_status), - completionToken=completionToken) + completionToken=completionToken, + ) except Exception as ex: - self._logger.exception(f"An error occurred while trying to execute instance '{req.instanceId}': {ex}") + self._logger.exception( + f"An error occurred while trying to execute instance '{req.instanceId}': {ex}" + ) failure_details = pbh.new_failure_details(ex) - actions = [pbh.new_complete_orchestration_action(-1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details)] - res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions, completionToken=completionToken) + actions = [ + pbh.new_complete_orchestration_action( + -1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details + ) + ] + res = pb.OrchestratorResponse( + instanceId=req.instanceId, + actions=actions, + completionToken=completionToken, + ) try: stub.CompleteOrchestratorTask(res) except Exception as ex: - self._logger.exception(f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}") - - def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken): + self._logger.exception( + f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}" + ) + + def _execute_activity( + self, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): instance_id = req.orchestrationInstance.instanceId try: executor = _ActivityExecutor(self._registry, self._logger) - result = executor.execute(instance_id, req.name, req.taskId, req.input.value) + result = executor.execute( + instance_id, req.name, req.taskId, req.input.value + ) res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, result=pbh.get_string_value(result), - completionToken=completionToken) + completionToken=completionToken, + ) except Exception as ex: res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, failureDetails=pbh.new_failure_details(ex), - completionToken=completionToken) + completionToken=completionToken, + ) try: stub.CompleteActivityTask(res) except Exception as ex: self._logger.exception( - f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}") + f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" + ) class _RuntimeOrchestrationContext(task.OrchestrationContext): @@ -441,7 +498,9 @@ def run(self, generator: Generator[task.Task, Any, Any]): def resume(self): if self._generator is None: # This is never expected unless maybe there's an issue with the history - raise TypeError("The orchestrator generator is not initialized! Was the orchestration history corrupted?") + raise TypeError( + "The orchestrator generator is not initialized! Was the orchestration history corrupted?" + ) # We can resume the generator only if the previously yielded task # has reached a completed state. The only time this won't be the @@ -462,7 +521,12 @@ def resume(self): raise TypeError("The orchestrator generator yielded a non-Task object") self._previous_task = next_task - def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_encoded: bool = False): + def set_complete( + self, + result: Any, + status: pb.OrchestrationStatus, + is_result_encoded: bool = False, + ): if self._is_complete: return @@ -475,7 +539,8 @@ def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_en if result is not None: result_json = result if is_result_encoded else shared.to_json(result) action = ph.new_complete_orchestration_action( - self.next_sequence_number(), status, result_json) + self.next_sequence_number(), status, result_json + ) self._pending_actions[action.id] = action def set_failed(self, ex: Exception): @@ -487,7 +552,10 @@ def set_failed(self, ex: Exception): self._completion_status = pb.ORCHESTRATION_STATUS_FAILED action = ph.new_complete_orchestration_action( - self.next_sequence_number(), pb.ORCHESTRATION_STATUS_FAILED, None, ph.new_failure_details(ex) + self.next_sequence_number(), + pb.ORCHESTRATION_STATUS_FAILED, + None, + ph.new_failure_details(ex), ) self._pending_actions[action.id] = action @@ -511,14 +579,21 @@ def get_actions(self) -> list[pb.OrchestratorAction]: # replayed when the new instance starts. for event_name, values in self._received_events.items(): for event_value in values: - encoded_value = shared.to_json(event_value) if event_value else None - carryover_events.append(ph.new_event_raised_event(event_name, encoded_value)) + encoded_value = ( + shared.to_json(event_value) if event_value else None + ) + carryover_events.append( + ph.new_event_raised_event(event_name, encoded_value) + ) action = ph.new_complete_orchestration_action( self.next_sequence_number(), pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW, - result=shared.to_json(self._new_input) if self._new_input is not None else None, + result=shared.to_json(self._new_input) + if self._new_input is not None + else None, failure_details=None, - carryover_events=carryover_events) + carryover_events=carryover_events, + ) return [action] else: return list(self._pending_actions.values()) @@ -544,13 +619,18 @@ def is_replaying(self) -> bool: return self._is_replaying def set_custom_status(self, custom_status: Any) -> None: - self._encoded_custom_status = shared.to_json(custom_status) if custom_status is not None else None + self._encoded_custom_status = ( + shared.to_json(custom_status) if custom_status is not None else None + ) def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: return self.create_timer_internal(fire_at) - def create_timer_internal(self, fire_at: Union[datetime, timedelta], - retryable_task: Optional[task.RetryableTask] = None) -> task.Task: + def create_timer_internal( + self, + fire_at: Union[datetime, timedelta], + retryable_task: Optional[task.RetryableTask] = None, + ) -> task.Task: id = self.next_sequence_number() if isinstance(fire_at, timedelta): fire_at = self.current_utc_datetime + fire_at @@ -563,32 +643,51 @@ def create_timer_internal(self, fire_at: Union[datetime, timedelta], self._pending_tasks[id] = timer_task return timer_task - def call_activity(self, activity: Union[task.Activity[TInput, TOutput], str], *, - input: Optional[TInput] = None, - retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]: + def call_activity( + self, + activity: Union[task.Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None, + ) -> task.Task[TOutput]: id = self.next_sequence_number() - self.call_activity_function_helper(id, activity, input=input, retry_policy=retry_policy, - is_sub_orch=False) + self.call_activity_function_helper( + id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False + ) return self._pending_tasks.get(id, task.CompletableTask()) - def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]: + def call_sub_orchestrator( + self, + orchestrator: task.Orchestrator[TInput, TOutput], + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[task.RetryPolicy] = None, + ) -> task.Task[TOutput]: id = self.next_sequence_number() orchestrator_name = task.get_name(orchestrator) - self.call_activity_function_helper(id, orchestrator_name, input=input, retry_policy=retry_policy, - is_sub_orch=True, instance_id=instance_id) + self.call_activity_function_helper( + id, + orchestrator_name, + input=input, + retry_policy=retry_policy, + is_sub_orch=True, + instance_id=instance_id, + ) return self._pending_tasks.get(id, task.CompletableTask()) - def call_activity_function_helper(self, id: Optional[int], - activity_function: Union[task.Activity[TInput, TOutput], str], *, - input: Optional[TInput] = None, - retry_policy: Optional[task.RetryPolicy] = None, - is_sub_orch: bool = False, - instance_id: Optional[str] = None, - fn_task: Optional[task.CompletableTask[TOutput]] = None): + def call_activity_function_helper( + self, + id: Optional[int], + activity_function: Union[task.Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None, + is_sub_orch: bool = False, + instance_id: Optional[str] = None, + fn_task: Optional[task.CompletableTask[TOutput]] = None, + ): if id is None: id = self.next_sequence_number() @@ -599,7 +698,11 @@ def call_activity_function_helper(self, id: Optional[int], # We just need to take string representation of it. encoded_input = str(input) if not is_sub_orch: - name = activity_function if isinstance(activity_function, str) else task.get_name(activity_function) + name = ( + activity_function + if isinstance(activity_function, str) + else task.get_name(activity_function) + ) action = ph.new_schedule_task_action(id, name, encoded_input) else: if instance_id is None: @@ -607,16 +710,21 @@ def call_activity_function_helper(self, id: Optional[int], instance_id = f"{self.instance_id}:{id:04x}" if not isinstance(activity_function, str): raise ValueError("Orchestrator function name must be a string") - action = ph.new_create_sub_orchestration_action(id, activity_function, instance_id, encoded_input) + action = ph.new_create_sub_orchestration_action( + id, activity_function, instance_id, encoded_input + ) self._pending_actions[id] = action if fn_task is None: if retry_policy is None: fn_task = task.CompletableTask[TOutput]() else: - fn_task = task.RetryableTask[TOutput](retry_policy=retry_policy, action=action, - start_time=self.current_utc_datetime, - is_sub_orch=is_sub_orch) + fn_task = task.RetryableTask[TOutput]( + retry_policy=retry_policy, + action=action, + start_time=self.current_utc_datetime, + is_sub_orch=is_sub_orch, + ) self._pending_tasks[id] = fn_task def wait_for_external_event(self, name: str) -> task.Task: @@ -652,7 +760,9 @@ class ExecutionResults: actions: list[pb.OrchestratorAction] encoded_custom_status: Optional[str] - def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]): + def __init__( + self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str] + ): self.actions = actions self.encoded_custom_status = encoded_custom_status @@ -666,14 +776,23 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] - def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_events: Sequence[pb.HistoryEvent]) -> ExecutionResults: + def execute( + self, + instance_id: str, + old_events: Sequence[pb.HistoryEvent], + new_events: Sequence[pb.HistoryEvent], + ) -> ExecutionResults: if not new_events: - raise task.OrchestrationStateError("The new history event list must have at least one event in it.") + raise task.OrchestrationStateError( + "The new history event list must have at least one event in it." + ) ctx = _RuntimeOrchestrationContext(instance_id) try: # Rebuild local state by replaying old history into the orchestrator function - self._logger.debug(f"{instance_id}: Rebuilding local state with {len(old_events)} history event...") + self._logger.debug( + f"{instance_id}: Rebuilding local state with {len(old_events)} history event..." + ) ctx._is_replaying = True for old_event in old_events: self.process_event(ctx, old_event) @@ -681,7 +800,9 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e # Get new actions by executing newly received events into the orchestrator function if self._logger.level <= logging.DEBUG: summary = _get_new_event_summary(new_events) - self._logger.debug(f"{instance_id}: Processing {len(new_events)} new event(s): {summary}") + self._logger.debug( + f"{instance_id}: Processing {len(new_events)} new event(s): {summary}" + ) ctx._is_replaying = False for new_event in new_events: self.process_event(ctx, new_event) @@ -693,17 +814,31 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e if not ctx._is_complete: task_count = len(ctx._pending_tasks) event_count = len(ctx._pending_events) - self._logger.info(f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding.") - elif ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW: - completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status) - self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}") + self._logger.info( + f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding." + ) + elif ( + ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW + ): + completion_status_str = pbh.get_orchestration_status_str( + ctx._completion_status + ) + self._logger.info( + f"{instance_id}: Orchestration completed with status: {completion_status_str}" + ) actions = ctx.get_actions() if self._logger.level <= logging.DEBUG: - self._logger.debug(f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}") - return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status) + self._logger.debug( + f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}" + ) + return ExecutionResults( + actions=actions, encoded_custom_status=ctx._encoded_custom_status + ) - def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None: + def process_event( + self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent + ) -> None: if self._is_suspended and _is_suspendable(event): # We are suspended, so we need to buffer this event until we are resumed self._suspended_events.append(event) @@ -718,14 +853,19 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven fn = self._registry.get_orchestrator(event.executionStarted.name) if fn is None: raise OrchestratorNotRegisteredError( - f"A '{event.executionStarted.name}' orchestrator was not registered.") + f"A '{event.executionStarted.name}' orchestrator was not registered." + ) # deserialize the input, if any input = None - if event.executionStarted.input is not None and event.executionStarted.input.value != "": + if ( + event.executionStarted.input is not None and event.executionStarted.input.value != "" + ): input = shared.from_json(event.executionStarted.input.value) - result = fn(ctx, input) # this does not execute the generator, only creates it + result = fn( + ctx, input + ) # this does not execute the generator, only creates it if isinstance(result, GeneratorType): # Start the orchestrator's generator function ctx.run(result) @@ -738,10 +878,14 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven timer_id = event.eventId action = ctx._pending_actions.pop(timer_id, None) if not action: - raise _get_non_determinism_error(timer_id, task.get_name(ctx.create_timer)) + raise _get_non_determinism_error( + timer_id, task.get_name(ctx.create_timer) + ) elif not action.HasField("createTimer"): expected_method_name = task.get_name(ctx.create_timer) - raise _get_wrong_action_type_error(timer_id, expected_method_name, action) + raise _get_wrong_action_type_error( + timer_id, expected_method_name, action + ) elif event.HasField("timerFired"): timer_id = event.timerFired.timerId timer_task = ctx._pending_tasks.pop(timer_id, None) @@ -749,7 +893,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx._is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}.") + f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}." + ) return timer_task.complete(None) if timer_task._retryable_parent is not None: @@ -761,12 +906,15 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven else: cur_task = activity_action.createSubOrchestration instance_id = cur_task.instanceId - ctx.call_activity_function_helper(id=activity_action.id, activity_function=cur_task.name, - input=cur_task.input.value, - retry_policy=timer_task._retryable_parent._retry_policy, - is_sub_orch=timer_task._retryable_parent._is_sub_orch, - instance_id=instance_id, - fn_task=timer_task._retryable_parent) + ctx.call_activity_function_helper( + id=activity_action.id, + activity_function=cur_task.name, + input=cur_task.input.value, + retry_policy=timer_task._retryable_parent._retry_policy, + is_sub_orch=timer_task._retryable_parent._is_sub_orch, + instance_id=instance_id, + fn_task=timer_task._retryable_parent, + ) else: ctx.resume() elif event.HasField("taskScheduled"): @@ -776,16 +924,21 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven action = ctx._pending_actions.pop(task_id, None) activity_task = ctx._pending_tasks.get(task_id, None) if not action: - raise _get_non_determinism_error(task_id, task.get_name(ctx.call_activity)) + raise _get_non_determinism_error( + task_id, task.get_name(ctx.call_activity) + ) elif not action.HasField("scheduleTask"): expected_method_name = task.get_name(ctx.call_activity) - raise _get_wrong_action_type_error(task_id, expected_method_name, action) + raise _get_wrong_action_type_error( + task_id, expected_method_name, action + ) elif action.scheduleTask.name != event.taskScheduled.name: raise _get_wrong_action_name_error( task_id, method_name=task.get_name(ctx.call_activity), expected_task_name=event.taskScheduled.name, - actual_task_name=action.scheduleTask.name) + actual_task_name=action.scheduleTask.name, + ) elif event.HasField("taskCompleted"): # This history event contains the result of a completed activity task. task_id = event.taskCompleted.taskScheduledId @@ -794,7 +947,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx.is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}.") + f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}." + ) return result = None if not ph.is_empty(event.taskCompleted.result): @@ -808,7 +962,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx.is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.") + f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}." + ) return if isinstance(activity_task, task.RetryableTask): @@ -817,7 +972,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if next_delay is None: activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", - event.taskFailed.failureDetails) + event.taskFailed.failureDetails, + ) ctx.resume() else: activity_task.increment_attempt_count() @@ -825,7 +981,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven elif isinstance(activity_task, task.CompletableTask): activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", - event.taskFailed.failureDetails) + event.taskFailed.failureDetails, + ) ctx.resume() else: raise TypeError("Unexpected task type") @@ -835,16 +992,23 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven task_id = event.eventId action = ctx._pending_actions.pop(task_id, None) if not action: - raise _get_non_determinism_error(task_id, task.get_name(ctx.call_sub_orchestrator)) + raise _get_non_determinism_error( + task_id, task.get_name(ctx.call_sub_orchestrator) + ) elif not action.HasField("createSubOrchestration"): expected_method_name = task.get_name(ctx.call_sub_orchestrator) - raise _get_wrong_action_type_error(task_id, expected_method_name, action) - elif action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name: + raise _get_wrong_action_type_error( + task_id, expected_method_name, action + ) + elif ( + action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name + ): raise _get_wrong_action_name_error( task_id, method_name=task.get_name(ctx.call_sub_orchestrator), expected_task_name=event.subOrchestrationInstanceCreated.name, - actual_task_name=action.createSubOrchestration.name) + actual_task_name=action.createSubOrchestration.name, + ) elif event.HasField("subOrchestrationInstanceCompleted"): task_id = event.subOrchestrationInstanceCompleted.taskScheduledId sub_orch_task = ctx._pending_tasks.pop(task_id, None) @@ -852,11 +1016,14 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx.is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}.") + f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}." + ) return result = None if not ph.is_empty(event.subOrchestrationInstanceCompleted.result): - result = shared.from_json(event.subOrchestrationInstanceCompleted.result.value) + result = shared.from_json( + event.subOrchestrationInstanceCompleted.result.value + ) sub_orch_task.complete(result) ctx.resume() elif event.HasField("subOrchestrationInstanceFailed"): @@ -867,7 +1034,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx.is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}.") + f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}." + ) return if isinstance(sub_orch_task, task.RetryableTask): if sub_orch_task._retry_policy is not None: @@ -875,7 +1043,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if next_delay is None: sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", - failedEvent.failureDetails) + failedEvent.failureDetails, + ) ctx.resume() else: sub_orch_task.increment_attempt_count() @@ -883,7 +1052,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven elif isinstance(sub_orch_task, task.CompletableTask): sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", - failedEvent.failureDetails) + failedEvent.failureDetails, + ) ctx.resume() else: raise TypeError("Unexpected sub-orchestration task type") @@ -912,7 +1082,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven decoded_result = shared.from_json(event.eventRaised.input.value) event_list.append(decoded_result) if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it.") + self._logger.info( + f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it." + ) elif event.HasField("executionSuspended"): if not self._is_suspended and not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Execution suspended.") @@ -927,11 +1099,21 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven elif event.HasField("executionTerminated"): if not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Execution terminating.") - encoded_output = event.executionTerminated.input.value if not ph.is_empty(event.executionTerminated.input) else None - ctx.set_complete(encoded_output, pb.ORCHESTRATION_STATUS_TERMINATED, is_result_encoded=True) + encoded_output = ( + event.executionTerminated.input.value + if not ph.is_empty(event.executionTerminated.input) + else None + ) + ctx.set_complete( + encoded_output, + pb.ORCHESTRATION_STATUS_TERMINATED, + is_result_encoded=True, + ) else: eventType = event.WhichOneof("eventType") - raise task.OrchestrationStateError(f"Don't know how to handle event of type '{eventType}'") + raise task.OrchestrationStateError( + f"Don't know how to handle event of type '{eventType}'" + ) except StopIteration as generatorStopped: # The orchestrator generator function completed ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED) @@ -942,12 +1124,22 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._registry = registry self._logger = logger - def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: Optional[str]) -> Optional[str]: + def execute( + self, + orchestration_id: str, + name: str, + task_id: int, + encoded_input: Optional[str], + ) -> Optional[str]: """Executes an activity function and returns the serialized result, if any.""" - self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...") + self._logger.debug( + f"{orchestration_id}/{task_id}: Executing activity '{name}'..." + ) fn = self._registry.get_activity(name) if not fn: - raise ActivityNotRegisteredError(f"Activity function named '{name}' was not registered!") + raise ActivityNotRegisteredError( + f"Activity function named '{name}' was not registered!" + ) activity_input = shared.from_json(encoded_input) if encoded_input else None ctx = task.ActivityContext(orchestration_id, task_id) @@ -955,49 +1147,54 @@ def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: # Execute the activity function activity_output = fn(ctx, activity_input) - encoded_output = shared.to_json(activity_output) if activity_output is not None else None + encoded_output = ( + shared.to_json(activity_output) if activity_output is not None else None + ) chars = len(encoded_output) if encoded_output else 0 self._logger.debug( - f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output.") + f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output." + ) return encoded_output -def _get_non_determinism_error(task_id: int, action_name: str) -> task.NonDeterminismError: +def _get_non_determinism_error( + task_id: int, action_name: str +) -> task.NonDeterminismError: return task.NonDeterminismError( f"A previous execution called {action_name} with ID={task_id}, but the current " f"execution doesn't have this action with this ID. This problem occurs when either " f"the orchestration has non-deterministic logic or if the code was changed after an " - f"instance of this orchestration already started running.") + f"instance of this orchestration already started running." + ) def _get_wrong_action_type_error( - task_id: int, - expected_method_name: str, - action: pb.OrchestratorAction) -> task.NonDeterminismError: + task_id: int, expected_method_name: str, action: pb.OrchestratorAction +) -> task.NonDeterminismError: unexpected_method_name = _get_method_name_for_action(action) return task.NonDeterminismError( f"Failed to restore orchestration state due to a history mismatch: A previous execution called " f"{expected_method_name} with ID={task_id}, but the current execution is instead trying to call " f"{unexpected_method_name} as part of rebuilding it's history. This kind of mismatch can happen if an " f"orchestration has non-deterministic logic or if the code was changed after an instance of this " - f"orchestration already started running.") + f"orchestration already started running." + ) def _get_wrong_action_name_error( - task_id: int, - method_name: str, - expected_task_name: str, - actual_task_name: str) -> task.NonDeterminismError: + task_id: int, method_name: str, expected_task_name: str, actual_task_name: str +) -> task.NonDeterminismError: return task.NonDeterminismError( f"Failed to restore orchestration state due to a history mismatch: A previous execution called " f"{method_name} with name='{expected_task_name}' and sequence number {task_id}, but the current " f"execution is instead trying to call {actual_task_name} as part of rebuilding it's history. " f"This kind of mismatch can happen if an orchestration has non-deterministic logic or if the code " - f"was changed after an instance of this orchestration already started running.") + f"was changed after an instance of this orchestration already started running." + ) def _get_method_name_for_action(action: pb.OrchestratorAction) -> str: - action_type = action.WhichOneof('orchestratorActionType') + action_type = action.WhichOneof("orchestratorActionType") if action_type == "scheduleTask": return task.get_name(task.OrchestrationContext.call_activity) elif action_type == "createTimer": @@ -1019,7 +1216,7 @@ def _get_new_event_summary(new_events: Sequence[pb.HistoryEvent]) -> str: else: counts: dict[str, int] = {} for event in new_events: - event_type = event.WhichOneof('eventType') + event_type = event.WhichOneof("eventType") counts[event_type] = counts.get(event_type, 0) + 1 return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]" @@ -1033,11 +1230,72 @@ def _get_action_summary(new_actions: Sequence[pb.OrchestratorAction]) -> str: else: counts: dict[str, int] = {} for action in new_actions: - action_type = action.WhichOneof('orchestratorActionType') + action_type = action.WhichOneof("orchestratorActionType") counts[action_type] = counts.get(action_type, 0) + 1 return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]" def _is_suspendable(event: pb.HistoryEvent) -> bool: """Returns true if the event is one that can be suspended and resumed.""" - return event.WhichOneof("eventType") not in ["executionResumed", "executionTerminated"] + return event.WhichOneof("eventType") not in [ + "executionResumed", + "executionTerminated", + ] + + +class _AsyncWorkerManager: + def __init__(self, concurrency_options: ConcurrencyOptions): + self.activity_semaphore = asyncio.Semaphore( + concurrency_options.maximum_concurrent_activity_work_items + ) + self.orchestration_semaphore = asyncio.Semaphore( + concurrency_options.maximum_concurrent_orchestration_work_items + ) + self.activity_queue: asyncio.Queue = asyncio.Queue() + self.orchestration_queue: asyncio.Queue = asyncio.Queue() + self.thread_pool = ThreadPoolExecutor( + max_workers=concurrency_options.maximum_thread_pool_workers, + thread_name_prefix="DurableTask", + ) + self._shutdown = False + + async def run(self): + # Start background consumers for each work type + await asyncio.gather( + self._consume_queue(self.activity_queue, self.activity_semaphore), + self._consume_queue(self.orchestration_queue, self.orchestration_semaphore), + ) + + async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): + while True: + # Exit if shutdown is set and the queue is empty + if self._shutdown and queue.empty(): + break + try: + work = await asyncio.wait_for(queue.get(), timeout=1.0) + except asyncio.TimeoutError: + continue + func, args, kwargs = work + async with semaphore: + await self._run_func(func, *args, **kwargs) + queue.task_done() + + async def _run_func(self, func, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + loop = asyncio.get_running_loop() + # Avoid submitting to executor after shutdown + if getattr(self, '_shutdown', False) and getattr(self, 'thread_pool', None) and getattr(self.thread_pool, '_shutdown', False): + return None + return await loop.run_in_executor(self.thread_pool, lambda: func(*args, **kwargs)) + + def submit_activity(self, func, *args, **kwargs): + self.activity_queue.put_nowait((func, args, kwargs)) + + def submit_orchestration(self, func, *args, **kwargs): + self.orchestration_queue.put_nowait((func, args, kwargs)) + + def shutdown(self): + self._shutdown = True + self.thread_pool.shutdown(wait=True) diff --git a/tests/durabletask/test_concurrency_options.py b/tests/durabletask/test_concurrency_options.py index d963d92..df3ebec 100644 --- a/tests/durabletask/test_concurrency_options.py +++ b/tests/durabletask/test_concurrency_options.py @@ -12,10 +12,11 @@ def test_default_concurrency_options(): options = ConcurrencyOptions() processor_count = os.cpu_count() or 1 expected_default = 100 * processor_count + expected_workers = processor_count + 4 assert options.maximum_concurrent_activity_work_items == expected_default assert options.maximum_concurrent_orchestration_work_items == expected_default - assert options.max_total_workers == expected_default + assert options.maximum_thread_pool_workers == expected_workers def test_custom_concurrency_options(): @@ -23,45 +24,28 @@ def test_custom_concurrency_options(): options = ConcurrencyOptions( maximum_concurrent_activity_work_items=50, maximum_concurrent_orchestration_work_items=25, + maximum_thread_pool_workers=30, ) assert options.maximum_concurrent_activity_work_items == 50 assert options.maximum_concurrent_orchestration_work_items == 25 - assert options.max_total_workers == 50 # Max of both values + assert options.maximum_thread_pool_workers == 30 def test_partial_custom_options(): """Test that partially specified options use defaults for unspecified values.""" processor_count = os.cpu_count() or 1 expected_default = 100 * processor_count + expected_workers = processor_count + 4 options = ConcurrencyOptions( maximum_concurrent_activity_work_items=30 - # Leave other options as default ) assert options.maximum_concurrent_activity_work_items == 30 assert options.maximum_concurrent_orchestration_work_items == expected_default - assert ( - options.max_total_workers == expected_default - ) # Should be the default since it's larger - + assert options.maximum_thread_pool_workers == expected_workers -def test_max_total_workers_calculation(): - """Test that max_total_workers returns the maximum of all concurrency limits.""" - # Case 1: Activity is highest - options1 = ConcurrencyOptions( - maximum_concurrent_activity_work_items=100, - maximum_concurrent_orchestration_work_items=50, - ) - assert options1.max_total_workers == 100 - - # Case 2: Orchestration is highest - options2 = ConcurrencyOptions( - maximum_concurrent_activity_work_items=25, - maximum_concurrent_orchestration_work_items=100, - ) - assert options2.max_total_workers == 100 def test_worker_with_concurrency_options(): @@ -69,6 +53,7 @@ def test_worker_with_concurrency_options(): options = ConcurrencyOptions( maximum_concurrent_activity_work_items=10, maximum_concurrent_orchestration_work_items=20, + maximum_thread_pool_workers=15, ) worker = TaskHubGrpcWorker(concurrency_options=options) @@ -82,6 +67,7 @@ def test_worker_default_options(): processor_count = os.cpu_count() or 1 expected_default = 100 * processor_count + expected_workers = processor_count + 4 assert ( worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default @@ -89,6 +75,7 @@ def test_worker_default_options(): assert ( worker.concurrency_options.maximum_concurrent_orchestration_work_items == expected_default ) + assert worker.concurrency_options.maximum_thread_pool_workers == expected_workers def test_concurrency_options_property_access(): @@ -96,6 +83,7 @@ def test_concurrency_options_property_access(): options = ConcurrencyOptions( maximum_concurrent_activity_work_items=15, maximum_concurrent_orchestration_work_items=25, + maximum_thread_pool_workers=30, ) worker = TaskHubGrpcWorker(concurrency_options=options) @@ -107,21 +95,5 @@ def test_concurrency_options_property_access(): # Should have correct values assert retrieved_options.maximum_concurrent_activity_work_items == 15 assert retrieved_options.maximum_concurrent_orchestration_work_items == 25 + assert retrieved_options.maximum_thread_pool_workers == 30 - -def test_edge_cases(): - """Test edge cases like zero or very large values.""" - # Test with zeros (should still work) - options_zero = ConcurrencyOptions( - maximum_concurrent_activity_work_items=0, - maximum_concurrent_orchestration_work_items=0, - ) - assert options_zero.max_total_workers == 0 - - # Test with very large values - options_large = ConcurrencyOptions( - maximum_concurrent_activity_work_items=999999, - maximum_concurrent_orchestration_work_items=1, - ) - assert options_large.max_total_workers == 999999 - assert options_large.max_total_workers == 999999 diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py new file mode 100644 index 0000000..aab502e --- /dev/null +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -0,0 +1,135 @@ +import asyncio +import threading +import time + +from durabletask import ConcurrencyOptions +from durabletask.worker import TaskHubGrpcWorker + + +class DummyStub: + def __init__(self): + self.completed = [] + + def CompleteOrchestratorTask(self, res): + self.completed.append(('orchestrator', res)) + + def CompleteActivityTask(self, res): + self.completed.append(('activity', res)) + +class DummyRequest: + def __init__(self, kind, instance_id): + self.kind = kind + self.instanceId = instance_id + self.orchestrationInstance = type('O', (), {'instanceId': instance_id}) + self.name = 'dummy' + self.taskId = 1 + self.input = type('I', (), {'value': ''}) + self.pastEvents = [] + self.newEvents = [] + + def HasField(self, field): + return (field == 'orchestratorRequest' and self.kind == 'orchestrator') or \ + (field == 'activityRequest' and self.kind == 'activity') + + def WhichOneof(self, _): + return f'{self.kind}Request' + +class DummyCompletionToken: + pass + +def test_worker_concurrency_loop_sync(): + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=2, + maximum_concurrent_orchestration_work_items=1, + maximum_thread_pool_workers=2, + ) + worker = TaskHubGrpcWorker(concurrency_options=options) + stub = DummyStub() + + def dummy_orchestrator(req, stub, completionToken): + time.sleep(0.1) + stub.CompleteOrchestratorTask('ok') + + def dummy_activity(req, stub, completionToken): + time.sleep(0.1) + stub.CompleteActivityTask('ok') + + # Patch the worker's _execute_orchestrator and _execute_activity + worker._execute_orchestrator = dummy_orchestrator + worker._execute_activity = dummy_activity + + orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] + activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] + + async def run_test(): + # Start the worker manager's run loop in the background + worker_task = asyncio.create_task(worker._async_worker_manager.run()) + for req in orchestrator_requests: + worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + for req in activity_requests: + worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + await asyncio.sleep(1.0) + orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') + activity_count = sum(1 for t, _ in stub.completed if t == 'activity') + assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" + assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" + worker._async_worker_manager._shutdown = True + await worker_task + asyncio.run(run_test()) + +# Dummy orchestrator and activity for sync context +def dummy_orchestrator(ctx, input): + # Simulate some work + time.sleep(0.1) + return "orchestrator-done" + +def dummy_activity(ctx, input): + # Simulate some work + time.sleep(0.1) + return "activity-done" + +def test_worker_concurrency_sync(): + # Use small concurrency to make test observable + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=2, + maximum_concurrent_orchestration_work_items=2, + maximum_thread_pool_workers=2, + ) + worker = TaskHubGrpcWorker(concurrency_options=options) + worker.add_orchestrator(dummy_orchestrator) + worker.add_activity(dummy_activity) + + # Simulate submitting work items to the queues directly (bypassing gRPC) + # We'll use the internal _async_worker_manager for this test + manager = worker._async_worker_manager + results = [] + lock = threading.Lock() + + def make_work(kind, idx): + def fn(*args, **kwargs): + time.sleep(0.1) + with lock: + results.append((kind, idx)) + return f"{kind}-{idx}-done" + return fn + + # Submit more work than concurrency allows + for i in range(5): + manager.submit_orchestration(make_work("orch", i)) + manager.submit_activity(make_work("act", i)) + + # Run the manager loop in a thread (sync context) + def run_manager(): + asyncio.run(manager.run()) + + t = threading.Thread(target=run_manager) + t.start() + time.sleep(1.5) # Let work process + manager.shutdown() + # Unblock the consumers by putting dummy items in the queues + manager.activity_queue.put_nowait((lambda: None, (), {})) + manager.orchestration_queue.put_nowait((lambda: None, (), {})) + t.join(timeout=2) + + # Check that all work items completed + assert len(results) == 10 \ No newline at end of file diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py new file mode 100644 index 0000000..6cb25f0 --- /dev/null +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -0,0 +1,79 @@ +import asyncio + +from durabletask import ConcurrencyOptions +from durabletask.worker import TaskHubGrpcWorker + + +class DummyStub: + def __init__(self): + self.completed = [] + + def CompleteOrchestratorTask(self, res): + self.completed.append(('orchestrator', res)) + + def CompleteActivityTask(self, res): + self.completed.append(('activity', res)) + + +class DummyRequest: + def __init__(self, kind, instance_id): + self.kind = kind + self.instanceId = instance_id + self.orchestrationInstance = type('O', (), {'instanceId': instance_id}) + self.name = 'dummy' + self.taskId = 1 + self.input = type('I', (), {'value': ''}) + self.pastEvents = [] + self.newEvents = [] + + def HasField(self, field): + return (field == 'orchestratorRequest' and self.kind == 'orchestrator') or \ + (field == 'activityRequest' and self.kind == 'activity') + + def WhichOneof(self, _): + return f'{self.kind}Request' + + +class DummyCompletionToken: + pass + + +def test_worker_concurrency_loop_async(): + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=2, + maximum_concurrent_orchestration_work_items=1, + maximum_thread_pool_workers=2, + ) + worker = TaskHubGrpcWorker(concurrency_options=options) + stub = DummyStub() + + async def dummy_orchestrator(req, stub, completionToken): + await asyncio.sleep(0.1) + stub.CompleteOrchestratorTask('ok') + + async def dummy_activity(req, stub, completionToken): + await asyncio.sleep(0.1) + stub.CompleteActivityTask('ok') + + # Patch the worker's _execute_orchestrator and _execute_activity + worker._execute_orchestrator = dummy_orchestrator + worker._execute_activity = dummy_activity + + orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] + activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] + + async def run_test(): + worker_task = asyncio.create_task(worker._async_worker_manager.run()) + for req in orchestrator_requests: + worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + for req in activity_requests: + worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + await asyncio.sleep(1.0) + orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') + activity_count = sum(1 for t, _ in stub.completed if t == 'activity') + assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" + assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" + worker._async_worker_manager._shutdown = True + await worker_task + asyncio.run(run_test()) + asyncio.run(run_test()) From 66cbcb24297e01ad38e66396c0ca87a2c67a903f Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 16:34:08 -0700 Subject: [PATCH 05/18] more concurrency stuff --- durabletask/worker.py | 94 ++++++++++++++++--- .../test_worker_concurrency_loop_async.py | 2 + 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index 8c0abc1..2a25711 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -219,7 +219,6 @@ def run_loop(): self._is_running = True async def _async_run_loop(self): - self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options) worker_task = asyncio.create_task(self._async_worker_manager.run()) # Connection state management for retry fix current_channel = None @@ -1245,21 +1244,57 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool: class _AsyncWorkerManager: def __init__(self, concurrency_options: ConcurrencyOptions): - self.activity_semaphore = asyncio.Semaphore( - concurrency_options.maximum_concurrent_activity_work_items - ) - self.orchestration_semaphore = asyncio.Semaphore( - concurrency_options.maximum_concurrent_orchestration_work_items - ) + self.concurrency_options = concurrency_options + self.activity_semaphore = None + self.orchestration_semaphore = None self.activity_queue: asyncio.Queue = asyncio.Queue() self.orchestration_queue: asyncio.Queue = asyncio.Queue() + self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None + # Try to capture the current event loop when queues are created + try: + self._queue_event_loop = asyncio.get_running_loop() + except RuntimeError: + # No event loop running when manager was created + pass self.thread_pool = ThreadPoolExecutor( max_workers=concurrency_options.maximum_thread_pool_workers, thread_name_prefix="DurableTask", ) self._shutdown = False + def _ensure_queues_for_current_loop(self): + """Ensure queues are bound to the current event loop.""" + try: + current_loop = asyncio.get_running_loop() + except RuntimeError: + # No event loop running, can't create queues + return + + if self._queue_event_loop is current_loop and hasattr(self, 'activity_queue') and hasattr(self, 'orchestration_queue'): + # Queues are already bound to the current loop and exist + return + + # Need to recreate queues for the current event loop + # Create fresh queues - any items from previous event loops are dropped + self.activity_queue = asyncio.Queue() + self.orchestration_queue = asyncio.Queue() + self._queue_event_loop = current_loop + async def run(self): + # Reset shutdown flag in case this manager is being reused + self._shutdown = False + + # Ensure queues are properly bound to the current event loop + self._ensure_queues_for_current_loop() + + # Create semaphores in the current event loop + self.activity_semaphore = asyncio.Semaphore( + self.concurrency_options.maximum_concurrent_activity_work_items + ) + self.orchestration_semaphore = asyncio.Semaphore( + self.concurrency_options.maximum_concurrent_orchestration_work_items + ) + # Start background consumers for each work type await asyncio.gather( self._consume_queue(self.activity_queue, self.activity_semaphore), @@ -1267,18 +1302,34 @@ async def run(self): ) async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): + # List to track running tasks + running_tasks: set[asyncio.Task] = set() + while True: - # Exit if shutdown is set and the queue is empty - if self._shutdown and queue.empty(): + # Clean up completed tasks + done_tasks = {task for task in running_tasks if task.done()} + running_tasks -= done_tasks + + # Exit if shutdown is set and the queue is empty and no tasks are running + if self._shutdown and queue.empty() and not running_tasks: break + try: work = await asyncio.wait_for(queue.get(), timeout=1.0) except asyncio.TimeoutError: continue + func, args, kwargs = work - async with semaphore: + # Create a concurrent task for processing + task = asyncio.create_task(self._process_work_item(semaphore, queue, func, args, kwargs)) + running_tasks.add(task) + + async def _process_work_item(self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs): + async with semaphore: + try: await self._run_func(func, *args, **kwargs) - queue.task_done() + finally: + queue.task_done() async def _run_func(self, func, *args, **kwargs): if inspect.iscoroutinefunction(func): @@ -1291,11 +1342,32 @@ async def _run_func(self, func, *args, **kwargs): return await loop.run_in_executor(self.thread_pool, lambda: func(*args, **kwargs)) def submit_activity(self, func, *args, **kwargs): + self._ensure_queues_for_current_loop() self.activity_queue.put_nowait((func, args, kwargs)) def submit_orchestration(self, func, *args, **kwargs): + self._ensure_queues_for_current_loop() self.orchestration_queue.put_nowait((func, args, kwargs)) def shutdown(self): self._shutdown = True self.thread_pool.shutdown(wait=True) + + def reset_for_new_run(self): + """Reset the manager state for a new run.""" + self._shutdown = False + # Clear any existing queues - they'll be recreated when needed + if hasattr(self, 'activity_queue'): + # Clear existing queue by creating a new one + # This ensures no items from previous runs remain + try: + while not self.activity_queue.empty(): + self.activity_queue.get_nowait() + except Exception: + pass + if hasattr(self, 'orchestration_queue'): + try: + while not self.orchestration_queue.empty(): + self.orchestration_queue.get_nowait() + except Exception: + pass diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py index 6cb25f0..1a1f6c3 100644 --- a/tests/durabletask/test_worker_concurrency_loop_async.py +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -63,6 +63,8 @@ async def dummy_activity(req, stub, completionToken): activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] async def run_test(): + # Clear stub state before each run + stub.completed.clear() worker_task = asyncio.create_task(worker._async_worker_manager.run()) for req in orchestrator_requests: worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) From 3044d28673376a0ef2104e331ee75985ff8e8786 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:02:49 -0700 Subject: [PATCH 06/18] final touches --- .../durabletask/azuremanaged/worker.py | 53 ++++++++++-- durabletask-azuremanaged/pyproject.toml | 2 +- durabletask/worker.py | 84 ++++++++++++++++++- 3 files changed, 128 insertions(+), 11 deletions(-) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index fd3b1e4..0354bd8 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -5,19 +5,59 @@ from azure.core.credentials import TokenCredential -from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import ( - DTSDefaultClientInterceptorImpl, -) -from durabletask.worker import TaskHubGrpcWorker +from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ + DTSDefaultClientInterceptorImpl +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker # Worker class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerWorker(TaskHubGrpcWorker): + """A worker implementation for Azure Durable Task Scheduler (DTS). + + This class extends TaskHubGrpcWorker to provide integration with Azure's + Durable Task Scheduler service. It handles authentication via Azure credentials + and configures the necessary gRPC interceptors for DTS communication. + + Args: + host_address (str): The gRPC endpoint address of the DTS service. + taskhub (str): The name of the task hub. Cannot be empty. + token_credential (Optional[TokenCredential]): Azure credential for authentication. + If None, anonymous authentication will be used. + secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). + Defaults to True. + concurrency_options (Optional[ConcurrencyOptions], optional): Configuration + for controlling worker concurrency limits. If None, default concurrency + settings will be used. + + Raises: + ValueError: If taskhub is empty or None. + + Example: + >>> from azure.identity import DefaultAzureCredential + >>> from durabletask.azuremanaged import DurableTaskSchedulerWorker + >>> from durabletask import ConcurrencyOptions + >>> + >>> credential = DefaultAzureCredential() + >>> concurrency = ConcurrencyOptions(max_concurrent_activities=10) + >>> worker = DurableTaskSchedulerWorker( + ... host_address="my-dts-service.azure.com:443", + ... taskhub="my-task-hub", + ... token_credential=credential, + ... concurrency_options=concurrency + ... ) + + Note: + This worker automatically configures DTS-specific gRPC interceptors + for authentication and task hub routing. The parent class metadata + parameter is set to None since authentication is handled by the + DTS interceptor. + """ def __init__(self, *, host_address: str, taskhub: str, token_credential: Optional[TokenCredential], - secure_channel: bool = True): + secure_channel: bool = True, + concurrency_options: Optional[ConcurrencyOptions] = None): if not taskhub: raise ValueError("The taskhub value cannot be empty.") @@ -30,4 +70,5 @@ def __init__(self, *, host_address=host_address, secure_channel=secure_channel, metadata=None, - interceptors=interceptors) + interceptors=interceptors, + concurrency_options=concurrency_options) diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml index 3de2b53..250cfcc 100644 --- a/durabletask-azuremanaged/pyproject.toml +++ b/durabletask-azuremanaged/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "durabletask.azuremanaged" -version = "0.1.5" +version = "0.2.0" description = "Durable Task Python SDK provider implementation for the Azure Durable Task Scheduler" keywords = [ "durable", diff --git a/durabletask/worker.py b/durabletask/worker.py index 2a25711..b89ff66 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -28,11 +28,10 @@ class ConcurrencyOptions: - """Configuration options for controlling concurrency of different work item types. + """Configuration options for controlling concurrency of different work item types and the thread pool size. - This class mirrors the .NET DurableTask SDK's ConcurrencyOptions class, - providing fine-grained control over concurrent processing limits for - activities, orchestrations, and entities. + This class provides fine-grained control over concurrent processing limits for + activities, orchestrations and the thread pool size. """ def __init__( @@ -134,6 +133,83 @@ class ActivityNotRegisteredError(ValueError): class TaskHubGrpcWorker: + """A gRPC-based worker for processing durable task orchestrations and activities. + + This worker connects to a Durable Task backend service via gRPC to receive and process + work items including orchestration functions and activity functions. It provides + concurrent execution capabilities with configurable limits and automatic retry handling. + + The worker manages the complete lifecycle: + - Registers orchestrator and activity functions + - Connects to the gRPC backend service + - Receives work items and executes them concurrently + - Handles failures, retries, and state management + - Provides logging and monitoring capabilities + + Args: + host_address (Optional[str], optional): The gRPC endpoint address of the backend service. + Defaults to the value from environment variables or localhost. + metadata (Optional[list[tuple[str, str]]], optional): gRPC metadata to include with + requests. Used for authentication and routing. Defaults to None. + log_handler (optional): Custom logging handler for worker logs. Defaults to None. + log_formatter (Optional[logging.Formatter], optional): Custom log formatter. + Defaults to None. + secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). + Defaults to False. + interceptors (Optional[Sequence[shared.ClientInterceptor]], optional): Custom gRPC + interceptors to apply to the channel. Defaults to None. + concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for + controlling worker concurrency limits. If None, default settings are used. + + Attributes: + concurrency_options (ConcurrencyOptions): The current concurrency configuration. + + Example: + Basic worker setup: + + >>> from durabletask import TaskHubGrpcWorker, ConcurrencyOptions + >>> + >>> # Create worker with custom concurrency settings + >>> concurrency = ConcurrencyOptions( + ... maximum_concurrent_activity_work_items=50, + ... maximum_concurrent_orchestration_work_items=20 + ... ) + >>> worker = TaskHubGrpcWorker( + ... host_address="localhost:4001", + ... concurrency_options=concurrency + ... ) + >>> + >>> # Register functions + >>> @worker.add_orchestrator + ... def my_orchestrator(context, input): + ... result = yield context.call_activity("my_activity", input="hello") + ... return result + >>> + >>> @worker.add_activity + ... def my_activity(context, input): + ... return f"Processed: {input}" + >>> + >>> # Start the worker + >>> worker.start() + >>> # ... worker runs in background thread + >>> worker.stop() + + Using as context manager: + + >>> with TaskHubGrpcWorker() as worker: + ... worker.add_orchestrator(my_orchestrator) + ... worker.add_activity(my_activity) + ... worker.start() + ... # Worker automatically stops when exiting context + + Raises: + RuntimeError: If attempting to add orchestrators/activities while the worker is running, + or if starting a worker that is already running. + OrchestratorNotRegisteredError: If an orchestration work item references an + unregistered orchestrator function. + ActivityNotRegisteredError: If an activity work item references an unregistered + activity function. + """ _response_stream: Optional[grpc.Future] = None _interceptors: Optional[list[shared.ClientInterceptor]] = None From 5ce7a666d1b9f7f0ada0e1d42f721bb6d4ac4ddf Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:06:37 -0700 Subject: [PATCH 07/18] fix import --- durabletask-azuremanaged/durabletask/azuremanaged/worker.py | 2 +- tests/durabletask/test_concurrency_options.py | 3 +-- tests/durabletask/test_worker_concurrency_loop.py | 3 +-- tests/durabletask/test_worker_concurrency_loop_async.py | 3 +-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 0354bd8..46db4ad 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -35,7 +35,7 @@ class DurableTaskSchedulerWorker(TaskHubGrpcWorker): Example: >>> from azure.identity import DefaultAzureCredential >>> from durabletask.azuremanaged import DurableTaskSchedulerWorker - >>> from durabletask import ConcurrencyOptions + >>> from durabletask.worker import ConcurrencyOptions >>> >>> credential = DefaultAzureCredential() >>> concurrency = ConcurrencyOptions(max_concurrent_activities=10) diff --git a/tests/durabletask/test_concurrency_options.py b/tests/durabletask/test_concurrency_options.py index df3ebec..f884587 100644 --- a/tests/durabletask/test_concurrency_options.py +++ b/tests/durabletask/test_concurrency_options.py @@ -3,8 +3,7 @@ import os -from durabletask import ConcurrencyOptions -from durabletask.worker import TaskHubGrpcWorker +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker def test_default_concurrency_options(): diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py index aab502e..8e6eccb 100644 --- a/tests/durabletask/test_worker_concurrency_loop.py +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -2,8 +2,7 @@ import threading import time -from durabletask import ConcurrencyOptions -from durabletask.worker import TaskHubGrpcWorker +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker class DummyStub: diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py index 1a1f6c3..a6b4c5a 100644 --- a/tests/durabletask/test_worker_concurrency_loop_async.py +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -1,7 +1,6 @@ import asyncio -from durabletask import ConcurrencyOptions -from durabletask.worker import TaskHubGrpcWorker +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker class DummyStub: From 5d0d61bb93b59bed7a905750920166f78369acac Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:10:30 -0700 Subject: [PATCH 08/18] update log level --- durabletask/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index b89ff66..b9dd465 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -318,9 +318,9 @@ def create_fresh_connection(): current_stub = stubs.TaskHubSidecarServiceStub(current_channel) current_stub.Hello(empty_pb2.Empty()) conn_retry_count = 0 - self._logger.debug(f"Created fresh connection to {self._host_address}") + self._logger.info(f"Created fresh connection to {self._host_address}") except Exception as e: - self._logger.debug(f"Failed to create connection: {e}") + self._logger.warning(f"Failed to create connection: {e}") current_channel = None current_stub = None raise From f9713e3530ea240409015f1e48bbc3c125349091 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:19:23 -0700 Subject: [PATCH 09/18] fix exports --- durabletask/worker.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/durabletask/worker.py b/durabletask/worker.py index b9dd465..f58276d 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1447,3 +1447,10 @@ def reset_for_new_run(self): self.orchestration_queue.get_nowait() except Exception: pass + + +# Export public API +__all__ = [ + 'ConcurrencyOptions', + 'TaskHubGrpcWorker' +] From 7d85c7647292a4bb9eea594df67dc6cf997f0480 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:28:17 -0700 Subject: [PATCH 10/18] more fixup --- .../durabletask/azuremanaged/worker.py | 2 +- durabletask/worker.py | 28 +++++++++++++++++-- .../test_worker_concurrency_loop.py | 7 ++++- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 46db4ad..1135ae7 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -36,7 +36,7 @@ class DurableTaskSchedulerWorker(TaskHubGrpcWorker): >>> from azure.identity import DefaultAzureCredential >>> from durabletask.azuremanaged import DurableTaskSchedulerWorker >>> from durabletask.worker import ConcurrencyOptions - >>> + >>> >>> credential = DefaultAzureCredential() >>> concurrency = ConcurrencyOptions(max_concurrent_activities=10) >>> worker = DurableTaskSchedulerWorker( diff --git a/durabletask/worker.py b/durabletask/worker.py index f58276d..ecffd18 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -185,7 +185,7 @@ class TaskHubGrpcWorker: ... result = yield context.call_activity("my_activity", input="hello") ... return result >>> - >>> @worker.add_activity + >>> @worker.add_activity ... def my_activity(context, input): ... return f"Processed: {input}" >>> @@ -1351,11 +1351,35 @@ def _ensure_queues_for_current_loop(self): return # Need to recreate queues for the current event loop - # Create fresh queues - any items from previous event loops are dropped + # First, preserve any existing work items + existing_activity_items = [] + existing_orchestration_items = [] + + if hasattr(self, 'activity_queue'): + try: + while not self.activity_queue.empty(): + existing_activity_items.append(self.activity_queue.get_nowait()) + except Exception: + pass + + if hasattr(self, 'orchestration_queue'): + try: + while not self.orchestration_queue.empty(): + existing_orchestration_items.append(self.orchestration_queue.get_nowait()) + except Exception: + pass + + # Create fresh queues for the current event loop self.activity_queue = asyncio.Queue() self.orchestration_queue = asyncio.Queue() self._queue_event_loop = current_loop + # Restore the work items to the new queues + for item in existing_activity_items: + self.activity_queue.put_nowait(item) + for item in existing_orchestration_items: + self.orchestration_queue.put_nowait(item) + async def run(self): # Reset shutdown flag in case this manager is being reused self._shutdown = False diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py index 8e6eccb..fb13032 100644 --- a/tests/durabletask/test_worker_concurrency_loop.py +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -124,11 +124,16 @@ def run_manager(): t = threading.Thread(target=run_manager) t.start() time.sleep(1.5) # Let work process - manager.shutdown() + + # Signal shutdown but don't actually call shutdown() yet + manager._shutdown = True # Unblock the consumers by putting dummy items in the queues manager.activity_queue.put_nowait((lambda: None, (), {})) manager.orchestration_queue.put_nowait((lambda: None, (), {})) t.join(timeout=2) + # Now shutdown the thread pool + manager.thread_pool.shutdown(wait=True) + # Check that all work items completed assert len(results) == 10 \ No newline at end of file From 97359f7af9ed8bc3851744f2cba4ce03595a88dc Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:43:00 -0700 Subject: [PATCH 11/18] test updateS --- tests/durabletask/test_concurrency_options.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/tests/durabletask/test_concurrency_options.py b/tests/durabletask/test_concurrency_options.py index f884587..160c3ac 100644 --- a/tests/durabletask/test_concurrency_options.py +++ b/tests/durabletask/test_concurrency_options.py @@ -3,12 +3,12 @@ import os -from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker +from durabletask import worker def test_default_concurrency_options(): """Test that default concurrency options work correctly.""" - options = ConcurrencyOptions() + options = worker.ConcurrencyOptions() processor_count = os.cpu_count() or 1 expected_default = 100 * processor_count expected_workers = processor_count + 4 @@ -20,7 +20,7 @@ def test_default_concurrency_options(): def test_custom_concurrency_options(): """Test that custom concurrency options work correctly.""" - options = ConcurrencyOptions( + options = worker.ConcurrencyOptions( maximum_concurrent_activity_work_items=50, maximum_concurrent_orchestration_work_items=25, maximum_thread_pool_workers=30, @@ -37,7 +37,7 @@ def test_partial_custom_options(): expected_default = 100 * processor_count expected_workers = processor_count + 4 - options = ConcurrencyOptions( + options = worker.ConcurrencyOptions( maximum_concurrent_activity_work_items=30 ) @@ -46,47 +46,46 @@ def test_partial_custom_options(): assert options.maximum_thread_pool_workers == expected_workers - def test_worker_with_concurrency_options(): """Test that TaskHubGrpcWorker accepts concurrency options.""" - options = ConcurrencyOptions( + options = worker.ConcurrencyOptions( maximum_concurrent_activity_work_items=10, maximum_concurrent_orchestration_work_items=20, maximum_thread_pool_workers=15, ) - worker = TaskHubGrpcWorker(concurrency_options=options) + grpc_worker = worker.TaskHubGrpcWorker(concurrency_options=options) - assert worker.concurrency_options == options + assert grpc_worker.concurrency_options == options def test_worker_default_options(): """Test that TaskHubGrpcWorker uses default options when no parameters are provided.""" - worker = TaskHubGrpcWorker() + grpc_worker = worker.TaskHubGrpcWorker() processor_count = os.cpu_count() or 1 expected_default = 100 * processor_count expected_workers = processor_count + 4 assert ( - worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default + grpc_worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default ) assert ( - worker.concurrency_options.maximum_concurrent_orchestration_work_items == expected_default + grpc_worker.concurrency_options.maximum_concurrent_orchestration_work_items == expected_default ) - assert worker.concurrency_options.maximum_thread_pool_workers == expected_workers + assert grpc_worker.concurrency_options.maximum_thread_pool_workers == expected_workers def test_concurrency_options_property_access(): """Test that the concurrency_options property works correctly.""" - options = ConcurrencyOptions( + options = worker.ConcurrencyOptions( maximum_concurrent_activity_work_items=15, maximum_concurrent_orchestration_work_items=25, maximum_thread_pool_workers=30, ) - worker = TaskHubGrpcWorker(concurrency_options=options) - retrieved_options = worker.concurrency_options + grpc_worker = worker.TaskHubGrpcWorker(concurrency_options=options) + retrieved_options = grpc_worker.concurrency_options # Should be the same object assert retrieved_options is options @@ -95,4 +94,3 @@ def test_concurrency_options_property_access(): assert retrieved_options.maximum_concurrent_activity_work_items == 15 assert retrieved_options.maximum_concurrent_orchestration_work_items == 25 assert retrieved_options.maximum_thread_pool_workers == 30 - From 9a49240473ea0a4ef6c2e951fa89557ec6fd76e7 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:45:22 -0700 Subject: [PATCH 12/18] more test imports --- .../test_worker_concurrency_loop.py | 18 +++++++++--------- .../test_worker_concurrency_loop_async.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py index fb13032..ef0f291 100644 --- a/tests/durabletask/test_worker_concurrency_loop.py +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -2,7 +2,7 @@ import threading import time -from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker +from durabletask import worker class DummyStub: @@ -37,12 +37,12 @@ class DummyCompletionToken: pass def test_worker_concurrency_loop_sync(): - options = ConcurrencyOptions( + options = worker.ConcurrencyOptions( maximum_concurrent_activity_work_items=2, maximum_concurrent_orchestration_work_items=1, maximum_thread_pool_workers=2, ) - worker = TaskHubGrpcWorker(concurrency_options=options) + grpc_worker = worker.TaskHubGrpcWorker(concurrency_options=options) stub = DummyStub() def dummy_orchestrator(req, stub, completionToken): @@ -54,25 +54,25 @@ def dummy_activity(req, stub, completionToken): stub.CompleteActivityTask('ok') # Patch the worker's _execute_orchestrator and _execute_activity - worker._execute_orchestrator = dummy_orchestrator - worker._execute_activity = dummy_activity + grpc_worker._execute_orchestrator = dummy_orchestrator + grpc_worker._execute_activity = dummy_activity orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] async def run_test(): # Start the worker manager's run loop in the background - worker_task = asyncio.create_task(worker._async_worker_manager.run()) + worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) for req in orchestrator_requests: - worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) for req in activity_requests: - worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) await asyncio.sleep(1.0) orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') activity_count = sum(1 for t, _ in stub.completed if t == 'activity') assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" - worker._async_worker_manager._shutdown = True + grpc_worker._async_worker_manager._shutdown = True await worker_task asyncio.run(run_test()) diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py index a6b4c5a..70ff2ad 100644 --- a/tests/durabletask/test_worker_concurrency_loop_async.py +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -1,6 +1,6 @@ import asyncio -from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker +from durabletask import worker class DummyStub: @@ -38,12 +38,12 @@ class DummyCompletionToken: def test_worker_concurrency_loop_async(): - options = ConcurrencyOptions( + options = worker.ConcurrencyOptions( maximum_concurrent_activity_work_items=2, maximum_concurrent_orchestration_work_items=1, maximum_thread_pool_workers=2, ) - worker = TaskHubGrpcWorker(concurrency_options=options) + grpc_worker = worker.TaskHubGrpcWorker(concurrency_options=options) stub = DummyStub() async def dummy_orchestrator(req, stub, completionToken): @@ -55,8 +55,8 @@ async def dummy_activity(req, stub, completionToken): stub.CompleteActivityTask('ok') # Patch the worker's _execute_orchestrator and _execute_activity - worker._execute_orchestrator = dummy_orchestrator - worker._execute_activity = dummy_activity + grpc_worker._execute_orchestrator = dummy_orchestrator + grpc_worker._execute_activity = dummy_activity orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] @@ -64,17 +64,17 @@ async def dummy_activity(req, stub, completionToken): async def run_test(): # Clear stub state before each run stub.completed.clear() - worker_task = asyncio.create_task(worker._async_worker_manager.run()) + worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) for req in orchestrator_requests: - worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) for req in activity_requests: - worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) await asyncio.sleep(1.0) orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') activity_count = sum(1 for t, _ in stub.completed if t == 'activity') assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" - worker._async_worker_manager._shutdown = True + grpc_worker._async_worker_manager._shutdown = True await worker_task asyncio.run(run_test()) asyncio.run(run_test()) From 38eeabcb796b13d827e42db3715e33b568681d23 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:55:48 -0700 Subject: [PATCH 13/18] fix github workflow pytest --- .github/workflows/pr-validation.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index dddcc53..1d14d83 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -25,11 +25,12 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install durabletask dependencies + - name: Install durabletask dependencies and the library itself in editable mode run: | python -m pip install --upgrade pip pip install flake8 pytest pip install -r requirements.txt + pip install -e . - name: Install durabletask-azuremanaged dependencies working-directory: examples/dts run: | From cb9695ec0508c5c489bc5c5c62a74fa7c43a9bf5 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 17:56:00 -0700 Subject: [PATCH 14/18] cleanup tests --- tests/durabletask/test_concurrency_options.py | 28 ++++++++-------- .../test_worker_concurrency_loop.py | 33 ++++++++++--------- .../test_worker_concurrency_loop_async.py | 6 ++-- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/tests/durabletask/test_concurrency_options.py b/tests/durabletask/test_concurrency_options.py index 160c3ac..b49b7ec 100644 --- a/tests/durabletask/test_concurrency_options.py +++ b/tests/durabletask/test_concurrency_options.py @@ -3,12 +3,12 @@ import os -from durabletask import worker +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker def test_default_concurrency_options(): """Test that default concurrency options work correctly.""" - options = worker.ConcurrencyOptions() + options = ConcurrencyOptions() processor_count = os.cpu_count() or 1 expected_default = 100 * processor_count expected_workers = processor_count + 4 @@ -20,7 +20,7 @@ def test_default_concurrency_options(): def test_custom_concurrency_options(): """Test that custom concurrency options work correctly.""" - options = worker.ConcurrencyOptions( + options = ConcurrencyOptions( maximum_concurrent_activity_work_items=50, maximum_concurrent_orchestration_work_items=25, maximum_thread_pool_workers=30, @@ -37,7 +37,7 @@ def test_partial_custom_options(): expected_default = 100 * processor_count expected_workers = processor_count + 4 - options = worker.ConcurrencyOptions( + options = ConcurrencyOptions( maximum_concurrent_activity_work_items=30 ) @@ -48,44 +48,44 @@ def test_partial_custom_options(): def test_worker_with_concurrency_options(): """Test that TaskHubGrpcWorker accepts concurrency options.""" - options = worker.ConcurrencyOptions( + options = ConcurrencyOptions( maximum_concurrent_activity_work_items=10, maximum_concurrent_orchestration_work_items=20, maximum_thread_pool_workers=15, ) - grpc_worker = worker.TaskHubGrpcWorker(concurrency_options=options) + worker = TaskHubGrpcWorker(concurrency_options=options) - assert grpc_worker.concurrency_options == options + assert worker.concurrency_options == options def test_worker_default_options(): """Test that TaskHubGrpcWorker uses default options when no parameters are provided.""" - grpc_worker = worker.TaskHubGrpcWorker() + worker = TaskHubGrpcWorker() processor_count = os.cpu_count() or 1 expected_default = 100 * processor_count expected_workers = processor_count + 4 assert ( - grpc_worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default + worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default ) assert ( - grpc_worker.concurrency_options.maximum_concurrent_orchestration_work_items == expected_default + worker.concurrency_options.maximum_concurrent_orchestration_work_items == expected_default ) - assert grpc_worker.concurrency_options.maximum_thread_pool_workers == expected_workers + assert worker.concurrency_options.maximum_thread_pool_workers == expected_workers def test_concurrency_options_property_access(): """Test that the concurrency_options property works correctly.""" - options = worker.ConcurrencyOptions( + options = ConcurrencyOptions( maximum_concurrent_activity_work_items=15, maximum_concurrent_orchestration_work_items=25, maximum_thread_pool_workers=30, ) - grpc_worker = worker.TaskHubGrpcWorker(concurrency_options=options) - retrieved_options = grpc_worker.concurrency_options + worker = TaskHubGrpcWorker(concurrency_options=options) + retrieved_options = worker.concurrency_options # Should be the same object assert retrieved_options is options diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py index ef0f291..de6753b 100644 --- a/tests/durabletask/test_worker_concurrency_loop.py +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -2,7 +2,7 @@ import threading import time -from durabletask import worker +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker class DummyStub: @@ -15,6 +15,7 @@ def CompleteOrchestratorTask(self, res): def CompleteActivityTask(self, res): self.completed.append(('activity', res)) + class DummyRequest: def __init__(self, kind, instance_id): self.kind = kind @@ -33,16 +34,18 @@ def HasField(self, field): def WhichOneof(self, _): return f'{self.kind}Request' + class DummyCompletionToken: pass + def test_worker_concurrency_loop_sync(): - options = worker.ConcurrencyOptions( + options = ConcurrencyOptions( maximum_concurrent_activity_work_items=2, maximum_concurrent_orchestration_work_items=1, maximum_thread_pool_workers=2, ) - grpc_worker = worker.TaskHubGrpcWorker(concurrency_options=options) + worker = TaskHubGrpcWorker(concurrency_options=options) stub = DummyStub() def dummy_orchestrator(req, stub, completionToken): @@ -54,39 +57,42 @@ def dummy_activity(req, stub, completionToken): stub.CompleteActivityTask('ok') # Patch the worker's _execute_orchestrator and _execute_activity - grpc_worker._execute_orchestrator = dummy_orchestrator - grpc_worker._execute_activity = dummy_activity + worker._execute_orchestrator = dummy_orchestrator + worker._execute_activity = dummy_activity orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] async def run_test(): # Start the worker manager's run loop in the background - worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) + worker_task = asyncio.create_task(worker._async_worker_manager.run()) for req in orchestrator_requests: - grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) for req in activity_requests: - grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) await asyncio.sleep(1.0) orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') activity_count = sum(1 for t, _ in stub.completed if t == 'activity') assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" - grpc_worker._async_worker_manager._shutdown = True + worker._async_worker_manager._shutdown = True await worker_task asyncio.run(run_test()) + # Dummy orchestrator and activity for sync context def dummy_orchestrator(ctx, input): # Simulate some work time.sleep(0.1) return "orchestrator-done" + def dummy_activity(ctx, input): # Simulate some work time.sleep(0.1) return "activity-done" + def test_worker_concurrency_sync(): # Use small concurrency to make test observable options = ConcurrencyOptions( @@ -124,16 +130,11 @@ def run_manager(): t = threading.Thread(target=run_manager) t.start() time.sleep(1.5) # Let work process - - # Signal shutdown but don't actually call shutdown() yet - manager._shutdown = True + manager.shutdown() # Unblock the consumers by putting dummy items in the queues manager.activity_queue.put_nowait((lambda: None, (), {})) manager.orchestration_queue.put_nowait((lambda: None, (), {})) t.join(timeout=2) - # Now shutdown the thread pool - manager.thread_pool.shutdown(wait=True) - # Check that all work items completed - assert len(results) == 10 \ No newline at end of file + assert len(results) == 10 diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py index 70ff2ad..c7ba238 100644 --- a/tests/durabletask/test_worker_concurrency_loop_async.py +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -1,6 +1,6 @@ import asyncio -from durabletask import worker +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker class DummyStub: @@ -38,12 +38,12 @@ class DummyCompletionToken: def test_worker_concurrency_loop_async(): - options = worker.ConcurrencyOptions( + options = ConcurrencyOptions( maximum_concurrent_activity_work_items=2, maximum_concurrent_orchestration_work_items=1, maximum_thread_pool_workers=2, ) - grpc_worker = worker.TaskHubGrpcWorker(concurrency_options=options) + grpc_worker = TaskHubGrpcWorker(concurrency_options=options) stub = DummyStub() async def dummy_orchestrator(req, stub, completionToken): From fed58ced5439e8bd5955cc4c03bad90d0fc69f85 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 18:09:41 -0700 Subject: [PATCH 15/18] Python 3.9 specific test fix --- durabletask/worker.py | 95 +++++++++++++++++++++++++++++-------------- 1 file changed, 65 insertions(+), 30 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index ecffd18..b4f974c 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -210,6 +210,7 @@ class TaskHubGrpcWorker: ActivityNotRegisteredError: If an activity work item references an unregistered activity function. """ + _response_stream: Optional[grpc.Future] = None _interceptors: Optional[list[shared.ClientInterceptor]] = None @@ -1323,15 +1324,13 @@ def __init__(self, concurrency_options: ConcurrencyOptions): self.concurrency_options = concurrency_options self.activity_semaphore = None self.orchestration_semaphore = None - self.activity_queue: asyncio.Queue = asyncio.Queue() - self.orchestration_queue: asyncio.Queue = asyncio.Queue() + # Don't create queues here - defer until we have an event loop + self.activity_queue: Optional[asyncio.Queue] = None + self.orchestration_queue: Optional[asyncio.Queue] = None self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None - # Try to capture the current event loop when queues are created - try: - self._queue_event_loop = asyncio.get_running_loop() - except RuntimeError: - # No event loop running when manager was created - pass + # Store work items when no event loop is available + self._pending_activity_work: list = [] + self._pending_orchestration_work: list = [] self.thread_pool = ThreadPoolExecutor( max_workers=concurrency_options.maximum_thread_pool_workers, thread_name_prefix="DurableTask", @@ -1346,26 +1345,30 @@ def _ensure_queues_for_current_loop(self): # No event loop running, can't create queues return - if self._queue_event_loop is current_loop and hasattr(self, 'activity_queue') and hasattr(self, 'orchestration_queue'): - # Queues are already bound to the current loop and exist - return + # Check if queues are already properly set up for current loop + if self._queue_event_loop is current_loop: + if self.activity_queue is not None and self.orchestration_queue is not None: + # Queues are already bound to the current loop and exist + return # Need to recreate queues for the current event loop # First, preserve any existing work items existing_activity_items = [] existing_orchestration_items = [] - if hasattr(self, 'activity_queue'): + if self.activity_queue is not None: try: while not self.activity_queue.empty(): existing_activity_items.append(self.activity_queue.get_nowait()) except Exception: pass - if hasattr(self, 'orchestration_queue'): + if self.orchestration_queue is not None: try: while not self.orchestration_queue.empty(): - existing_orchestration_items.append(self.orchestration_queue.get_nowait()) + existing_orchestration_items.append( + self.orchestration_queue.get_nowait() + ) except Exception: pass @@ -1380,6 +1383,16 @@ def _ensure_queues_for_current_loop(self): for item in existing_orchestration_items: self.orchestration_queue.put_nowait(item) + # Move pending work items to the queues + for item in self._pending_activity_work: + self.activity_queue.put_nowait(item) + for item in self._pending_orchestration_work: + self.orchestration_queue.put_nowait(item) + + # Clear the pending work lists + self._pending_activity_work.clear() + self._pending_orchestration_work.clear() + async def run(self): # Reset shutdown flag in case this manager is being reused self._shutdown = False @@ -1396,10 +1409,13 @@ async def run(self): ) # Start background consumers for each work type - await asyncio.gather( - self._consume_queue(self.activity_queue, self.activity_semaphore), - self._consume_queue(self.orchestration_queue, self.orchestration_semaphore), - ) + if self.activity_queue is not None and self.orchestration_queue is not None: + await asyncio.gather( + self._consume_queue(self.activity_queue, self.activity_semaphore), + self._consume_queue( + self.orchestration_queue, self.orchestration_semaphore + ), + ) async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): # List to track running tasks @@ -1421,10 +1437,14 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor func, args, kwargs = work # Create a concurrent task for processing - task = asyncio.create_task(self._process_work_item(semaphore, queue, func, args, kwargs)) + task = asyncio.create_task( + self._process_work_item(semaphore, queue, func, args, kwargs) + ) running_tasks.add(task) - async def _process_work_item(self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs): + async def _process_work_item( + self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs + ): async with semaphore: try: await self._run_func(func, *args, **kwargs) @@ -1437,17 +1457,32 @@ async def _run_func(self, func, *args, **kwargs): else: loop = asyncio.get_running_loop() # Avoid submitting to executor after shutdown - if getattr(self, '_shutdown', False) and getattr(self, 'thread_pool', None) and getattr(self.thread_pool, '_shutdown', False): + if ( + getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr( + self.thread_pool, "_shutdown", False) + ): return None - return await loop.run_in_executor(self.thread_pool, lambda: func(*args, **kwargs)) + return await loop.run_in_executor( + self.thread_pool, lambda: func(*args, **kwargs) + ) def submit_activity(self, func, *args, **kwargs): + work_item = (func, args, kwargs) self._ensure_queues_for_current_loop() - self.activity_queue.put_nowait((func, args, kwargs)) + if self.activity_queue is not None: + self.activity_queue.put_nowait(work_item) + else: + # No event loop running, store in pending list + self._pending_activity_work.append(work_item) def submit_orchestration(self, func, *args, **kwargs): + work_item = (func, args, kwargs) self._ensure_queues_for_current_loop() - self.orchestration_queue.put_nowait((func, args, kwargs)) + if self.orchestration_queue is not None: + self.orchestration_queue.put_nowait(work_item) + else: + # No event loop running, store in pending list + self._pending_orchestration_work.append(work_item) def shutdown(self): self._shutdown = True @@ -1457,7 +1492,7 @@ def reset_for_new_run(self): """Reset the manager state for a new run.""" self._shutdown = False # Clear any existing queues - they'll be recreated when needed - if hasattr(self, 'activity_queue'): + if self.activity_queue is not None: # Clear existing queue by creating a new one # This ensures no items from previous runs remain try: @@ -1465,16 +1500,16 @@ def reset_for_new_run(self): self.activity_queue.get_nowait() except Exception: pass - if hasattr(self, 'orchestration_queue'): + if self.orchestration_queue is not None: try: while not self.orchestration_queue.empty(): self.orchestration_queue.get_nowait() except Exception: pass + # Clear pending work lists + self._pending_activity_work.clear() + self._pending_orchestration_work.clear() # Export public API -__all__ = [ - 'ConcurrencyOptions', - 'TaskHubGrpcWorker' -] +__all__ = ["ConcurrencyOptions", "TaskHubGrpcWorker"] From 2156df6b58fa8484f7647770d02eb2fc51f43843 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 18:38:04 -0700 Subject: [PATCH 16/18] fixup reconnection for new concurrency model --- durabletask/worker.py | 44 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index b4f974c..4f36c5d 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -300,6 +300,7 @@ async def _async_run_loop(self): # Connection state management for retry fix current_channel = None current_stub = None + current_reader_thread = None conn_retry_count = 0 conn_max_retry_delay = 60 @@ -327,7 +328,26 @@ def create_fresh_connection(): raise def invalidate_connection(): - nonlocal current_channel, current_stub + nonlocal current_channel, current_stub, current_reader_thread + # Cancel the response stream first to signal the reader thread to stop + if self._response_stream is not None: + try: + self._response_stream.cancel() + except Exception: + pass + self._response_stream = None + + # Wait for the reader thread to finish + if current_reader_thread is not None: + try: + current_reader_thread.join(timeout=2) + if current_reader_thread.is_alive(): + self._logger.warning("Stream reader thread did not shut down gracefully") + except Exception: + pass + current_reader_thread = None + + # Close the channel if current_channel: try: current_channel.close() @@ -389,8 +409,8 @@ def stream_reader(): import threading - reader_thread = threading.Thread(target=stream_reader, daemon=True) - reader_thread.start() + current_reader_thread = threading.Thread(target=stream_reader, daemon=True) + current_reader_thread.start() loop = asyncio.get_running_loop() while not self._shutdown.is_set(): try: @@ -423,21 +443,29 @@ def stream_reader(): ) except Exception as e: self._logger.warning(f"Error in work item stream: {e}") - break - reader_thread.join(timeout=1) + raise e + current_reader_thread.join(timeout=1) self._logger.info("Work item stream ended normally") except grpc.RpcError as rpc_error: should_invalidate = should_invalidate_connection(rpc_error) if should_invalidate: invalidate_connection() error_code = rpc_error.code() # type: ignore + error_details = str(rpc_error) + if error_code == grpc.StatusCode.CANCELLED: self._logger.info(f"Disconnected from {self._host_address}") break elif error_code == grpc.StatusCode.UNAVAILABLE: - self._logger.warning( - f"The sidecar at address {self._host_address} is unavailable - will continue retrying" - ) + # Check if this is a connection timeout scenario + if "Timeout occurred" in error_details or "Failed to connect to remote host" in error_details: + self._logger.warning( + f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection" + ) + else: + self._logger.warning( + f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying" + ) elif should_invalidate: self._logger.warning( f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection" From 00f8d1cd01e97975c3b7a7c77a62ce01894b8fe1 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 30 May 2025 18:51:13 -0700 Subject: [PATCH 17/18] autopep8 --- durabletask/worker.py | 138 +++++++++++++++++++++--------------------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index 4f36c5d..9e443e0 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -35,10 +35,10 @@ class ConcurrencyOptions: """ def __init__( - self, - maximum_concurrent_activity_work_items: Optional[int] = None, - maximum_concurrent_orchestration_work_items: Optional[int] = None, - maximum_thread_pool_workers: Optional[int] = None, + self, + maximum_concurrent_activity_work_items: Optional[int] = None, + maximum_concurrent_orchestration_work_items: Optional[int] = None, + maximum_thread_pool_workers: Optional[int] = None, ): """Initialize concurrency options. @@ -167,7 +167,7 @@ class TaskHubGrpcWorker: Example: Basic worker setup: - >>> from durabletask import TaskHubGrpcWorker, ConcurrencyOptions + >>> from durabletask.worker import TaskHubGrpcWorker, ConcurrencyOptions >>> >>> # Create worker with custom concurrency settings >>> concurrency = ConcurrencyOptions( @@ -215,15 +215,15 @@ class TaskHubGrpcWorker: _interceptors: Optional[list[shared.ClientInterceptor]] = None def __init__( - self, - *, - host_address: Optional[str] = None, - metadata: Optional[list[tuple[str, str]]] = None, - log_handler=None, - log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False, - interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, - concurrency_options: Optional[ConcurrencyOptions] = None, + self, + *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler=None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, + concurrency_options: Optional[ConcurrencyOptions] = None, ): self._registry = _Registry() self._host_address = ( @@ -500,10 +500,10 @@ def stop(self): self._is_running = False def _execute_orchestrator( - self, - req: pb.OrchestratorRequest, - stub: stubs.TaskHubSidecarServiceStub, - completionToken, + self, + req: pb.OrchestratorRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, ): try: executor = _OrchestrationExecutor(self._registry, self._logger) @@ -538,10 +538,10 @@ def _execute_orchestrator( ) def _execute_activity( - self, - req: pb.ActivityRequest, - stub: stubs.TaskHubSidecarServiceStub, - completionToken, + self, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, ): instance_id = req.orchestrationInstance.instanceId try: @@ -626,10 +626,10 @@ def resume(self): self._previous_task = next_task def set_complete( - self, - result: Any, - status: pb.OrchestrationStatus, - is_result_encoded: bool = False, + self, + result: Any, + status: pb.OrchestrationStatus, + is_result_encoded: bool = False, ): if self._is_complete: return @@ -731,9 +731,9 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: return self.create_timer_internal(fire_at) def create_timer_internal( - self, - fire_at: Union[datetime, timedelta], - retryable_task: Optional[task.RetryableTask] = None, + self, + fire_at: Union[datetime, timedelta], + retryable_task: Optional[task.RetryableTask] = None, ) -> task.Task: id = self.next_sequence_number() if isinstance(fire_at, timedelta): @@ -748,11 +748,11 @@ def create_timer_internal( return timer_task def call_activity( - self, - activity: Union[task.Activity[TInput, TOutput], str], - *, - input: Optional[TInput] = None, - retry_policy: Optional[task.RetryPolicy] = None, + self, + activity: Union[task.Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None, ) -> task.Task[TOutput]: id = self.next_sequence_number() @@ -762,12 +762,12 @@ def call_activity( return self._pending_tasks.get(id, task.CompletableTask()) def call_sub_orchestrator( - self, - orchestrator: task.Orchestrator[TInput, TOutput], - *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - retry_policy: Optional[task.RetryPolicy] = None, + self, + orchestrator: task.Orchestrator[TInput, TOutput], + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[task.RetryPolicy] = None, ) -> task.Task[TOutput]: id = self.next_sequence_number() orchestrator_name = task.get_name(orchestrator) @@ -782,15 +782,15 @@ def call_sub_orchestrator( return self._pending_tasks.get(id, task.CompletableTask()) def call_activity_function_helper( - self, - id: Optional[int], - activity_function: Union[task.Activity[TInput, TOutput], str], - *, - input: Optional[TInput] = None, - retry_policy: Optional[task.RetryPolicy] = None, - is_sub_orch: bool = False, - instance_id: Optional[str] = None, - fn_task: Optional[task.CompletableTask[TOutput]] = None, + self, + id: Optional[int], + activity_function: Union[task.Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None, + is_sub_orch: bool = False, + instance_id: Optional[str] = None, + fn_task: Optional[task.CompletableTask[TOutput]] = None, ): if id is None: id = self.next_sequence_number() @@ -865,7 +865,7 @@ class ExecutionResults: encoded_custom_status: Optional[str] def __init__( - self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str] + self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str] ): self.actions = actions self.encoded_custom_status = encoded_custom_status @@ -881,10 +881,10 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._suspended_events: list[pb.HistoryEvent] = [] def execute( - self, - instance_id: str, - old_events: Sequence[pb.HistoryEvent], - new_events: Sequence[pb.HistoryEvent], + self, + instance_id: str, + old_events: Sequence[pb.HistoryEvent], + new_events: Sequence[pb.HistoryEvent], ) -> ExecutionResults: if not new_events: raise task.OrchestrationStateError( @@ -922,7 +922,7 @@ def execute( f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding." ) elif ( - ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW + ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW ): completion_status_str = pbh.get_orchestration_status_str( ctx._completion_status @@ -941,7 +941,7 @@ def execute( ) def process_event( - self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent + self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent ) -> None: if self._is_suspended and _is_suspendable(event): # We are suspended, so we need to buffer this event until we are resumed @@ -963,7 +963,7 @@ def process_event( # deserialize the input, if any input = None if ( - event.executionStarted.input is not None and event.executionStarted.input.value != "" + event.executionStarted.input is not None and event.executionStarted.input.value != "" ): input = shared.from_json(event.executionStarted.input.value) @@ -1105,7 +1105,7 @@ def process_event( task_id, expected_method_name, action ) elif ( - action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name + action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name ): raise _get_wrong_action_name_error( task_id, @@ -1229,11 +1229,11 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._logger = logger def execute( - self, - orchestration_id: str, - name: str, - task_id: int, - encoded_input: Optional[str], + self, + orchestration_id: str, + name: str, + task_id: int, + encoded_input: Optional[str], ) -> Optional[str]: """Executes an activity function and returns the serialized result, if any.""" self._logger.debug( @@ -1262,7 +1262,7 @@ def execute( def _get_non_determinism_error( - task_id: int, action_name: str + task_id: int, action_name: str ) -> task.NonDeterminismError: return task.NonDeterminismError( f"A previous execution called {action_name} with ID={task_id}, but the current " @@ -1273,7 +1273,7 @@ def _get_non_determinism_error( def _get_wrong_action_type_error( - task_id: int, expected_method_name: str, action: pb.OrchestratorAction + task_id: int, expected_method_name: str, action: pb.OrchestratorAction ) -> task.NonDeterminismError: unexpected_method_name = _get_method_name_for_action(action) return task.NonDeterminismError( @@ -1286,7 +1286,7 @@ def _get_wrong_action_type_error( def _get_wrong_action_name_error( - task_id: int, method_name: str, expected_task_name: str, actual_task_name: str + task_id: int, method_name: str, expected_task_name: str, actual_task_name: str ) -> task.NonDeterminismError: return task.NonDeterminismError( f"Failed to restore orchestration state due to a history mismatch: A previous execution called " @@ -1471,7 +1471,7 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor running_tasks.add(task) async def _process_work_item( - self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs + self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs ): async with semaphore: try: @@ -1486,8 +1486,8 @@ async def _run_func(self, func, *args, **kwargs): loop = asyncio.get_running_loop() # Avoid submitting to executor after shutdown if ( - getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr( - self.thread_pool, "_shutdown", False) + getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr( + self.thread_pool, "_shutdown", False) ): return None return await loop.run_in_executor( From 305bbbd44c7b5be64b2394728a6c87fce1add8e1 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Mon, 2 Jun 2025 10:10:13 -0700 Subject: [PATCH 18/18] Remove existing duplicate import --- durabletask/worker.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index 9e443e0..b433a83 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -15,7 +15,6 @@ import grpc from google.protobuf import empty_pb2 -import durabletask.internal.helpers as pbh import durabletask.internal.helpers as ph import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs @@ -511,16 +510,16 @@ def _execute_orchestrator( res = pb.OrchestratorResponse( instanceId=req.instanceId, actions=result.actions, - customStatus=pbh.get_string_value(result.encoded_custom_status), + customStatus=ph.get_string_value(result.encoded_custom_status), completionToken=completionToken, ) except Exception as ex: self._logger.exception( f"An error occurred while trying to execute instance '{req.instanceId}': {ex}" ) - failure_details = pbh.new_failure_details(ex) + failure_details = ph.new_failure_details(ex) actions = [ - pbh.new_complete_orchestration_action( + ph.new_complete_orchestration_action( -1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details ) ] @@ -552,14 +551,14 @@ def _execute_activity( res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, - result=pbh.get_string_value(result), + result=ph.get_string_value(result), completionToken=completionToken, ) except Exception as ex: res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, - failureDetails=pbh.new_failure_details(ex), + failureDetails=ph.new_failure_details(ex), completionToken=completionToken, ) @@ -924,7 +923,7 @@ def execute( elif ( ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW ): - completion_status_str = pbh.get_orchestration_status_str( + completion_status_str = ph.get_orchestration_status_str( ctx._completion_status ) self._logger.info(