@@ -197,6 +197,8 @@ def __init__(self, queue_size: int = DEFAULT_QUEUE_SIZE) -> None:
197
197
# Event loop needs to remain in the same process
198
198
self ._task_for_pid : Optional [int ] = None
199
199
self ._loop : Optional [asyncio .AbstractEventLoop ] = None
200
+ # Track active callback tasks so they have a strong reference and can be cancelled on kill
201
+ self ._active_tasks : set [asyncio .Task ] = set ()
200
202
201
203
@property
202
204
def is_alive (self ) -> bool :
@@ -211,6 +213,12 @@ def kill(self) -> None:
211
213
self ._task .cancel ()
212
214
self ._task = None
213
215
self ._task_for_pid = None
216
+ # Also cancel any active callback tasks
217
+ # Avoid modifying the set while cancelling tasks
218
+ tasks_to_cancel = set (self ._active_tasks )
219
+ for task in tasks_to_cancel :
220
+ task .cancel ()
221
+ self ._active_tasks .clear ()
214
222
self ._loop = None
215
223
216
224
def start (self ) -> None :
@@ -272,16 +280,30 @@ def submit(self, callback: Callable[[], Any]) -> bool:
272
280
async def _target (self ) -> None :
273
281
while True :
274
282
callback = await self ._queue .get ()
275
- try :
276
- if inspect .iscoroutinefunction (callback ):
277
- # Callback is an async coroutine, need to await it
278
- await callback ()
279
- else :
280
- # Callback is a sync function, need to call it
281
- callback ()
282
- except Exception :
283
- logger .error ("Failed processing job" , exc_info = True )
284
- finally :
285
- self ._queue .task_done ()
283
+ # Firing tasks instead of awaiting them allows for concurrent requests
284
+ task = asyncio .create_task (self ._process_callback (callback ))
285
+ # Create a strong reference to the task so it can be cancelled on kill
286
+ # and does not get garbage collected while running
287
+ self ._active_tasks .add (task )
288
+ task .add_done_callback (self ._on_task_complete )
286
289
# Yield to let the event loop run other tasks
287
290
await asyncio .sleep (0 )
291
+
292
+ async def _process_callback (self , callback : Callable [[], Any ]) -> None :
293
+ if inspect .iscoroutinefunction (callback ):
294
+ # Callback is an async coroutine, need to await it
295
+ await callback ()
296
+ else :
297
+ # Callback is a sync function, need to call it
298
+ callback ()
299
+
300
+ def _on_task_complete (self , task : asyncio .Task [None ]) -> None :
301
+ try :
302
+ task .result ()
303
+ except Exception :
304
+ logger .error ("Failed processing job" , exc_info = True )
305
+ finally :
306
+ # Mark the task as done and remove it from the active tasks set
307
+ # This happens only after the task has completed
308
+ self ._queue .task_done ()
309
+ self ._active_tasks .discard (task )
0 commit comments