1010
1111import asyncio
1212import queue
13- import random
1413import threading
1514import time
1615from collections .abc import Awaitable , Callable
2423
2524# Polling configuration
2625DEFAULT_POLL_WAIT_TIME = 0.05 # 50ms
27- DEFAULT_POLL_SLEEP_TIME = 0.5 # 1 second
26+ DEFAULT_POLL_SLEEP_TIME = 0.5 # 500ms
2827
2928
3029class TaskQueueFullError (RuntimeError ):
3130 """Raised when an AsyncTaskRunner queue is full."""
3231
3332
3433@dataclass
35- class _TimedResult (Generic [T ]):
36- """Internal wrapper for results with creation timestamp."""
34+ class TimedResult (Generic [T ]):
35+ """Wrapper for task results with creation timestamp.
3736
38- create_time : int # nanoseconds from time.monotonic_ns()
37+ Attributes
38+ ----------
39+ create_time : int
40+ Task creation time in nanoseconds from time.monotonic_ns().
41+ data : T
42+ The actual result data from the completed task.
43+ """
44+
45+ create_time : int
3946 data : T
4047
4148
@@ -72,13 +79,13 @@ class AsyncTaskRunner(Generic[T]):
7279 ----------
7380 max_queue_size : int
7481 Maximum size for input and output queues. Tasks submitted when
75- the input queue is full will raise RuntimeError .
82+ the input queue is full will raise TaskQueueFullError .
7683 poll_wait_time : float, optional
7784 Time in seconds to wait for task completion during each poll
7885 cycle. Default is 0.05 (50ms).
7986 poll_sleep_time : float, optional
8087 Time in seconds to sleep between poll cycles.
81- Default is 1.0 second .
88+ Default is 0.5 seconds .
8289 enable_tracing : bool, optional
8390 Enable detailed logging of task submission and completion.
8491 Default is False.
@@ -163,7 +170,7 @@ def __init__(
163170 Default is 0.05.
164171 poll_sleep_time : float, optional
165172 Time in seconds to sleep between poll cycles.
166- Default is 1.0 .
173+ Default is 0.5 .
167174 enable_tracing : bool, optional
168175 Enable detailed logging. Default is False.
169176 """
@@ -180,13 +187,10 @@ def __init__(
180187 self .input_queue : queue .Queue [_TaskInput [T ]] = queue .Queue (
181188 maxsize = max_queue_size
182189 )
183- self .output_queue : queue .Queue [_TimedResult [T ]] = queue .Queue (
190+ self .output_queue : queue .Queue [TimedResult [T ]] = queue .Queue (
184191 maxsize = max_queue_size
185192 )
186193
187- # Cache for results to support wait() with arbitrary counts
188- self .result_cache : list [_TimedResult [T ]] = []
189-
190194 # Thread exception handling
191195 self ._thread_exception_lock = threading .Lock ()
192196 self ._thread_exception : Exception | None = None
@@ -335,7 +339,7 @@ async def _run_async_loop(self):
335339 try :
336340 # Place result in output queue
337341 self .output_queue .put_nowait (
338- _TimedResult (create_time = task_obj .create_time , data = result )
342+ TimedResult (create_time = task_obj .create_time , data = result )
339343 )
340344 if self .enable_tracing and self .logger :
341345 self .logger .info (
@@ -355,6 +359,7 @@ async def _run_async_loop(self):
355359 raise TaskQueueFullError (
356360 "Output queue full. Please increase max_queue_size."
357361 )
362+ # Sleep to avoid busy-waiting
358363 await asyncio .sleep (self .poll_sleep_time )
359364 finally :
360365 # Cancel all remaining tasks on shutdown
@@ -390,9 +395,10 @@ def submit(
390395
391396 Raises
392397 ------
398+ TaskQueueFullError
399+ If the input queue is full.
393400 RuntimeError
394- If the input queue is full or if the background thread
395- has died.
401+ If the background thread has died.
396402
397403 Examples
398404 --------
@@ -417,12 +423,13 @@ def submit(
417423 "wait for tasks to complete."
418424 )
419425
420- def wait (self , count : int , timeout : float | None = None ) -> list [T ]:
426+ def wait (
427+ self , count : int , timeout : float | None = None , with_timing : bool = False
428+ ) -> list [TimedResult [T ]] | list [T ]:
421429 """Wait for a specified number of task results.
422430
423431 This method blocks until at least `count` results are available
424- or the timeout expires. Results are returned in random order
425- (shuffled).
432+ or the timeout expires.
426433
427434 Parameters
428435 ----------
@@ -431,11 +438,15 @@ def wait(self, count: int, timeout: float | None = None) -> list[T]:
431438 timeout : float | None, optional
432439 Maximum time in seconds to wait. If None, waits indefinitely
433440 (up to 7 days). Default is None.
441+ with_timing : bool, optional
442+ If True, return TimedResult objects with creation timestamps.
443+ If False, return only the data values. Default is False.
434444
435445 Returns
436446 -------
437- List[T]
438- List of task results, shuffled randomly.
447+ list[TimedResult[T]] | list[T]
448+ If with_timing=True, returns list of TimedResult objects.
449+ If with_timing=False, returns list of result data.
439450
440451 Raises
441452 ------
@@ -460,16 +471,7 @@ def wait(self, count: int, timeout: float | None = None) -> list[T]:
460471 # Check thread health
461472 self ._check_thread_health ()
462473
463- # Drain all available results from output queue
464- while True :
465- try :
466- timed_result = self .output_queue .get_nowait ()
467- self .result_cache .append (timed_result )
468- except queue .Empty :
469- break
470-
471- # Check if we have enough results
472- if len (self .result_cache ) >= count :
474+ if self .get_output_queue_size () >= count :
473475 break
474476
475477 # Sleep briefly to avoid busy waiting
@@ -480,23 +482,16 @@ def wait(self, count: int, timeout: float | None = None) -> list[T]:
480482 self ._check_thread_health ()
481483 raise RuntimeError ("AsyncTaskRunner is exiting, cannot wait for results." )
482484
483- accepted = len ( self .result_cache )
485+ accepted = self .get_output_queue_size ( )
484486 if accepted < count :
485487 raise TimeoutError (
486488 f"Timed out waiting for { count } results, only received { accepted } ."
487489 )
488490
489- # Sort by creation time for deterministic ordering
490- self .result_cache .sort (key = lambda x : x .create_time )
491-
492- # Extract the requested number of results
493- results_to_return = self .result_cache [:count ]
494- self .result_cache = self .result_cache [count :]
495-
496- # Shuffle for randomness (helps with data diversity in ML)
497- random .shuffle (results_to_return )
498-
499- # Extract just the data (remove timing metadata)
491+ # Extract the requested number of results, sorted by return time
492+ results_to_return = [self .output_queue .get () for _ in range (count )]
493+ if with_timing :
494+ return results_to_return
500495 return [r .data for r in results_to_return ]
501496
502497 def submit_batch (
0 commit comments