Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v0.2.0 (Unreleased)
## v0.3.0

### New

- Added configurable worker concurrency with `max_workers` parameter in `TaskHubGrpcWorker` constructor - allows customization of ThreadPoolExecutor size (default: 16 workers)

### Fixed

- Fixed an issue where a worker could not recover after its connection was interrupted or severed

## v0.2.1

### New

- Added `set_custom_status` orchestrator API ([#31](https://github.com/microsoft/durabletask-python/pull/31)) - contributed by [@famarting](https://github.com/famarting)
- Added `purge_orchestration` client API ([#34](https://github.com/microsoft/durabletask-python/pull/34)) - contributed by [@famarting](https://github.com/famarting)
- Added new `durabletask-azuremanaged` package for use with the [Durable Task Scheduler](https://techcommunity.microsoft.com/blog/appsonazureblog/announcing-limited-early-access-of-the-durable-task-scheduler-for-azure-durable-/4286526) - by [@RyanLettieri](https://github.com/RyanLettieri)
- Added new `durabletask-azuremanaged` package for use with the [Durable Task Scheduler](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/durable-task-scheduler) - by [@RyanLettieri](https://github.com/RyanLettieri)

### Changes

Expand Down
6 changes: 3 additions & 3 deletions durabletask-azuremanaged/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ build-backend = "setuptools.build_meta"

[project]
name = "durabletask.azuremanaged"
version = "0.1.4"
description = "Extensions for the Durable Task Python SDK for integrating with the Durable Task Scheduler in Azure"
version = "0.1.5"
description = "Durable Task Python SDK provider implementation for the Azure Durable Task Scheduler"
keywords = [
"durable",
"task",
Expand All @@ -26,7 +26,7 @@ requires-python = ">=3.9"
license = {file = "LICENSE"}
readme = "README.md"
dependencies = [
"durabletask>=0.2.1",
"durabletask>=0.3.0",
"azure-identity>=1.19.0"
]

Expand Down
163 changes: 135 additions & 28 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import concurrent.futures
import logging
import random
from datetime import datetime, timedelta
from threading import Event, Thread
from types import GeneratorType
Expand All @@ -11,8 +12,8 @@
import grpc
from google.protobuf import empty_pb2

import durabletask.internal.helpers as ph
import durabletask.internal.helpers as pbh
import durabletask.internal.helpers as ph
import durabletask.internal.orchestrator_service_pb2 as pb
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
import durabletask.internal.shared as shared
Expand Down Expand Up @@ -91,13 +92,15 @@ def __init__(self, *,
log_handler=None,
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
max_workers: Optional[int] = None):
self._registry = _Registry()
self._host_address = host_address if host_address else shared.get_default_host_address()
self._logger = shared.get_logger("worker", log_handler, log_formatter)
self._shutdown = Event()
self._is_running = False
self._secure_channel = secure_channel
self._max_workers = max_workers if max_workers is not None else 16

# Determine the interceptors to use
if interceptors is not None:
Expand Down Expand Up @@ -129,31 +132,117 @@ def add_activity(self, fn: task.Activity) -> str:

def start(self):
"""Starts the worker on a background thread and begins listening for work items."""
channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors)
stub = stubs.TaskHubSidecarServiceStub(channel)

if self._is_running:
raise RuntimeError('The worker is already running.')

def run_loop():
"""Enhanced run loop with better connection management and retry logic."""

# Connection state management for retry fix
current_channel: Optional[grpc.Channel] = None
current_stub: Optional[stubs.TaskHubSidecarServiceStub] = None
conn_retry_count = 0
conn_max_retry_delay = 60

def create_fresh_connection() -> None:
"""Create a new gRPC channel and stub, invalidating any existing ones.

Raises:
Exception: If connection creation or testing fails.
"""
nonlocal current_channel, current_stub, conn_retry_count

# Close existing connection if any
if current_channel:
try:
current_channel.close()
except Exception:
pass

current_channel = None
current_stub = None

try:
# Create new connection
current_channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors)
current_stub = stubs.TaskHubSidecarServiceStub(current_channel)

# Test the connection
current_stub.Hello(empty_pb2.Empty())
conn_retry_count = 0 # Reset on successful connection
self._logger.debug(f"Created fresh connection to {self._host_address}")

except Exception as e:
self._logger.debug(f"Failed to create connection: {e}")
current_channel = None
current_stub = None
raise # Re-raise the original exception

def invalidate_connection() -> None:
"""Mark current connection as invalid."""
nonlocal current_channel, current_stub
if current_channel:
try:
current_channel.close()
except Exception:
pass
current_channel = None
current_stub = None

def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool:
"""Determine if a gRPC error should trigger connection invalidation.

Connection-level errors (network, authentication, server unavailable)
should invalidate the connection, while application-level errors
(bad requests, not found, etc.) should not.
"""
error_code = rpc_error.code() # type: ignore

# Connection-level errors that warrant invalidation
connection_level_errors = {
grpc.StatusCode.UNAVAILABLE, # Server down/unreachable
grpc.StatusCode.DEADLINE_EXCEEDED, # Timeout, likely network issue
grpc.StatusCode.CANCELLED, # Connection cancelled
grpc.StatusCode.UNAUTHENTICATED, # Auth failed, may need new connection
grpc.StatusCode.ABORTED, # Transaction aborted, connection may be bad
}

return error_code in connection_level_errors

# TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity
# functions. We'd need to know ahead of time whether a function is async or not.
# TODO: Max concurrency configuration settings
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="DurableTask") as executor:
while not self._shutdown.is_set():
try:
# send a "Hello" message to the sidecar to ensure that it's listening
stub.Hello(empty_pb2.Empty())
# Ensure we have a valid connection before attempting work
if current_stub is None:
try:
create_fresh_connection()
except Exception:
# Connection failed, implement exponential backoff
conn_retry_count += 1
delay = min(conn_max_retry_delay, (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1))
self._logger.warning(f'Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})')
if self._shutdown.wait(delay):
break # Shutdown requested during wait
continue

# stream work items
try:
# Stream work items with the current connection
# Type assertion since we know current_stub is not None at this point
assert current_stub is not None, "current_stub should not be None at this point"
stub = current_stub # Local reference for type safety
self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest())
self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...')

# The stream blocks until either a work item is received or the stream is canceled
# by another thread (see the stop() method).
# Process work items concurrently as they arrive
for work_item in self._response_stream: # type: ignore
if self._shutdown.is_set():
break

request_type = work_item.WhichOneof('request')
self._logger.debug(f'Received "{request_type}" work item')

# Submit work items to thread pool for concurrent processing
if work_item.HasField('orchestratorRequest'):
executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken)
elif work_item.HasField('activityRequest'):
Expand All @@ -163,21 +252,39 @@ def run_loop():
else:
self._logger.warning(f'Unexpected work item type: {request_type}')

# Stream ended normally (shouldn't happen unless server closes)
self._logger.info("Work item stream ended normally")

except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore
# Intelligently decide whether to invalidate connection based on error type
should_invalidate = should_invalidate_connection(rpc_error)
if should_invalidate:
invalidate_connection()

error_code = rpc_error.code() # type: ignore
if error_code == grpc.StatusCode.CANCELLED:
self._logger.info(f'Disconnected from {self._host_address}')
elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
self._logger.warning(
f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
break # Likely shutdown
elif error_code == grpc.StatusCode.UNAVAILABLE:
self._logger.warning(f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
elif should_invalidate:
self._logger.warning(f'Connection-level gRPC error ({error_code}): {rpc_error} - invalidating connection')
else:
self._logger.warning(f'Unexpected error: {rpc_error}')
self._logger.warning(f'Application-level gRPC error ({error_code}): {rpc_error} - keeping connection')

# Brief pause before retry
self._shutdown.wait(1)

except Exception as ex:
# Unexpected error, invalidate connection and retry
invalidate_connection()
self._logger.warning(f'Unexpected error: {ex}')
self._shutdown.wait(1)

# CONSIDER: exponential backoff
self._shutdown.wait(5)
self._logger.info("No longer listening for work items")
return
# Final cleanup
invalidate_connection()

self._logger.info("No longer listening for work items")

self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
self._runLoop = Thread(target=run_loop)
Expand Down Expand Up @@ -367,14 +474,14 @@ def instance_id(self) -> str:
def current_utc_datetime(self) -> datetime:
return self._current_utc_datetime

@property
def is_replaying(self) -> bool:
return self._is_replaying

@current_utc_datetime.setter
def current_utc_datetime(self, value: datetime):
self._current_utc_datetime = value

@property
def is_replaying(self) -> bool:
return self._is_replaying

def set_custom_status(self, custom_status: Any) -> None:
self._encoded_custom_status = shared.to_json(custom_status) if custom_status is not None else None

Expand All @@ -389,7 +496,7 @@ def create_timer_internal(self, fire_at: Union[datetime, timedelta],
action = ph.new_create_timer_action(id, fire_at)
self._pending_actions[id] = action

timer_task = task.TimerTask()
timer_task: task.TimerTask = task.TimerTask()
if retryable_task is not None:
timer_task.set_retryable_parent(retryable_task)
self._pending_tasks[id] = timer_task
Expand Down Expand Up @@ -457,7 +564,7 @@ def wait_for_external_event(self, name: str) -> task.Task:
# event with the given name so that we can resume the generator when it
# arrives. If there are multiple events with the same name, we return
# them in the order they were received.
external_event_task = task.CompletableTask()
external_event_task: task.CompletableTask = task.CompletableTask()
event_name = name.casefold()
event_list = self._received_events.get(event_name, None)
if event_list:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "durabletask"
version = "0.2.1"
version = "0.3.0"
description = "A Durable Task Client SDK for Python"
keywords = [
"durable",
Expand Down
Loading