1313from enum import Enum
1414from typing import Any
1515
16+ from ..clients import slack_client
1617from ..observability .metrics import (
1718 dead_letter_active_gauge ,
1819 dead_letter_alert_total ,
2526 record_task_latency ,
2627)
2728from ..settings import settings
28- from ..utils .datetime import utc_now
29- from ..clients import pagerduty_client , slack_client
3029from ..tasks .playbooks import run_playbook
30+ from ..utils .datetime import utc_now
3131from .redis import (
32- append_dead_letter_audit ,
3332 clear_dead_letter ,
3433 count_dead_letters ,
3534 delete_retry_policy ,
36- enqueue_task as redis_enqueue_task ,
3735 fetch_dead_letter_audit ,
3836 fetch_dead_letters ,
3937 get_dead_letter ,
4846 set_task_result ,
4947 write_heartbeat ,
5048)
49+ from .redis import (
50+ enqueue_task as redis_enqueue_task ,
51+ )
5152
5253logger = logging .getLogger (__name__ )
5354
@@ -332,12 +333,22 @@ async def _worker(self, worker_id: int):
332333 alert_channel = settings .task_queue_alert_channel or settings .slack_status_channel
333334 cooldown = timedelta (minutes = settings .task_queue_alert_cooldown_minutes )
334335
335- def _should_auto_requeue (err_type : str | None ) -> bool :
336+ def _should_auto_requeue (
337+ err_type : str | None ,
338+ * ,
339+ auto_errors : set [str ] = auto_errors ,
340+ ) -> bool :
336341 if not err_type :
337342 return False
338343 return err_type in auto_errors
339344
340- def _record_failure (err_type : str | None , task_identifier : str ) -> bool :
345+ def _record_failure (
346+ err_type : str | None ,
347+ task_identifier : str ,
348+ * ,
349+ alert_window : timedelta = alert_window ,
350+ alert_threshold : int = alert_threshold ,
351+ ) -> bool :
341352 if not err_type :
342353 return False
343354 now = utc_now ()
@@ -348,13 +359,17 @@ def _record_failure(err_type: str | None, task_identifier: str) -> bool:
348359 recent_failures [key ] = [ts for ts in entries if ts >= cutoff ]
349360 return len (recent_failures [key ]) >= alert_threshold
350361
351- async def _send_alert (error_type : str , payload : dict [str , Any ]) -> None :
362+ async def _send_alert (
363+ error_type : str ,
364+ payload : dict [str , Any ],
365+ * ,
366+ alert_channel : str | None = alert_channel ,
367+ cooldown : timedelta = cooldown ,
368+ ) -> None :
352369 if not alert_channel :
353370 return
354371 if settings .dry_run or not slack_client .enabled :
355- logger .warning (
356- "Slack alert skipped (dry run): %s" , error_type
357- )
372+ logger .warning ("Slack alert skipped (dry run): %s" , error_type )
358373 return
359374 now = utc_now ()
360375 last_sent = last_alert_sent .get (error_type )
@@ -373,19 +388,24 @@ async def _send_alert(error_type: str, payload: dict[str, Any]) -> None:
373388 except Exception as exc : # pragma: no cover - logging
374389 logger .error ("Failed to send Slack alert: %s" , exc )
375390
376- async def _apply_adaptive_policy (task_name : str ) -> None :
391+ async def _apply_adaptive_policy (
392+ task_name : str ,
393+ task_payload : dict [str , Any ],
394+ ) -> None :
377395 samples = failure_metrics .get (task_name , [])
378396 if len (samples ) < settings .task_queue_adaptive_min_samples :
379397 return
380398 failure_rate = 1 - (sum (1 for success in samples if success ) / len (samples ))
381399 if failure_rate < settings .task_queue_adaptive_failure_threshold :
382400 return
383401 policy = await get_retry_policy (self ._redis , task_name ) or {}
384- policy .setdefault ("max_retries" , payload .get ("max_retries" , 3 ))
402+ policy .setdefault ("max_retries" , task_payload .get ("max_retries" , 3 ))
385403 policy .setdefault ("timeout" , self ._task_timeout )
386404 policy .setdefault ("backoff_base" , self ._backoff_base )
387405 policy .setdefault ("backoff_max" , self ._backoff_max )
388- policy ["max_retries" ] = min (int (policy ["max_retries" ]) + 1 , settings .task_queue_max_auto_requeues )
406+ policy ["max_retries" ] = min (
407+ int (policy ["max_retries" ]) + 1 , settings .task_queue_max_auto_requeues
408+ )
389409 policy ["timeout" ] = float (policy .get ("timeout" , self ._task_timeout )) + 5.0
390410 await set_retry_policy (self ._redis , task_name , policy )
391411 failure_metrics [task_name ] = []
@@ -462,7 +482,9 @@ async def _apply_adaptive_policy(task_name: str) -> None:
462482 ).inc ()
463483 record_task_completion (self .queue_name , TaskStatus .FAILED .value )
464484 record_task_latency (self .queue_name , (utc_now () - start ).total_seconds ())
465- identifier = payload .get ("metadata" , {}).get ("workflow_id" ) or payload .get ("name" , "unknown" )
485+ identifier = payload .get ("metadata" , {}).get ("workflow_id" ) or payload .get (
486+ "name" , "unknown"
487+ )
466488 playbook_name = settings .task_queue_playbooks .get (error_type )
467489 if playbook_name :
468490 await run_playbook (playbook_name , payload , self , error_type )
@@ -480,7 +502,7 @@ async def _apply_adaptive_policy(task_name: str) -> None:
480502 payload = auto_payload
481503 metrics = failure_metrics .setdefault (name , [])
482504 metrics .append (False )
483- await _apply_adaptive_policy (name )
505+ await _apply_adaptive_policy (name , payload )
484506 if _record_failure (error_type , identifier ):
485507 await _send_alert (error_type , payload )
486508 continue
@@ -492,13 +514,13 @@ async def _apply_adaptive_policy(task_name: str) -> None:
492514 await asyncio .sleep (backoff )
493515 metrics = failure_metrics .setdefault (name , [])
494516 metrics .append (False )
495- await _apply_adaptive_policy (name )
517+ await _apply_adaptive_policy (name , payload )
496518 await redis_enqueue_task (self ._redis , name , payload )
497519 continue
498520
499521 metrics = failure_metrics .setdefault (name , [])
500522 metrics .append (True )
501- await _apply_adaptive_policy (name )
523+ await _apply_adaptive_policy (name , payload )
502524 await set_task_result (self ._redis , task_id , {"status" : "completed" , "result" : result })
503525 record_task_completion (self .queue_name , TaskStatus .COMPLETED .value )
504526 record_task_latency (self .queue_name , (utc_now () - start ).total_seconds ())
@@ -588,7 +610,9 @@ async def purge_dead_letters(self, *, older_than: timedelta | None = None) -> in
588610 if older_than is None :
589611 deleted = await purge_dead_letters (self ._redis )
590612 dead_letter_purged_total .labels (queue = self .queue_name , mode = "all" ).inc (deleted )
591- dead_letter_active_gauge .labels (queue = self .queue_name ).set (await count_dead_letters (self ._redis ))
613+ dead_letter_active_gauge .labels (queue = self .queue_name ).set (
614+ await count_dead_letters (self ._redis )
615+ )
592616 return deleted
593617 cutoff = utc_now () - older_than
594618 deleted = await purge_dead_letters (self ._redis , older_than = cutoff )
0 commit comments