Skip to content

Commit 03d27bd

Browse files
committed
Reconnect upon connection error
1 parent 1459a55 commit 03d27bd

File tree

4 files changed

+151
-34
lines changed

4 files changed

+151
-34
lines changed

CHANGELOG.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,23 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## v0.2.0 (Unreleased)
8+
## v0.3.0
9+
10+
### New
11+
12+
- Added configurable worker concurrency with `max_workers` parameter in `TaskHubGrpcWorker` constructor - allows customization of ThreadPoolExecutor size (default: 16 workers)
13+
14+
### Fixed
15+
16+
- Fixed an issue where a worker could not recover after its connection was interrupted or severed
17+
18+
## v0.2.1
919

1020
### New
1121

1222
- Added `set_custom_status` orchestrator API ([#31](https://github.com/microsoft/durabletask-python/pull/31)) - contributed by [@famarting](https://github.com/famarting)
1323
- Added `purge_orchestration` client API ([#34](https://github.com/microsoft/durabletask-python/pull/34)) - contributed by [@famarting](https://github.com/famarting)
14-
- Added new `durabletask-azuremanaged` package for use with the [Durable Task Scheduler](https://techcommunity.microsoft.com/blog/appsonazureblog/announcing-limited-early-access-of-the-durable-task-scheduler-for-azure-durable-/4286526) - by [@RyanLettieri](https://github.com/RyanLettieri)
24+
- Added new `durabletask-azuremanaged` package for use with the [Durable Task Scheduler](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/durable-task-scheduler) - by [@RyanLettieri](https://github.com/RyanLettieri)
1525

1626
### Changes
1727

durabletask-azuremanaged/pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ build-backend = "setuptools.build_meta"
99

1010
[project]
1111
name = "durabletask.azuremanaged"
12-
version = "0.1.4"
13-
description = "Extensions for the Durable Task Python SDK for integrating with the Durable Task Scheduler in Azure"
12+
version = "0.1.5"
13+
description = "Durable Task Python SDK provider implementation for the Azure Durable Task Scheduler"
1414
keywords = [
1515
"durable",
1616
"task",
@@ -26,7 +26,7 @@ requires-python = ">=3.9"
2626
license = {file = "LICENSE"}
2727
readme = "README.md"
2828
dependencies = [
29-
"durabletask>=0.2.1",
29+
"durabletask>=0.3.0",
3030
"azure-identity>=1.19.0"
3131
]
3232

durabletask/worker.py

Lines changed: 135 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import concurrent.futures
55
import logging
6+
import random
67
from datetime import datetime, timedelta
78
from threading import Event, Thread
89
from types import GeneratorType
@@ -11,8 +12,8 @@
1112
import grpc
1213
from google.protobuf import empty_pb2
1314

14-
import durabletask.internal.helpers as ph
1515
import durabletask.internal.helpers as pbh
16+
import durabletask.internal.helpers as ph
1617
import durabletask.internal.orchestrator_service_pb2 as pb
1718
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
1819
import durabletask.internal.shared as shared
@@ -91,13 +92,15 @@ def __init__(self, *,
9192
log_handler=None,
9293
log_formatter: Optional[logging.Formatter] = None,
9394
secure_channel: bool = False,
94-
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
95+
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
96+
max_workers: Optional[int] = None):
9597
self._registry = _Registry()
9698
self._host_address = host_address if host_address else shared.get_default_host_address()
9799
self._logger = shared.get_logger("worker", log_handler, log_formatter)
98100
self._shutdown = Event()
99101
self._is_running = False
100102
self._secure_channel = secure_channel
103+
self._max_workers = max_workers if max_workers is not None else 16
101104

102105
# Determine the interceptors to use
103106
if interceptors is not None:
@@ -129,31 +132,117 @@ def add_activity(self, fn: task.Activity) -> str:
129132

130133
def start(self):
131134
"""Starts the worker on a background thread and begins listening for work items."""
132-
channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors)
133-
stub = stubs.TaskHubSidecarServiceStub(channel)
134-
135135
if self._is_running:
136136
raise RuntimeError('The worker is already running.')
137137

138138
def run_loop():
139+
"""Enhanced run loop with better connection management and retry logic."""
140+
141+
# Connection state management for retry fix
142+
current_channel: Optional[grpc.Channel] = None
143+
current_stub: Optional[stubs.TaskHubSidecarServiceStub] = None
144+
conn_retry_count = 0
145+
conn_max_retry_delay = 60
146+
147+
def create_fresh_connection() -> None:
148+
"""Create a new gRPC channel and stub, invalidating any existing ones.
149+
150+
Raises:
151+
Exception: If connection creation or testing fails.
152+
"""
153+
nonlocal current_channel, current_stub, conn_retry_count
154+
155+
# Close existing connection if any
156+
if current_channel:
157+
try:
158+
current_channel.close()
159+
except Exception:
160+
pass
161+
162+
current_channel = None
163+
current_stub = None
164+
165+
try:
166+
# Create new connection
167+
current_channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors)
168+
current_stub = stubs.TaskHubSidecarServiceStub(current_channel)
169+
170+
# Test the connection
171+
current_stub.Hello(empty_pb2.Empty())
172+
conn_retry_count = 0 # Reset on successful connection
173+
self._logger.debug(f"Created fresh connection to {self._host_address}")
174+
175+
except Exception as e:
176+
self._logger.debug(f"Failed to create connection: {e}")
177+
current_channel = None
178+
current_stub = None
179+
raise # Re-raise the original exception
180+
181+
def invalidate_connection() -> None:
182+
"""Mark current connection as invalid."""
183+
nonlocal current_channel, current_stub
184+
if current_channel:
185+
try:
186+
current_channel.close()
187+
except Exception:
188+
pass
189+
current_channel = None
190+
current_stub = None
191+
192+
def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool:
193+
"""Determine if a gRPC error should trigger connection invalidation.
194+
195+
Connection-level errors (network, authentication, server unavailable)
196+
should invalidate the connection, while application-level errors
197+
(bad requests, not found, etc.) should not.
198+
"""
199+
error_code = rpc_error.code() # type: ignore
200+
201+
# Connection-level errors that warrant invalidation
202+
connection_level_errors = {
203+
grpc.StatusCode.UNAVAILABLE, # Server down/unreachable
204+
grpc.StatusCode.DEADLINE_EXCEEDED, # Timeout, likely network issue
205+
grpc.StatusCode.CANCELLED, # Connection cancelled
206+
grpc.StatusCode.UNAUTHENTICATED, # Auth failed, may need new connection
207+
grpc.StatusCode.ABORTED, # Transaction aborted, connection may be bad
208+
}
209+
210+
return error_code in connection_level_errors
211+
139212
# TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity
140213
# functions. We'd need to know ahead of time whether a function is async or not.
141-
# TODO: Max concurrency configuration settings
142-
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
214+
with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="DurableTask") as executor:
143215
while not self._shutdown.is_set():
144-
try:
145-
# send a "Hello" message to the sidecar to ensure that it's listening
146-
stub.Hello(empty_pb2.Empty())
216+
# Ensure we have a valid connection before attempting work
217+
if current_stub is None:
218+
try:
219+
create_fresh_connection()
220+
except Exception:
221+
# Connection failed, implement exponential backoff
222+
conn_retry_count += 1
223+
delay = min(conn_max_retry_delay, (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1))
224+
self._logger.warning(f'Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})')
225+
if self._shutdown.wait(delay):
226+
break # Shutdown requested during wait
227+
continue
147228

148-
# stream work items
229+
try:
230+
# Stream work items with the current connection
231+
# Type assertion since we know current_stub is not None at this point
232+
assert current_stub is not None, "current_stub should not be None at this point"
233+
stub = current_stub # Local reference for type safety
149234
self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest())
150235
self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...')
151236

152-
# The stream blocks until either a work item is received or the stream is canceled
153-
# by another thread (see the stop() method).
237+
# Process work items concurrently as they arrive
154238
for work_item in self._response_stream: # type: ignore
239+
if self._shutdown.is_set():
240+
break
241+
155242
request_type = work_item.WhichOneof('request')
156243
self._logger.debug(f'Received "{request_type}" work item')
244+
245+
# Submit work items to thread pool for concurrent processing
157246
if work_item.HasField('orchestratorRequest'):
158247
executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken)
159248
elif work_item.HasField('activityRequest'):
@@ -163,21 +252,39 @@ def run_loop():
163252
else:
164253
self._logger.warning(f'Unexpected work item type: {request_type}')
165254

255+
# Stream ended normally (shouldn't happen unless server closes)
256+
self._logger.info("Work item stream ended normally")
257+
166258
except grpc.RpcError as rpc_error:
167-
if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore
259+
# Intelligently decide whether to invalidate connection based on error type
260+
should_invalidate = should_invalidate_connection(rpc_error)
261+
if should_invalidate:
262+
invalidate_connection()
263+
264+
error_code = rpc_error.code() # type: ignore
265+
if error_code == grpc.StatusCode.CANCELLED:
168266
self._logger.info(f'Disconnected from {self._host_address}')
169-
elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
170-
self._logger.warning(
171-
f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
267+
break # Likely shutdown
268+
elif error_code == grpc.StatusCode.UNAVAILABLE:
269+
self._logger.warning(f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
270+
elif should_invalidate:
271+
self._logger.warning(f'Connection-level gRPC error ({error_code}): {rpc_error} - invalidating connection')
172272
else:
173-
self._logger.warning(f'Unexpected error: {rpc_error}')
273+
self._logger.warning(f'Application-level gRPC error ({error_code}): {rpc_error} - keeping connection')
274+
275+
# Brief pause before retry
276+
self._shutdown.wait(1)
277+
174278
except Exception as ex:
279+
# Unexpected error, invalidate connection and retry
280+
invalidate_connection()
175281
self._logger.warning(f'Unexpected error: {ex}')
282+
self._shutdown.wait(1)
176283

177-
# CONSIDER: exponential backoff
178-
self._shutdown.wait(5)
179-
self._logger.info("No longer listening for work items")
180-
return
284+
# Final cleanup
285+
invalidate_connection()
286+
287+
self._logger.info("No longer listening for work items")
181288

182289
self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
183290
self._runLoop = Thread(target=run_loop)
@@ -367,14 +474,14 @@ def instance_id(self) -> str:
367474
def current_utc_datetime(self) -> datetime:
368475
return self._current_utc_datetime
369476

370-
@property
371-
def is_replaying(self) -> bool:
372-
return self._is_replaying
373-
374477
@current_utc_datetime.setter
375478
def current_utc_datetime(self, value: datetime):
376479
self._current_utc_datetime = value
377480

481+
@property
482+
def is_replaying(self) -> bool:
483+
return self._is_replaying
484+
378485
def set_custom_status(self, custom_status: Any) -> None:
379486
self._encoded_custom_status = shared.to_json(custom_status) if custom_status is not None else None
380487

@@ -389,7 +496,7 @@ def create_timer_internal(self, fire_at: Union[datetime, timedelta],
389496
action = ph.new_create_timer_action(id, fire_at)
390497
self._pending_actions[id] = action
391498

392-
timer_task = task.TimerTask()
499+
timer_task: task.TimerTask = task.TimerTask()
393500
if retryable_task is not None:
394501
timer_task.set_retryable_parent(retryable_task)
395502
self._pending_tasks[id] = timer_task
@@ -457,7 +564,7 @@ def wait_for_external_event(self, name: str) -> task.Task:
457564
# event with the given name so that we can resume the generator when it
458565
# arrives. If there are multiple events with the same name, we return
459566
# them in the order they were received.
460-
external_event_task = task.CompletableTask()
567+
external_event_task: task.CompletableTask = task.CompletableTask()
461568
event_name = name.casefold()
462569
event_list = self._received_events.get(event_name, None)
463570
if event_list:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
99

1010
[project]
1111
name = "durabletask"
12-
version = "0.2.1"
12+
version = "0.3.0"
1313
description = "A Durable Task Client SDK for Python"
1414
keywords = [
1515
"durable",

0 commit comments

Comments
 (0)