Skip to content

Commit 3f0dec1

Browse files
committed
feat(queue): add redis worker orchestration
1 parent 14a3473 commit 3f0dec1

File tree

9 files changed

+444
-8
lines changed

9 files changed

+444
-8
lines changed

agent_pm/observability/metrics.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@
6565
labelnames=("source",),
6666
)
6767

68+
task_queue_enqueued_total = Counter(
69+
"task_queue_enqueued_total",
70+
"Tasks enqueued to background queue",
71+
labelnames=("queue",),
72+
)
73+
74+
task_queue_completed_total = Counter(
75+
"task_queue_completed_total",
76+
"Tasks completed grouped by status",
77+
labelnames=("queue", "status"),
78+
)
79+
80+
task_queue_latency_seconds = Histogram(
81+
"task_queue_latency_seconds",
82+
"Task execution latency in seconds",
83+
labelnames=("queue",),
84+
buckets=(0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60),
85+
)
86+
6887
plugin_hook_invocations_total = Counter(
6988
"plugin_hook_invocations_total",
7089
"Plugin hook invocation counts",
@@ -158,6 +177,18 @@ def record_client_call(client: str):
158177
client_latency_seconds.labels(client=client).observe(perf_counter() - start)
159178

160179

180+
def record_task_enqueued(queue: str) -> None:
181+
task_queue_enqueued_total.labels(queue=queue).inc()
182+
183+
184+
def record_task_completion(queue: str, status: str) -> None:
185+
task_queue_completed_total.labels(queue=queue, status=status).inc()
186+
187+
188+
def record_task_latency(queue: str, duration: float) -> None:
189+
task_queue_latency_seconds.labels(queue=queue).observe(duration)
190+
191+
161192
def latest_metrics() -> bytes:
162193
return generate_latest()
163194

@@ -175,5 +206,8 @@ def latest_metrics() -> bytes:
175206
"record_plugin_hook_invocation",
176207
"record_plugin_hook_failure",
177208
"record_client_call",
209+
"record_task_enqueued",
210+
"record_task_completion",
211+
"record_task_latency",
178212
"latest_metrics",
179213
]

agent_pm/settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def _parse_google_scopes(cls, value):
7171
log_format: str = Field("json", alias="LOG_FORMAT") # json or text
7272
task_queue_workers: int = Field(5, alias="TASK_QUEUE_WORKERS")
7373
task_queue_backend: Literal["memory", "redis"] = Field("memory", alias="TASK_QUEUE_BACKEND")
74+
task_queue_poll_interval: float = Field(0.2, alias="TASK_QUEUE_POLL_INTERVAL")
75+
task_queue_task_timeout: int = Field(300, alias="TASK_QUEUE_TASK_TIMEOUT")
76+
task_queue_retry_backoff_base: float = Field(2.0, alias="TASK_QUEUE_RETRY_BACKOFF_BASE")
77+
task_queue_retry_backoff_max: float = Field(60.0, alias="TASK_QUEUE_RETRY_BACKOFF_MAX")
78+
task_queue_worker_heartbeat_ttl: int = Field(60, alias="TASK_QUEUE_WORKER_HEARTBEAT_TTL")
7479
database_url: str | None = Field("sqlite+aiosqlite:///./data/agent_pm.db", alias="DATABASE_URL")
7580
database_echo: bool = Field(False, alias="DATABASE_ECHO")
7681
redis_url: str = Field("redis://localhost:6379", alias="REDIS_URL")

agent_pm/storage/redis.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
QUEUE_KEY = "agent_pm:tasks"
1717
RESULT_KEY = "agent_pm:results"
18+
DEAD_LETTER_KEY = "agent_pm:dead_letter"
19+
HEARTBEAT_KEY = "agent_pm:worker_heartbeats"
1820

1921

2022
def _queue_key() -> str:
@@ -25,6 +27,14 @@ def _result_key() -> str:
2527
return RESULT_KEY
2628

2729

30+
def _dead_letter_key() -> str:
31+
return DEAD_LETTER_KEY
32+
33+
34+
def _heartbeat_key() -> str:
35+
return HEARTBEAT_KEY
36+
37+
2838
async def get_redis_client() -> redis.Redis:
2939
return redis.from_url(settings.redis_url, decode_responses=True)
3040

@@ -57,3 +67,43 @@ async def get_task_result(client: redis.Redis, task_id: str) -> dict[str, Any] |
5767
if not item:
5868
return None
5969
return json.loads(item)
70+
71+
72+
async def record_dead_letter(client: redis.Redis, payload: dict[str, Any]) -> None:
73+
task_id = payload.get("task_id", uuid.uuid4().hex)
74+
await client.hset(_dead_letter_key(), task_id, json.dumps(payload))
75+
76+
77+
async def fetch_dead_letters(client: redis.Redis, limit: int = 100) -> list[dict[str, Any]]:
78+
items = await client.hgetall(_dead_letter_key())
79+
tasks: list[dict[str, Any]] = []
80+
for task_id, value in items.items():
81+
if len(tasks) >= limit:
82+
break
83+
try:
84+
data = json.loads(value)
85+
data.setdefault("task_id", task_id)
86+
tasks.append(data)
87+
except json.JSONDecodeError:
88+
tasks.append({"task_id": task_id, "raw": value})
89+
return tasks
90+
91+
92+
async def clear_dead_letter(client: redis.Redis, task_id: str) -> None:
93+
await client.hdel(_dead_letter_key(), task_id)
94+
95+
96+
async def write_heartbeat(client: redis.Redis, worker_id: str, payload: dict[str, Any], ttl: int) -> None:
97+
await client.hset(_heartbeat_key(), worker_id, json.dumps(payload))
98+
await client.expire(_heartbeat_key(), ttl)
99+
100+
101+
async def list_heartbeats(client: redis.Redis) -> dict[str, dict[str, Any]]:
102+
items = await client.hgetall(_heartbeat_key())
103+
heartbeats: dict[str, dict[str, Any]] = {}
104+
for worker_id, value in items.items():
105+
try:
106+
heartbeats[worker_id] = json.loads(value)
107+
except json.JSONDecodeError:
108+
heartbeats[worker_id] = {"raw": value}
109+
return heartbeats

agent_pm/storage/tasks.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,24 @@
1212
from enum import Enum
1313
from typing import Any
1414

15+
from ..observability.metrics import (
16+
record_task_completion,
17+
record_task_enqueued,
18+
record_task_latency,
19+
)
1520
from ..settings import settings
1621
from ..utils.datetime import utc_now
17-
from .redis import enqueue_task as redis_enqueue_task
18-
from .redis import get_redis_client
22+
from .redis import (
23+
clear_dead_letter,
24+
enqueue_task as redis_enqueue_task,
25+
fetch_dead_letters,
26+
get_redis_client,
27+
list_heartbeats,
28+
pop_task,
29+
record_dead_letter,
30+
set_task_result,
31+
write_heartbeat,
32+
)
1933

2034
logger = logging.getLogger(__name__)
2135

@@ -61,6 +75,7 @@ def __init__(self, max_workers: int = 5):
6175
self.workers: list[asyncio.Task[None]] = []
6276
self.running = False
6377
self._lock = asyncio.Lock()
78+
self.queue_name = "memory"
6479

6580
async def start(self):
6681
"""Start background workers."""
@@ -102,6 +117,7 @@ async def enqueue(
102117
async with self._lock:
103118
self.queue.append(task)
104119
self.tasks[task_id] = task
120+
record_task_enqueued(self.queue_name)
105121
logger.info("Task enqueued: %s (id=%s)", name, task_id)
106122
return task_id
107123

@@ -153,6 +169,8 @@ async def _execute_task(self, task: Task):
153169
task.status = TaskStatus.COMPLETED
154170
task.result = result
155171
task.completed_at = utc_now()
172+
record_task_completion(self.queue_name, task.status.value)
173+
record_task_latency(self.queue_name, (task.completed_at - task.started_at).total_seconds())
156174
logger.info("Task completed: %s (id=%s)", task.name, task.task_id)
157175
except Exception as exc:
158176
task.error = str(exc)
@@ -175,8 +193,19 @@ async def _execute_task(self, task: Task):
175193
else:
176194
task.status = TaskStatus.FAILED
177195
task.completed_at = utc_now()
196+
record_task_completion(self.queue_name, task.status.value)
197+
record_task_latency(self.queue_name, (task.completed_at - task.started_at).total_seconds())
178198
logger.error("Task permanently failed: %s (id=%s)", task.name, task.task_id)
179199

200+
async def list_dead_letters(self, limit: int = 100) -> list[dict[str, Any]]:
201+
return []
202+
203+
async def delete_dead_letter(self, task_id: str) -> None:
204+
return None
205+
206+
async def worker_heartbeats(self) -> dict[str, Any]:
207+
return {}
208+
180209

181210
# Global task queue instance
182211
_task_queue: TaskQueue | None = None
@@ -191,6 +220,20 @@ async def get_task_queue() -> TaskQueue:
191220
client = await get_redis_client()
192221

193222
class RedisTaskQueue(TaskQueue):
223+
def __init__(self, max_workers: int = 5):
224+
super().__init__(max_workers=max_workers)
225+
self.queue_name = "redis"
226+
self._redis = client
227+
self._registry: dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {}
228+
self._poll_interval = settings.task_queue_poll_interval
229+
self._task_timeout = settings.task_queue_task_timeout
230+
self._backoff_base = settings.task_queue_retry_backoff_base
231+
self._backoff_max = settings.task_queue_retry_backoff_max
232+
self._heartbeat_ttl = settings.task_queue_worker_heartbeat_ttl
233+
234+
def register(self, name: str, coro_fn: Callable[..., Coroutine[Any, Any, Any]]) -> None:
235+
self._registry[name] = coro_fn
236+
194237
async def enqueue( # type: ignore[override]
195238
self,
196239
name: str,
@@ -199,14 +242,101 @@ async def enqueue( # type: ignore[override]
199242
max_retries: int = 3,
200243
**kwargs: Any,
201244
) -> str:
245+
self.register(name, coro_fn)
202246
payload = {
203247
"task_id": uuid.uuid4().hex,
204248
"name": name,
205249
"args": args,
206250
"kwargs": kwargs,
207251
"max_retries": max_retries,
252+
"enqueued_at": utc_now().isoformat(),
208253
}
209-
return await redis_enqueue_task(client, name, payload)
254+
await redis_enqueue_task(self._redis, name, payload)
255+
record_task_enqueued(self.queue_name)
256+
logger.info("Redis task enqueued: %s (id=%s)", name, payload["task_id"])
257+
return payload["task_id"]
258+
259+
async def pop(self) -> dict[str, Any] | None:
260+
payload = await pop_task(self._redis)
261+
if not payload:
262+
return None
263+
return payload
264+
265+
async def _worker(self, worker_id: int):
266+
logger.info("Redis worker %d started", worker_id)
267+
while self.running:
268+
payload = await self.pop()
269+
if not payload:
270+
await asyncio.sleep(self._poll_interval)
271+
continue
272+
273+
task_id = payload.get("task_id", "unknown")
274+
name = payload.get("name")
275+
coro_fn = self._registry.get(name)
276+
if coro_fn is None:
277+
logger.error("No registered task callable for %s", name)
278+
await record_dead_letter(
279+
self._redis,
280+
{**payload, "error": "missing_callable", "worker_id": worker_id},
281+
)
282+
record_task_completion(self.queue_name, TaskStatus.FAILED.value)
283+
continue
284+
285+
retries = payload.get("retry_count", 0)
286+
max_retries = payload.get("max_retries", 3)
287+
288+
start = utc_now()
289+
try:
290+
result = await asyncio.wait_for(
291+
coro_fn(*payload.get("args", ()), **payload.get("kwargs", {})),
292+
timeout=self._task_timeout,
293+
)
294+
except asyncio.TimeoutError:
295+
payload["retry_count"] = retries + 1
296+
payload["last_error"] = "timeout"
297+
payload["worker_id"] = worker_id
298+
await record_dead_letter(self._redis, payload)
299+
record_task_completion(self.queue_name, TaskStatus.FAILED.value)
300+
record_task_latency(self.queue_name, (utc_now() - start).total_seconds())
301+
continue
302+
except Exception as exc: # pylint: disable=broad-except
303+
retries += 1
304+
payload["retry_count"] = retries
305+
payload["last_error"] = str(exc)
306+
if retries >= max_retries:
307+
payload["worker_id"] = worker_id
308+
await record_dead_letter(self._redis, payload)
309+
record_task_completion(self.queue_name, TaskStatus.FAILED.value)
310+
record_task_latency(self.queue_name, (utc_now() - start).total_seconds())
311+
continue
312+
313+
backoff = min(self._backoff_base**retries, self._backoff_max)
314+
await asyncio.sleep(backoff)
315+
await redis_enqueue_task(self._redis, name, payload)
316+
continue
317+
318+
await set_task_result(self._redis, task_id, {"status": "completed", "result": result})
319+
record_task_completion(self.queue_name, TaskStatus.COMPLETED.value)
320+
record_task_latency(self.queue_name, (utc_now() - start).total_seconds())
321+
322+
heartbeat_payload = {
323+
"worker_id": worker_id,
324+
"task_id": task_id,
325+
"name": name,
326+
"completed_at": utc_now().isoformat(),
327+
}
328+
await write_heartbeat(self._redis, f"worker:{worker_id}", heartbeat_payload, self._heartbeat_ttl)
329+
330+
logger.info("Redis worker %d stopped", worker_id)
331+
332+
async def list_dead_letters(self, limit: int = 100) -> list[dict[str, Any]]:
333+
return await fetch_dead_letters(self._redis, limit)
334+
335+
async def delete_dead_letter(self, task_id: str) -> None:
336+
await clear_dead_letter(self._redis, task_id)
337+
338+
async def worker_heartbeats(self) -> dict[str, dict[str, Any]]:
339+
return await list_heartbeats(self._redis)
210340

211341
_task_queue = RedisTaskQueue(max_workers=settings.task_queue_workers)
212342
else:

app.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@
4141
)
4242
from agent_pm.observability.export import schedule_trace_export
4343
from agent_pm.observability.logging import configure_logging
44-
from agent_pm.observability.metrics import latest_metrics, record_alignment_export
44+
from agent_pm.observability.metrics import (
45+
latest_metrics,
46+
record_alignment_export,
47+
)
4548
from agent_pm.observability.structured import (
4649
configure_structured_logging,
4750
get_correlation_id,
@@ -86,14 +89,13 @@ class FollowupUpdate(BaseModel):
8689
async def startup_event():
8790
global _task_queue
8891
_task_queue = await get_task_queue()
89-
if settings.task_queue_backend == "memory":
90-
await _task_queue.start()
92+
await _task_queue.start()
9193
logger.info("Agent PM service started")
9294

9395

9496
@app.on_event("shutdown")
9597
async def shutdown_event():
96-
if settings.task_queue_backend == "memory" and _task_queue:
98+
if _task_queue:
9799
await _task_queue.stop()
98100
logger.info("Agent PM service stopped")
99101

@@ -473,6 +475,29 @@ async def list_tasks(status: str | None = None, limit: int = 50, _admin_key: Adm
473475
}
474476

475477

478+
@app.get("/tasks/dead-letter")
479+
async def list_dead_letter(limit: int = 50, _admin_key: AdminKeyDep = None) -> dict[str, Any]:
480+
if not _task_queue:
481+
raise HTTPException(status_code=503, detail="Task queue not initialized")
482+
items = await _task_queue.list_dead_letters(limit)
483+
return {"dead_letter": items, "total": len(items)}
484+
485+
486+
@app.delete("/tasks/dead-letter/{task_id}")
487+
async def delete_dead_letter(task_id: str, _admin_key: AdminKeyDep = None) -> dict[str, Any]:
488+
if not _task_queue:
489+
raise HTTPException(status_code=503, detail="Task queue not initialized")
490+
await _task_queue.delete_dead_letter(task_id)
491+
return {"task_id": task_id, "status": "deleted"}
492+
493+
494+
@app.get("/tasks/workers")
495+
async def worker_status(_admin_key: AdminKeyDep = None) -> dict[str, Any]:
496+
if not _task_queue:
497+
raise HTTPException(status_code=503, detail="Task queue not initialized")
498+
return {"workers": await _task_queue.worker_heartbeats()}
499+
500+
476501
@app.post("/prd/{plan_id}/versions")
477502
async def create_prd_version(
478503
plan_id: str,

0 commit comments

Comments
 (0)