33
44import concurrent .futures
55import logging
6+ import random
67from datetime import datetime , timedelta
78from threading import Event , Thread
89from types import GeneratorType
1112import grpc
1213from google .protobuf import empty_pb2
1314
14- import durabletask .internal .helpers as ph
1515import durabletask .internal .helpers as pbh
16+ import durabletask .internal .helpers as ph
1617import durabletask .internal .orchestrator_service_pb2 as pb
1718import durabletask .internal .orchestrator_service_pb2_grpc as stubs
1819import durabletask .internal .shared as shared
@@ -91,13 +92,15 @@ def __init__(self, *,
9192 log_handler = None ,
9293 log_formatter : Optional [logging .Formatter ] = None ,
9394 secure_channel : bool = False ,
94- interceptors : Optional [Sequence [shared .ClientInterceptor ]] = None ):
95+ interceptors : Optional [Sequence [shared .ClientInterceptor ]] = None ,
96+ max_workers : Optional [int ] = None ):
9597 self ._registry = _Registry ()
9698 self ._host_address = host_address if host_address else shared .get_default_host_address ()
9799 self ._logger = shared .get_logger ("worker" , log_handler , log_formatter )
98100 self ._shutdown = Event ()
99101 self ._is_running = False
100102 self ._secure_channel = secure_channel
103+ self ._max_workers = max_workers if max_workers is not None else 16
101104
102105 # Determine the interceptors to use
103106 if interceptors is not None :
@@ -129,31 +132,117 @@ def add_activity(self, fn: task.Activity) -> str:
129132
130133 def start (self ):
131134 """Starts the worker on a background thread and begins listening for work items."""
132- channel = shared .get_grpc_channel (self ._host_address , self ._secure_channel , self ._interceptors )
133- stub = stubs .TaskHubSidecarServiceStub (channel )
134-
135135 if self ._is_running :
136136 raise RuntimeError ('The worker is already running.' )
137137
138138 def run_loop ():
139+ """Enhanced run loop with better connection management and retry logic."""
140+
141+ # Connection state management for retry fix
142+ current_channel : Optional [grpc .Channel ] = None
143+ current_stub : Optional [stubs .TaskHubSidecarServiceStub ] = None
144+ conn_retry_count = 0
145+ conn_max_retry_delay = 60
146+
147+ def create_fresh_connection () -> None :
148+ """Create a new gRPC channel and stub, invalidating any existing ones.
149+
150+ Raises:
151+ Exception: If connection creation or testing fails.
152+ """
153+ nonlocal current_channel , current_stub , conn_retry_count
154+
155+ # Close existing connection if any
156+ if current_channel :
157+ try :
158+ current_channel .close ()
159+ except Exception :
160+ pass
161+
162+ current_channel = None
163+ current_stub = None
164+
165+ try :
166+ # Create new connection
167+ current_channel = shared .get_grpc_channel (self ._host_address , self ._secure_channel , self ._interceptors )
168+ current_stub = stubs .TaskHubSidecarServiceStub (current_channel )
169+
170+ # Test the connection
171+ current_stub .Hello (empty_pb2 .Empty ())
172+ conn_retry_count = 0 # Reset on successful connection
173+ self ._logger .debug (f"Created fresh connection to { self ._host_address } " )
174+
175+ except Exception as e :
176+ self ._logger .debug (f"Failed to create connection: { e } " )
177+ current_channel = None
178+ current_stub = None
179+ raise # Re-raise the original exception
180+
181+ def invalidate_connection () -> None :
182+ """Mark current connection as invalid."""
183+ nonlocal current_channel , current_stub
184+ if current_channel :
185+ try :
186+ current_channel .close ()
187+ except Exception :
188+ pass
189+ current_channel = None
190+ current_stub = None
191+
192+ def should_invalidate_connection (rpc_error : grpc .RpcError ) -> bool :
193+ """Determine if a gRPC error should trigger connection invalidation.
194+
195+ Connection-level errors (network, authentication, server unavailable)
196+ should invalidate the connection, while application-level errors
197+ (bad requests, not found, etc.) should not.
198+ """
199+ error_code = rpc_error .code () # type: ignore
200+
201+ # Connection-level errors that warrant invalidation
202+ connection_level_errors = {
203+ grpc .StatusCode .UNAVAILABLE , # Server down/unreachable
204+ grpc .StatusCode .DEADLINE_EXCEEDED , # Timeout, likely network issue
205+ grpc .StatusCode .CANCELLED , # Connection cancelled
206+ grpc .StatusCode .UNAUTHENTICATED , # Auth failed, may need new connection
207+ grpc .StatusCode .ABORTED , # Transaction aborted, connection may be bad
208+ }
209+
210+ return error_code in connection_level_errors
211+
139212 # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity
140213 # functions. We'd need to know ahead of time whether a function is async or not.
141- # TODO: Max concurrency configuration settings
142- with concurrent .futures .ThreadPoolExecutor (max_workers = 16 ) as executor :
214+ with concurrent .futures .ThreadPoolExecutor (max_workers = self ._max_workers , thread_name_prefix = "DurableTask" ) as executor :
143215 while not self ._shutdown .is_set ():
144- try :
145- # send a "Hello" message to the sidecar to ensure that it's listening
146- stub .Hello (empty_pb2 .Empty ())
216+ # Ensure we have a valid connection before attempting work
217+ if current_stub is None :
218+ try :
219+ create_fresh_connection ()
220+ except Exception :
221+ # Connection failed, implement exponential backoff
222+ conn_retry_count += 1
223+ delay = min (conn_max_retry_delay , (2 ** min (conn_retry_count , 6 )) + random .uniform (0 , 1 ))
224+ self ._logger .warning (f'Connection failed, retrying in { delay :.2f} seconds (attempt { conn_retry_count } )' )
225+ if self ._shutdown .wait (delay ):
226+ break # Shutdown requested during wait
227+ continue
147228
148- # stream work items
229+ try :
230+ # Stream work items with the current connection
231+ # Type assertion since we know current_stub is not None at this point
232+ assert current_stub is not None , "current_stub should not be None at this point"
233+ stub = current_stub # Local reference for type safety
149234 self ._response_stream = stub .GetWorkItems (pb .GetWorkItemsRequest ())
150235 self ._logger .info (f'Successfully connected to { self ._host_address } . Waiting for work items...' )
151236
152- # The stream blocks until either a work item is received or the stream is canceled
153- # by another thread (see the stop() method).
237+ # Process work items concurrently as they arrive
154238 for work_item in self ._response_stream : # type: ignore
239+ if self ._shutdown .is_set ():
240+ break
241+
155242 request_type = work_item .WhichOneof ('request' )
156243 self ._logger .debug (f'Received "{ request_type } " work item' )
244+
245+ # Submit work items to thread pool for concurrent processing
157246 if work_item .HasField ('orchestratorRequest' ):
158247 executor .submit (self ._execute_orchestrator , work_item .orchestratorRequest , stub , work_item .completionToken )
159248 elif work_item .HasField ('activityRequest' ):
@@ -163,21 +252,39 @@ def run_loop():
163252 else :
164253 self ._logger .warning (f'Unexpected work item type: { request_type } ' )
165254
255+ # Stream ended normally (shouldn't happen unless server closes)
256+ self ._logger .info ("Work item stream ended normally" )
257+
166258 except grpc .RpcError as rpc_error :
167- if rpc_error .code () == grpc .StatusCode .CANCELLED : # type: ignore
259+ # Intelligently decide whether to invalidate connection based on error type
260+ should_invalidate = should_invalidate_connection (rpc_error )
261+ if should_invalidate :
262+ invalidate_connection ()
263+
264+ error_code = rpc_error .code () # type: ignore
265+ if error_code == grpc .StatusCode .CANCELLED :
168266 self ._logger .info (f'Disconnected from { self ._host_address } ' )
169- elif rpc_error .code () == grpc .StatusCode .UNAVAILABLE : # type: ignore
170- self ._logger .warning (
171- f'The sidecar at address { self ._host_address } is unavailable - will continue retrying' )
267+ break # Likely shutdown
268+ elif error_code == grpc .StatusCode .UNAVAILABLE :
269+ self ._logger .warning (f'The sidecar at address { self ._host_address } is unavailable - will continue retrying' )
270+ elif should_invalidate :
271+ self ._logger .warning (f'Connection-level gRPC error ({ error_code } ): { rpc_error } - invalidating connection' )
172272 else :
173- self ._logger .warning (f'Unexpected error: { rpc_error } ' )
273+ self ._logger .warning (f'Application-level gRPC error ({ error_code } ): { rpc_error } - keeping connection' )
274+
275+ # Brief pause before retry
276+ self ._shutdown .wait (1 )
277+
174278 except Exception as ex :
279+ # Unexpected error, invalidate connection and retry
280+ invalidate_connection ()
175281 self ._logger .warning (f'Unexpected error: { ex } ' )
282+ self ._shutdown .wait (1 )
176283
177- # CONSIDER: exponential backoff
178- self . _shutdown . wait ( 5 )
179- self . _logger . info ( "No longer listening for work items" )
180- return
284+ # Final cleanup
285+ invalidate_connection ( )
286+
287+ self . _logger . info ( "No longer listening for work items" )
181288
182289 self ._logger .info (f"Starting gRPC worker that connects to { self ._host_address } " )
183290 self ._runLoop = Thread (target = run_loop )
@@ -367,14 +474,14 @@ def instance_id(self) -> str:
367474 def current_utc_datetime (self ) -> datetime :
368475 return self ._current_utc_datetime
369476
370- @property
371- def is_replaying (self ) -> bool :
372- return self ._is_replaying
373-
374477 @current_utc_datetime .setter
375478 def current_utc_datetime (self , value : datetime ):
376479 self ._current_utc_datetime = value
377480
481+ @property
482+ def is_replaying (self ) -> bool :
483+ return self ._is_replaying
484+
378485 def set_custom_status (self , custom_status : Any ) -> None :
379486 self ._encoded_custom_status = shared .to_json (custom_status ) if custom_status is not None else None
380487
@@ -389,7 +496,7 @@ def create_timer_internal(self, fire_at: Union[datetime, timedelta],
389496 action = ph .new_create_timer_action (id , fire_at )
390497 self ._pending_actions [id ] = action
391498
392- timer_task = task .TimerTask ()
499+ timer_task : task . TimerTask = task .TimerTask ()
393500 if retryable_task is not None :
394501 timer_task .set_retryable_parent (retryable_task )
395502 self ._pending_tasks [id ] = timer_task
@@ -457,7 +564,7 @@ def wait_for_external_event(self, name: str) -> task.Task:
457564 # event with the given name so that we can resume the generator when it
458565 # arrives. If there are multiple events with the same name, we return
459566 # them in the order they were received.
460- external_event_task = task .CompletableTask ()
567+ external_event_task : task . CompletableTask = task .CompletableTask ()
461568 event_name = name .casefold ()
462569 event_list = self ._received_events .get (event_name , None )
463570 if event_list :
0 commit comments