1212from enum import Enum
1313from typing import Any
1414
15+ from ..observability .metrics import (
16+ record_task_completion ,
17+ record_task_enqueued ,
18+ record_task_latency ,
19+ )
1520from ..settings import settings
1621from ..utils .datetime import utc_now
17- from .redis import enqueue_task as redis_enqueue_task
18- from .redis import get_redis_client
22+ from .redis import (
23+ clear_dead_letter ,
24+ enqueue_task as redis_enqueue_task ,
25+ fetch_dead_letters ,
26+ get_redis_client ,
27+ list_heartbeats ,
28+ pop_task ,
29+ record_dead_letter ,
30+ set_task_result ,
31+ write_heartbeat ,
32+ )
1933
2034logger = logging .getLogger (__name__ )
2135
@@ -61,6 +75,7 @@ def __init__(self, max_workers: int = 5):
6175 self .workers : list [asyncio .Task [None ]] = []
6276 self .running = False
6377 self ._lock = asyncio .Lock ()
78+ self .queue_name = "memory"
6479
6580 async def start (self ):
6681 """Start background workers."""
@@ -102,6 +117,7 @@ async def enqueue(
102117 async with self ._lock :
103118 self .queue .append (task )
104119 self .tasks [task_id ] = task
120+ record_task_enqueued (self .queue_name )
105121 logger .info ("Task enqueued: %s (id=%s)" , name , task_id )
106122 return task_id
107123
@@ -153,6 +169,8 @@ async def _execute_task(self, task: Task):
153169 task .status = TaskStatus .COMPLETED
154170 task .result = result
155171 task .completed_at = utc_now ()
172+ record_task_completion (self .queue_name , task .status .value )
173+ record_task_latency (self .queue_name , (task .completed_at - task .started_at ).total_seconds ())
156174 logger .info ("Task completed: %s (id=%s)" , task .name , task .task_id )
157175 except Exception as exc :
158176 task .error = str (exc )
@@ -175,8 +193,19 @@ async def _execute_task(self, task: Task):
175193 else :
176194 task .status = TaskStatus .FAILED
177195 task .completed_at = utc_now ()
196+ record_task_completion (self .queue_name , task .status .value )
197+ record_task_latency (self .queue_name , (task .completed_at - task .started_at ).total_seconds ())
178198 logger .error ("Task permanently failed: %s (id=%s)" , task .name , task .task_id )
179199
200+ async def list_dead_letters (self , limit : int = 100 ) -> list [dict [str , Any ]]:
201+ return []
202+
203+ async def delete_dead_letter (self , task_id : str ) -> None :
204+ return None
205+
206+ async def worker_heartbeats (self ) -> dict [str , Any ]:
207+ return {}
208+
180209
181210# Global task queue instance
182211_task_queue : TaskQueue | None = None
@@ -191,6 +220,20 @@ async def get_task_queue() -> TaskQueue:
191220 client = await get_redis_client ()
192221
193222 class RedisTaskQueue (TaskQueue ):
223+ def __init__ (self , max_workers : int = 5 ):
224+ super ().__init__ (max_workers = max_workers )
225+ self .queue_name = "redis"
226+ self ._redis = client
227+ self ._registry : dict [str , Callable [..., Coroutine [Any , Any , Any ]]] = {}
228+ self ._poll_interval = settings .task_queue_poll_interval
229+ self ._task_timeout = settings .task_queue_task_timeout
230+ self ._backoff_base = settings .task_queue_retry_backoff_base
231+ self ._backoff_max = settings .task_queue_retry_backoff_max
232+ self ._heartbeat_ttl = settings .task_queue_worker_heartbeat_ttl
233+
234+ def register (self , name : str , coro_fn : Callable [..., Coroutine [Any , Any , Any ]]) -> None :
235+ self ._registry [name ] = coro_fn
236+
194237 async def enqueue ( # type: ignore[override]
195238 self ,
196239 name : str ,
@@ -199,14 +242,101 @@ async def enqueue( # type: ignore[override]
199242 max_retries : int = 3 ,
200243 ** kwargs : Any ,
201244 ) -> str :
245+ self .register (name , coro_fn )
202246 payload = {
203247 "task_id" : uuid .uuid4 ().hex ,
204248 "name" : name ,
205249 "args" : args ,
206250 "kwargs" : kwargs ,
207251 "max_retries" : max_retries ,
252+ "enqueued_at" : utc_now ().isoformat (),
208253 }
209- return await redis_enqueue_task (client , name , payload )
254+ await redis_enqueue_task (self ._redis , name , payload )
255+ record_task_enqueued (self .queue_name )
256+ logger .info ("Redis task enqueued: %s (id=%s)" , name , payload ["task_id" ])
257+ return payload ["task_id" ]
258+
259+ async def pop (self ) -> dict [str , Any ] | None :
260+ payload = await pop_task (self ._redis )
261+ if not payload :
262+ return None
263+ return payload
264+
265+ async def _worker (self , worker_id : int ):
266+ logger .info ("Redis worker %d started" , worker_id )
267+ while self .running :
268+ payload = await self .pop ()
269+ if not payload :
270+ await asyncio .sleep (self ._poll_interval )
271+ continue
272+
273+ task_id = payload .get ("task_id" , "unknown" )
274+ name = payload .get ("name" )
275+ coro_fn = self ._registry .get (name )
276+ if coro_fn is None :
277+ logger .error ("No registered task callable for %s" , name )
278+ await record_dead_letter (
279+ self ._redis ,
280+ {** payload , "error" : "missing_callable" , "worker_id" : worker_id },
281+ )
282+ record_task_completion (self .queue_name , TaskStatus .FAILED .value )
283+ continue
284+
285+ retries = payload .get ("retry_count" , 0 )
286+ max_retries = payload .get ("max_retries" , 3 )
287+
288+ start = utc_now ()
289+ try :
290+ result = await asyncio .wait_for (
291+ coro_fn (* payload .get ("args" , ()), ** payload .get ("kwargs" , {})),
292+ timeout = self ._task_timeout ,
293+ )
294+ except asyncio .TimeoutError :
295+ payload ["retry_count" ] = retries + 1
296+ payload ["last_error" ] = "timeout"
297+ payload ["worker_id" ] = worker_id
298+ await record_dead_letter (self ._redis , payload )
299+ record_task_completion (self .queue_name , TaskStatus .FAILED .value )
300+ record_task_latency (self .queue_name , (utc_now () - start ).total_seconds ())
301+ continue
302+ except Exception as exc : # pylint: disable=broad-except
303+ retries += 1
304+ payload ["retry_count" ] = retries
305+ payload ["last_error" ] = str (exc )
306+ if retries >= max_retries :
307+ payload ["worker_id" ] = worker_id
308+ await record_dead_letter (self ._redis , payload )
309+ record_task_completion (self .queue_name , TaskStatus .FAILED .value )
310+ record_task_latency (self .queue_name , (utc_now () - start ).total_seconds ())
311+ continue
312+
313+ backoff = min (self ._backoff_base ** retries , self ._backoff_max )
314+ await asyncio .sleep (backoff )
315+ await redis_enqueue_task (self ._redis , name , payload )
316+ continue
317+
318+ await set_task_result (self ._redis , task_id , {"status" : "completed" , "result" : result })
319+ record_task_completion (self .queue_name , TaskStatus .COMPLETED .value )
320+ record_task_latency (self .queue_name , (utc_now () - start ).total_seconds ())
321+
322+ heartbeat_payload = {
323+ "worker_id" : worker_id ,
324+ "task_id" : task_id ,
325+ "name" : name ,
326+ "completed_at" : utc_now ().isoformat (),
327+ }
328+ await write_heartbeat (self ._redis , f"worker:{ worker_id } " , heartbeat_payload , self ._heartbeat_ttl )
329+
330+ logger .info ("Redis worker %d stopped" , worker_id )
331+
332+ async def list_dead_letters (self , limit : int = 100 ) -> list [dict [str , Any ]]:
333+ return await fetch_dead_letters (self ._redis , limit )
334+
335+ async def delete_dead_letter (self , task_id : str ) -> None :
336+ await clear_dead_letter (self ._redis , task_id )
337+
338+ async def worker_heartbeats (self ) -> dict [str , dict [str , Any ]]:
339+ return await list_heartbeats (self ._redis )
210340
211341 _task_queue = RedisTaskQueue (max_workers = settings .task_queue_workers )
212342 else :
0 commit comments