Skip to content

Commit f8a9cd7

Browse files
committed
feat(queue): add audit trail and retry policies
1 parent 505abaa commit f8a9cd7

File tree

3 files changed

+148
-8
lines changed

3 files changed

+148
-8
lines changed

agent_pm/storage/redis.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
QUEUE_KEY = "agent_pm:tasks"
1818
RESULT_KEY = "agent_pm:results"
1919
DEAD_LETTER_KEY = "agent_pm:dead_letter"
20+
DEAD_LETTER_AUDIT_KEY = "agent_pm:dead_letter:audit"
21+
RETRY_POLICY_KEY = "agent_pm:retry_policy"
2022
HEARTBEAT_KEY = "agent_pm:worker_heartbeats"
2123

2224

@@ -32,6 +34,14 @@ def _dead_letter_key() -> str:
3234
return DEAD_LETTER_KEY
3335

3436

37+
def _dead_letter_audit_key() -> str:
38+
return DEAD_LETTER_AUDIT_KEY
39+
40+
41+
def _retry_policy_key() -> str:
42+
return RETRY_POLICY_KEY
43+
44+
3545
def _heartbeat_key() -> str:
3646
return HEARTBEAT_KEY
3747

@@ -122,6 +132,57 @@ async def clear_dead_letter(client: redis.Redis, task_id: str) -> None:
122132
await client.hdel(_dead_letter_key(), task_id)
123133

124134

135+
async def append_dead_letter_audit(client: redis.Redis, entry: dict[str, Any], *, max_entries: int = 1000) -> None:
136+
payload = entry.copy()
137+
payload.setdefault("timestamp", datetime.now(UTC).isoformat())
138+
await client.lpush(_dead_letter_audit_key(), json.dumps(payload))
139+
await client.ltrim(_dead_letter_audit_key(), 0, max_entries - 1)
140+
141+
142+
async def fetch_dead_letter_audit(client: redis.Redis, limit: int = 100) -> list[dict[str, Any]]:
143+
raw_entries = await client.lrange(_dead_letter_audit_key(), 0, max(0, limit - 1))
144+
entries: list[dict[str, Any]] = []
145+
for value in raw_entries:
146+
try:
147+
entries.append(json.loads(value))
148+
except json.JSONDecodeError:
149+
entries.append({"raw": value})
150+
return entries
151+
152+
153+
async def set_retry_policy(client: redis.Redis, task_name: str, policy: dict[str, Any]) -> None:
154+
if not policy:
155+
await client.hdel(_retry_policy_key(), task_name)
156+
return
157+
await client.hset(_retry_policy_key(), task_name, json.dumps(policy))
158+
159+
160+
async def get_retry_policy(client: redis.Redis, task_name: str) -> dict[str, Any] | None:
161+
value = await client.hget(_retry_policy_key(), task_name)
162+
if not value:
163+
return None
164+
try:
165+
data = json.loads(value)
166+
except json.JSONDecodeError:
167+
return None
168+
return data
169+
170+
171+
async def delete_retry_policy(client: redis.Redis, task_name: str) -> None:
172+
await client.hdel(_retry_policy_key(), task_name)
173+
174+
175+
async def list_retry_policies(client: redis.Redis) -> dict[str, dict[str, Any]]:
176+
values = await client.hgetall(_retry_policy_key())
177+
policies: dict[str, dict[str, Any]] = {}
178+
for name, raw in values.items():
179+
try:
180+
policies[name] = json.loads(raw)
181+
except json.JSONDecodeError:
182+
continue
183+
return policies
184+
185+
125186
async def purge_dead_letters(client: redis.Redis, *, older_than: datetime | None = None) -> int:
126187
if older_than is None:
127188
count = await client.hlen(_dead_letter_key())

agent_pm/storage/tasks.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,22 @@
2525
from ..settings import settings
2626
from ..utils.datetime import utc_now
2727
from .redis import (
28+
append_dead_letter_audit,
2829
clear_dead_letter,
2930
count_dead_letters,
31+
delete_retry_policy,
3032
enqueue_task as redis_enqueue_task,
33+
fetch_dead_letter_audit,
3134
fetch_dead_letters,
3235
get_dead_letter,
3336
get_redis_client,
37+
get_retry_policy,
3438
list_heartbeats,
39+
list_retry_policies,
3540
pop_task,
3641
purge_dead_letters,
3742
record_dead_letter,
43+
set_retry_policy,
3844
set_task_result,
3945
write_heartbeat,
4046
)
@@ -217,6 +223,21 @@ async def list_dead_letters(
217223
) -> tuple[list[dict[str, Any]], int]:
218224
return [], 0
219225

226+
async def list_dead_letter_audit(self, limit: int = 100) -> list[dict[str, Any]]:
227+
return []
228+
229+
async def get_retry_policy(self, task_name: str) -> dict[str, Any] | None:
230+
return None
231+
232+
async def set_retry_policy(self, task_name: str, policy: dict[str, Any]) -> None:
233+
return None
234+
235+
async def delete_retry_policy(self, task_name: str) -> None:
236+
return None
237+
238+
async def list_retry_policies(self) -> dict[str, dict[str, Any]]:
239+
return {}
240+
220241
async def delete_dead_letter(self, task_id: str) -> None:
221242
return None
222243

@@ -314,23 +335,23 @@ async def _worker(self, worker_id: int):
314335
await record_dead_letter(self._redis, payload)
315336
dead_letter_recorded_total.labels(
316337
queue=self.queue_name,
317-
error_type="TimeoutError",
318-
).inc()
319-
dead_letter_recorded_total.labels(
320-
queue=self.queue_name,
321-
error_type=payload.get("error_type", payload.get("last_error", "unknown")),
338+
error_type=payload.get("error_type", "MissingCallable"),
322339
).inc()
323340
record_task_completion(self.queue_name, TaskStatus.FAILED.value)
324341
continue
325342

326343
retries = payload.get("retry_count", 0)
327-
max_retries = payload.get("max_retries", 3)
344+
base_max_retries = payload.get("max_retries", 3)
345+
346+
policy = await get_retry_policy(self._redis, name) or {}
347+
timeout = float(policy.get("timeout", self._task_timeout))
348+
max_retries = int(policy.get("max_retries", base_max_retries))
328349

329350
start = utc_now()
330351
try:
331352
result = await asyncio.wait_for(
332353
coro_fn(*payload.get("args", ()), **payload.get("kwargs", {})),
333-
timeout=self._task_timeout,
354+
timeout=timeout,
334355
)
335356
except TimeoutError:
336357
payload.setdefault("metadata", {})
@@ -343,6 +364,10 @@ async def _worker(self, worker_id: int):
343364
payload["stack_trace"] = None
344365
payload["worker_id"] = worker_id
345366
await record_dead_letter(self._redis, payload)
367+
dead_letter_recorded_total.labels(
368+
queue=self.queue_name,
369+
error_type="TimeoutError",
370+
).inc()
346371
record_task_completion(self.queue_name, TaskStatus.FAILED.value)
347372
record_task_latency(self.queue_name, (utc_now() - start).total_seconds())
348373
continue
@@ -365,7 +390,10 @@ async def _worker(self, worker_id: int):
365390
record_task_latency(self.queue_name, (utc_now() - start).total_seconds())
366391
continue
367392

368-
backoff = min(self._backoff_base**retries, self._backoff_max)
393+
backoff = min(
394+
float(policy.get("backoff_base", self._backoff_base)) ** retries,
395+
float(policy.get("backoff_max", self._backoff_max)),
396+
)
369397
await asyncio.sleep(backoff)
370398
await redis_enqueue_task(self._redis, name, payload)
371399
continue
@@ -404,6 +432,21 @@ async def list_dead_letters(
404432
dead_letter_active_gauge.labels(queue=self.queue_name).set(total_filtered)
405433
return window, total_filtered
406434

435+
async def list_dead_letter_audit(self, limit: int = 100) -> list[dict[str, Any]]:
436+
return await fetch_dead_letter_audit(self._redis, limit)
437+
438+
async def get_retry_policy(self, task_name: str) -> dict[str, Any] | None:
439+
return await get_retry_policy(self._redis, task_name)
440+
441+
async def set_retry_policy(self, task_name: str, policy: dict[str, Any]) -> None:
442+
await set_retry_policy(self._redis, task_name, policy)
443+
444+
async def delete_retry_policy(self, task_name: str) -> None:
445+
await delete_retry_policy(self._redis, task_name)
446+
447+
async def list_retry_policies(self) -> dict[str, dict[str, Any]]:
448+
return await list_retry_policies(self._redis)
449+
407450
async def delete_dead_letter(self, task_id: str) -> None:
408451
await clear_dead_letter(self._redis, task_id)
409452

tests/storage/test_redis.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def __init__(self):
1212
self.results: dict[str, str] = {}
1313
self.dead: dict[str, str] = {}
1414
self.heartbeats: dict[str, str] = {}
15+
self.audit: list[str] = []
16+
self.retry_policies: dict[str, str] = {}
1517

1618
async def rpush(self, key: str, value: str) -> None:
1719
self.items.append(value)
@@ -29,28 +31,38 @@ async def hset(self, key: str, field: str, value: str) -> None:
2931
self.dead[field] = value
3032
elif key.endswith("heartbeats"):
3133
self.heartbeats[field] = value
34+
elif key.endswith("retry_policy"):
35+
self.retry_policies[field] = value
3236
else:
3337
self.results[field] = value
3438

3539
async def hget(self, key: str, field: str):
3640
if key.endswith("dead_letter"):
3741
return self.dead.get(field)
42+
if key.endswith("retry_policy"):
43+
return self.retry_policies.get(field)
3844
return self.results.get(field)
3945

4046
async def hgetall(self, key: str):
4147
if key.endswith("dead_letter"):
4248
return self.dead
4349
if key.endswith("heartbeats"):
4450
return self.heartbeats
51+
if key.endswith("retry_policy"):
52+
return self.retry_policies
4553
return self.results
4654

4755
async def hdel(self, key: str, field: str):
4856
if key.endswith("dead_letter"):
4957
self.dead.pop(field, None)
58+
elif key.endswith("retry_policy"):
59+
self.retry_policies.pop(field, None)
5060

5161
async def hlen(self, key: str) -> int:
5262
if key.endswith("dead_letter"):
5363
return len(self.dead)
64+
if key.endswith("retry_policy"):
65+
return len(self.retry_policies)
5466
return len(self.results)
5567

5668
async def expire(self, key: str, ttl: int):
@@ -59,6 +71,18 @@ async def expire(self, key: str, ttl: int):
5971
async def delete(self, key: str):
6072
if key.endswith("dead_letter"):
6173
self.dead.clear()
74+
async def lpush(self, key: str, value: str):
75+
if key.endswith("dead_letter:audit"):
76+
self.audit.insert(0, value)
77+
78+
async def ltrim(self, key: str, start: int, stop: int):
79+
if key.endswith("dead_letter:audit"):
80+
self.audit = self.audit[start : stop + 1]
81+
82+
async def lrange(self, key: str, start: int, stop: int):
83+
if key.endswith("dead_letter:audit"):
84+
return self.audit[start : stop + 1]
85+
return []
6286

6387

6488
@pytest.mark.asyncio
@@ -119,3 +143,15 @@ async def test_dead_letter_and_heartbeat_helpers():
119143
await redis.write_heartbeat(client, "worker:1", {"status": "ok"}, ttl=60)
120144
beats = await redis.list_heartbeats(client)
121145
assert beats["worker:1"]["status"] == "ok"
146+
147+
await redis.append_dead_letter_audit(client, {"event": "record", "task_id": "abc"}, max_entries=10)
148+
audit_entries = await redis.fetch_dead_letter_audit(client, limit=1)
149+
assert audit_entries[0]["task_id"] == "abc"
150+
151+
await redis.set_retry_policy(client, "job", {"timeout": 123, "backoff_base": 2.5})
152+
policy = await redis.get_retry_policy(client, "job")
153+
assert policy == {"timeout": 123, "backoff_base": 2.5}
154+
policies = await redis.list_retry_policies(client)
155+
assert "job" in policies
156+
await redis.delete_retry_policy(client, "job")
157+
assert await redis.get_retry_policy(client, "job") is None

0 commit comments

Comments
 (0)