Skip to content

Commit 15d59cc

Browse files
committed
feat(queue): add adaptive retry tuning
1 parent 7836b8e commit 15d59cc

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

agent_pm/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def _parse_google_scopes(cls, value):
8686
pagerduty_routing_key: str | None = Field(None, alias="PAGERDUTY_ROUTING_KEY")
8787
task_queue_playbooks: dict[str, str] = Field(default_factory=dict, alias="TASK_QUEUE_PLAYBOOKS")
8888
pagerduty_service_name: str | None = Field(None, alias="PAGERDUTY_SERVICE_NAME")
89+
task_queue_adaptive_failure_threshold: float = Field(0.6, alias="TASK_QUEUE_ADAPTIVE_FAILURE_THRESHOLD")
90+
task_queue_adaptive_min_samples: int = Field(10, alias="TASK_QUEUE_ADAPTIVE_MIN_SAMPLES")
8991
database_url: str | None = Field("sqlite+aiosqlite:///./data/agent_pm.db", alias="DATABASE_URL")
9092
database_echo: bool = Field(False, alias="DATABASE_ECHO")
9193
redis_url: str = Field("redis://localhost:6379", alias="REDIS_URL")

agent_pm/storage/tasks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ async def _worker(self, worker_id: int):
323323
recent_failures: dict[str, list[datetime]] = {}
324324
auto_requeue_counts: dict[str, int] = {}
325325
last_alert_sent: dict[str, datetime] = {}
326+
failure_metrics: dict[str, list[bool]] = {}
326327

327328
while self.running:
328329
auto_errors = set(settings.task_queue_auto_requeue_errors)
@@ -372,6 +373,23 @@ async def _send_alert(error_type: str, payload: dict[str, Any]) -> None:
372373
except Exception as exc: # pragma: no cover - logging
373374
logger.error("Failed to send Slack alert: %s", exc)
374375

376+
async def _apply_adaptive_policy(task_name: str) -> None:
377+
samples = failure_metrics.get(task_name, [])
378+
if len(samples) < settings.task_queue_adaptive_min_samples:
379+
return
380+
failure_rate = 1 - (sum(1 for success in samples if success) / len(samples))
381+
if failure_rate < settings.task_queue_adaptive_failure_threshold:
382+
return
383+
policy = await get_retry_policy(self._redis, task_name) or {}
384+
policy.setdefault("max_retries", payload.get("max_retries", 3))
385+
policy.setdefault("timeout", self._task_timeout)
386+
policy.setdefault("backoff_base", self._backoff_base)
387+
policy.setdefault("backoff_max", self._backoff_max)
388+
policy["max_retries"] = min(int(policy["max_retries"]) + 1, settings.task_queue_max_auto_requeues)
389+
policy["timeout"] = float(policy.get("timeout", self._task_timeout)) + 5.0
390+
await set_retry_policy(self._redis, task_name, policy)
391+
failure_metrics[task_name] = []
392+
375393
payload = await self.pop()
376394
if not payload:
377395
await asyncio.sleep(self._poll_interval)
@@ -460,6 +478,9 @@ async def _send_alert(error_type: str, payload: dict[str, Any]) -> None:
460478
auto_requeue_counts[key] = count + 1
461479
if auto_payload:
462480
payload = auto_payload
481+
metrics = failure_metrics.setdefault(name, [])
482+
metrics.append(False)
483+
await _apply_adaptive_policy(name)
463484
if _record_failure(error_type, identifier):
464485
await _send_alert(error_type, payload)
465486
continue
@@ -469,9 +490,15 @@ async def _send_alert(error_type: str, payload: dict[str, Any]) -> None:
469490
float(policy.get("backoff_max", self._backoff_max)),
470491
)
471492
await asyncio.sleep(backoff)
493+
metrics = failure_metrics.setdefault(name, [])
494+
metrics.append(False)
495+
await _apply_adaptive_policy(name)
472496
await redis_enqueue_task(self._redis, name, payload)
473497
continue
474498

499+
metrics = failure_metrics.setdefault(name, [])
500+
metrics.append(True)
501+
await _apply_adaptive_policy(name)
475502
await set_task_result(self._redis, task_id, {"status": "completed", "result": result})
476503
record_task_completion(self.queue_name, TaskStatus.COMPLETED.value)
477504
record_task_latency(self.queue_name, (utc_now() - start).total_seconds())

tests/storage/test_redis_worker.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from typing import Any
23
from collections import deque
34

45
import pytest
@@ -204,3 +205,46 @@ async def flaky_task() -> str:
204205
assert auto_metric == 1
205206
assert alert_metric == 1
206207
assert calls
208+
209+
210+
@pytest.mark.asyncio
211+
async def test_adaptive_retry_policy_updates(redis_queue, monkeypatch):
212+
queue, fake = redis_queue
213+
214+
monkeypatch.setattr(tasks_module.settings, "task_queue_adaptive_min_samples", 2)
215+
monkeypatch.setattr(tasks_module.settings, "task_queue_adaptive_failure_threshold", 0.5)
216+
monkeypatch.setattr(tasks_module.settings, "task_queue_max_auto_requeues", 10)
217+
218+
policies: dict[str, dict[str, Any]] = {}
219+
220+
async def fake_get_policy(client, task_name: str):
221+
return policies.get(task_name)
222+
223+
async def fake_set_policy(client, task_name: str, policy: dict[str, Any]):
224+
policies[task_name] = policy
225+
226+
monkeypatch.setattr(tasks_module, "get_retry_policy", fake_get_policy)
227+
monkeypatch.setattr(tasks_module, "set_retry_policy", fake_set_policy)
228+
229+
attempts = {"count": 0}
230+
231+
async def unreliable_task() -> str:
232+
attempts["count"] += 1
233+
if attempts["count"] <= 2:
234+
raise RuntimeError("flaky")
235+
return "ok"
236+
237+
task_id = await queue.enqueue("unstable", unreliable_task, max_retries=5)
238+
239+
result = None
240+
for _ in range(100):
241+
result = await redis_helpers.get_task_result(fake, task_id)
242+
if result:
243+
break
244+
await asyncio.sleep(0.05)
245+
246+
assert result is not None
247+
assert result["status"] == "completed"
248+
assert "unstable" in policies
249+
assert policies["unstable"]["max_retries"] == 6
250+
assert policies["unstable"]["timeout"] == tasks_module.settings.task_queue_task_timeout + 5.0

0 commit comments

Comments
 (0)