Skip to content

Commit b06cc59

Browse files
committed
feat: migrate task queue to native redis client
1 parent 7c3c8cc commit b06cc59

File tree

5 files changed

+86
-96
lines changed

5 files changed

+86
-96
lines changed

agent_pm/storage/redis.py

Lines changed: 35 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,59 @@
1-
"""Redis-backed distributed task queue using ARQ."""
1+
"""Redis-backed task queue using redis.asyncio."""
22

33
from __future__ import annotations
44

5+
import json
56
import logging
7+
import uuid
68
from typing import Any
79

8-
from arq import create_pool
9-
from arq.connections import ArqRedis, RedisSettings
10+
import redis.asyncio as redis
1011

1112
from agent_pm.settings import settings
1213

1314
logger = logging.getLogger(__name__)
1415

16+
QUEUE_KEY = "agent_pm:tasks"
17+
RESULT_KEY = "agent_pm:results"
1518

16-
async def get_redis_pool() -> ArqRedis:
17-
"""Get or create Redis connection pool for ARQ."""
18-
# Parse Redis URL
19-
url = settings.redis_url
20-
if url.startswith("redis://"):
21-
host_port = url.replace("redis://", "").split("/")[0]
22-
host, port = host_port.split(":") if ":" in host_port else (host_port, 6379)
23-
else:
24-
host, port = "localhost", 6379
2519

26-
redis_settings = RedisSettings(host=host, port=int(port))
27-
return await create_pool(redis_settings)
20+
def _queue_key() -> str:
21+
return QUEUE_KEY
2822

2923

30-
async def enqueue_task(
31-
pool: ArqRedis,
32-
function_name: str,
33-
*args: Any,
34-
**kwargs: Any,
35-
) -> str:
36-
"""Enqueue a task to Redis queue.
24+
def _result_key() -> str:
25+
return RESULT_KEY
3726

38-
Returns:
39-
Job ID
40-
"""
41-
job = await pool.enqueue_job(function_name, *args, **kwargs)
42-
logger.info("Enqueued task: %s (job_id=%s)", function_name, job.job_id)
43-
return job.job_id
4427

28+
async def get_redis_client() -> redis.Redis:
29+
return redis.from_url(settings.redis_url, decode_responses=True)
4530

46-
async def get_job_status(pool: ArqRedis, job_id: str) -> dict[str, Any] | None:
47-
"""Get status of a job by ID."""
48-
job_result = await pool.job_status(job_id)
49-
if not job_result:
50-
return None
5131

52-
return {
53-
"job_id": job_id,
54-
"status": job_result.job_status.value if job_result.job_status else "unknown",
55-
"result": job_result.result,
56-
"start_time": (job_result.start_time.isoformat() if job_result.start_time else None),
57-
"finish_time": (job_result.finish_time.isoformat() if job_result.finish_time else None),
58-
}
32+
async def enqueue_task(
33+
client: redis.Redis,
34+
name: str,
35+
payload: dict[str, Any],
36+
) -> str:
37+
task_id = payload.setdefault("task_id", uuid.uuid4().hex)
38+
payload.setdefault("name", name)
39+
await client.rpush(_queue_key(), json.dumps(payload))
40+
logger.info("Queued redis task %s (%s)", name, task_id)
41+
return task_id
42+
5943

44+
async def pop_task(client: redis.Redis) -> dict[str, Any] | None:
45+
item = await client.lpop(_queue_key())
46+
if not item:
47+
return None
48+
return json.loads(item)
6049

61-
# ARQ worker functions (to be imported by worker process)
62-
async def example_background_task(ctx: dict, plan_id: str, user: str) -> dict:
63-
"""Example background task function."""
64-
logger.info("Processing background task for plan_id=%s user=%s", plan_id, user)
65-
# Perform async work here
66-
return {"plan_id": plan_id, "status": "processed"}
6750

51+
async def set_task_result(client: redis.Redis, task_id: str, result: dict[str, Any]) -> None:
52+
await client.hset(_result_key(), task_id, json.dumps(result))
6853

69-
# Worker class for ARQ
70-
class WorkerSettings:
71-
"""ARQ worker configuration."""
7254

73-
functions = [example_background_task]
74-
redis_settings = RedisSettings(host="localhost", port=6379)
75-
max_jobs = 10
76-
job_timeout = 300 # 5 minutes
55+
async def get_task_result(client: redis.Redis, task_id: str) -> dict[str, Any] | None:
56+
item = await client.hget(_result_key(), task_id)
57+
if not item:
58+
return None
59+
return json.loads(item)

agent_pm/storage/tasks.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ..settings import settings
1616
from ..utils.datetime import utc_now
17-
from .redis import enqueue_task as redis_enqueue_task, get_redis_pool
17+
from .redis import enqueue_task as redis_enqueue_task, get_redis_client
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -187,7 +187,7 @@ async def get_task_queue() -> TaskQueue:
187187
if _task_queue is None:
188188
backend = settings.task_queue_backend
189189
if backend == "redis":
190-
pool = await get_redis_pool()
190+
client = await get_redis_client()
191191

192192
class RedisTaskQueue(TaskQueue):
193193
async def enqueue( # type: ignore[override]
@@ -198,7 +198,14 @@ async def enqueue( # type: ignore[override]
198198
max_retries: int = 3,
199199
**kwargs: Any,
200200
) -> str:
201-
return await redis_enqueue_task(pool, name, *args, **kwargs)
201+
payload = {
202+
"task_id": uuid.uuid4().hex,
203+
"name": name,
204+
"args": args,
205+
"kwargs": kwargs,
206+
"max_retries": max_retries,
207+
}
208+
return await redis_enqueue_task(client, name, payload)
202209

203210
_task_queue = RedisTaskQueue(max_workers=settings.task_queue_workers)
204211
else:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ dependencies = [
2121
"prometheus-client",
2222
"sqlalchemy>=2.0",
2323
"asyncpg",
24-
"redis>=5.0",
25-
"arq",
24+
"redis[async]>=5.0",
25+
"rq",
2626
"opentelemetry-api",
2727
"opentelemetry-sdk",
2828
"opentelemetry-instrumentation-fastapi",

tests/storage/test_redis.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,54 @@
1-
import asyncio
1+
import json
22

33
import pytest
44

55
from agent_pm.storage import redis
66

77

8-
class FakeJobResult:
9-
def __init__(self, job_id: str, status: str):
10-
self.job_status = type("JobStatus", (), {"value": status})()
11-
self.result = {"ok": True}
12-
self.start_time = None
13-
self.finish_time = None
14-
15-
16-
class FakeRedisPool:
8+
class DummyRedis:
179
def __init__(self):
18-
self.jobs: dict[str, tuple[str, tuple, dict]] = {}
10+
self.items: list[str] = []
11+
self.results: dict[str, str] = {}
1912

20-
async def enqueue_job(self, function_name, *args, **kwargs):
21-
job_id = f"job-{len(self.jobs) + 1}"
22-
self.jobs[job_id] = (function_name, args, kwargs)
23-
return type("Job", (), {"job_id": job_id})()
13+
async def rpush(self, key: str, value: str) -> None:
14+
self.items.append(value)
2415

25-
async def job_status(self, job_id):
26-
if job_id not in self.jobs:
16+
async def lpop(self, key: str):
17+
if not self.items:
2718
return None
28-
return FakeJobResult(job_id, "complete")
19+
return self.items.pop(0)
20+
21+
async def hset(self, key: str, field: str, value: str) -> None:
22+
self.results[field] = value
23+
24+
async def hget(self, key: str, field: str):
25+
return self.results.get(field)
2926

3027

3128
@pytest.mark.asyncio
32-
async def test_enqueue_task_records_job(monkeypatch):
33-
pool = FakeRedisPool()
29+
async def test_enqueue_task_pushes_payload():
30+
client = DummyRedis()
31+
payload = {"args": [1], "kwargs": {"foo": "bar"}}
3432

35-
job_id = await redis.enqueue_task(pool, "do_work", 1, foo="bar")
33+
task_id = await redis.enqueue_task(client, "do_work", payload)
3634

37-
assert job_id in pool.jobs
38-
function_name, args, kwargs = pool.jobs[job_id]
39-
assert function_name == "do_work"
40-
assert args == (1,)
41-
assert kwargs == {"foo": "bar"}
35+
assert isinstance(task_id, str)
36+
stored = json.loads(client.items[0])
37+
assert stored["name"] == "do_work"
38+
assert stored["args"] == [1]
39+
assert stored["kwargs"] == {"foo": "bar"}
4240

4341

4442
@pytest.mark.asyncio
45-
async def test_get_job_status_returns_dict():
46-
pool = FakeRedisPool()
47-
job_id = await redis.enqueue_task(pool, "do_work")
48-
49-
status = await redis.get_job_status(pool, job_id)
50-
assert status is not None
51-
assert status["job_id"] == job_id
52-
assert status["status"] == "complete"
53-
assert status["result"] == {"ok": True}
43+
async def test_pop_and_result_roundtrip():
44+
client = DummyRedis()
45+
payload = {"args": [], "kwargs": {}}
46+
task_id = await redis.enqueue_task(client, "noop", payload)
47+
48+
popped = await redis.pop_task(client)
49+
assert popped is not None
50+
assert popped["task_id"] == task_id
51+
52+
await redis.set_task_result(client, task_id, {"status": "done"})
53+
result = await redis.get_task_result(client, task_id)
54+
assert result == {"status": "done"}

uv.lock

Lines changed: 3 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)