Skip to content

Commit 605e062

Browse files
committed
refactor(app): move task queue startup to lifespan
1 parent ec034d0 commit 605e062

File tree

2 files changed

+91
-92
lines changed

2 files changed

+91
-92
lines changed

app.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from datetime import datetime, timedelta
99
from typing import Any
1010

11+
from contextlib import asynccontextmanager
12+
1113
from fastapi import Depends, FastAPI, HTTPException, WebSocket, WebSocketDisconnect
1214
from fastapi.responses import PlainTextResponse
1315
from pydantic import BaseModel, model_validator
@@ -68,36 +70,39 @@
6870
from agent_pm.storage.tasks import TaskStatus, get_task_queue
6971
from agent_pm.tools import registry
7072

73+
@asynccontextmanager
74+
async def lifespan(_app: FastAPI):
75+
global _task_queue
76+
_task_queue = await get_task_queue()
77+
await _task_queue.start()
78+
logger.info("Agent PM service started")
79+
try:
80+
yield
81+
finally:
82+
if _task_queue:
83+
await _task_queue.stop()
84+
logger.info("Agent PM service stopped")
85+
86+
87+
lifespan_app = FastAPI(title="Agent PM", version="0.1.0", lifespan=lifespan)
88+
7189
if settings.log_format == "json":
7290
configure_structured_logging()
7391
else:
7492
configure_logging(settings.trace_dir)
7593

7694
logger = logging.getLogger(__name__)
77-
app = FastAPI(title="Agent PM", version="0.1.0")
7895
_jira_lock = asyncio.Lock()
7996
_task_queue = None
8097

81-
plugin_registry.attach_app(app)
98+
plugin_registry.attach_app(lifespan_app)
8299

83100

84101
class FollowupUpdate(BaseModel):
85102
status: str
86103

87104

88-
@app.on_event("startup")
89-
async def startup_event():
90-
global _task_queue
91-
_task_queue = await get_task_queue()
92-
await _task_queue.start()
93-
logger.info("Agent PM service started")
94-
95-
96-
@app.on_event("shutdown")
97-
async def shutdown_event():
98-
if _task_queue:
99-
await _task_queue.stop()
100-
logger.info("Agent PM service stopped")
105+
app = lifespan_app
101106

102107

103108
async def ensure_project_allowed(plan: TicketPlan) -> TicketPlan:
@@ -435,10 +440,9 @@ async def metrics() -> PlainTextResponse:
435440
@app.get("/tasks")
436441
async def list_tasks(status: str | None = None, limit: int = 50, _admin_key: AdminKeyDep = None) -> dict[str, Any]:
437442
"""List all tasks with optional status filter."""
438-
if not _task_queue:
439-
raise HTTPException(status_code=503, detail="Task queue not initialized")
443+
task_queue = await get_task_queue()
440444
task_status = TaskStatus(status) if status else None
441-
tasks = await _task_queue.list_tasks(status=task_status, limit=limit)
445+
tasks = await task_queue.list_tasks(status=task_status, limit=limit)
442446
return {
443447
"tasks": [
444448
{
@@ -463,9 +467,8 @@ async def list_dead_letter(
463467
error_type: str | None = None,
464468
_admin_key: AdminKeyDep = None,
465469
) -> dict[str, Any]:
466-
if not _task_queue:
467-
raise HTTPException(status_code=503, detail="Task queue not initialized")
468-
items, total = await _task_queue.list_dead_letters(
470+
task_queue = await get_task_queue()
471+
items, total = await task_queue.list_dead_letters(
469472
limit=limit, offset=offset, workflow_id=workflow_id, error_type=error_type
470473
)
471474
return {
@@ -487,48 +490,43 @@ async def list_dead_letter(
487490

488491
@app.get("/tasks/dead-letter/{task_id}")
489492
async def get_dead_letter(task_id: str, _admin_key: AdminKeyDep = None) -> dict[str, Any]:
490-
if not _task_queue:
491-
raise HTTPException(status_code=503, detail="Task queue not initialized")
492-
item = await _task_queue.get_dead_letter(task_id)
493+
task_queue = await get_task_queue()
494+
item = await task_queue.get_dead_letter(task_id)
493495
if not item:
494496
raise HTTPException(status_code=404, detail="Dead-letter task not found")
495497
return item
496498

497499

498500
@app.delete("/tasks/dead-letter/{task_id}")
499501
async def delete_dead_letter(task_id: str, _admin_key: AdminKeyDep = None) -> dict[str, Any]:
500-
if not _task_queue:
501-
raise HTTPException(status_code=503, detail="Task queue not initialized")
502-
await _task_queue.delete_dead_letter(task_id)
502+
task_queue = await get_task_queue()
503+
await task_queue.delete_dead_letter(task_id)
503504
return {"task_id": task_id, "status": "deleted"}
504505

505506

506507
@app.delete("/tasks/dead-letter")
507508
async def purge_dead_letters(
508509
older_than_minutes: int | None = None, _admin_key: AdminKeyDep = None
509510
) -> dict[str, int]:
510-
if not _task_queue:
511-
raise HTTPException(status_code=503, detail="Task queue not initialized")
511+
task_queue = await get_task_queue()
512512
if older_than_minutes is None:
513-
deleted = await _task_queue.purge_dead_letters()
513+
deleted = await task_queue.purge_dead_letters()
514514
else:
515-
deleted = await _task_queue.purge_dead_letters_older_than(timedelta(minutes=older_than_minutes))
515+
deleted = await task_queue.purge_dead_letters_older_than(timedelta(minutes=older_than_minutes))
516516
return {"deleted": deleted}
517517

518518

519519
@app.get("/tasks/workers")
520520
async def worker_status(_admin_key: AdminKeyDep = None) -> dict[str, Any]:
521-
if not _task_queue:
522-
raise HTTPException(status_code=503, detail="Task queue not initialized")
523-
return {"workers": await _task_queue.worker_heartbeats()}
521+
task_queue = await get_task_queue()
522+
return {"workers": await task_queue.worker_heartbeats()}
524523

525524

526525
@app.get("/tasks/{task_id}")
527526
async def get_task(task_id: str, _admin_key: AdminKeyDep = None) -> dict[str, Any]:
528527
"""Get task status by ID."""
529-
if not _task_queue:
530-
raise HTTPException(status_code=503, detail="Task queue not initialized")
531-
task = await _task_queue.get_task(task_id)
528+
task_queue = await get_task_queue()
529+
task = await task_queue.get_task(task_id)
532530
if not task:
533531
raise HTTPException(status_code=404, detail="Task not found")
534532
return {
@@ -545,9 +543,8 @@ async def get_task(task_id: str, _admin_key: AdminKeyDep = None) -> dict[str, An
545543

546544
@app.post("/tasks/dead-letter/{task_id}/requeue")
547545
async def requeue_dead_letter(task_id: str, _admin_key: AdminKeyDep = None) -> dict[str, Any]:
548-
if not _task_queue:
549-
raise HTTPException(status_code=503, detail="Task queue not initialized")
550-
payload = await _task_queue.requeue_dead_letter(task_id)
546+
task_queue = await get_task_queue()
547+
payload = await task_queue.requeue_dead_letter(task_id)
551548
if payload is None:
552549
raise HTTPException(status_code=404, detail="Dead-letter task not found")
553550
return {"task_id": payload.get("task_id", task_id), "status": "requeued"}

tests/test_tasks_api.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
async def _create_client():
1414
transport = ASGITransport(app=app)
1515
client = AsyncClient(transport=transport, base_url="http://test")
16-
await app_module.startup_event()
16+
await app.router.startup()
1717

1818
async def cleanup():
19+
await app.router.shutdown()
1920
await client.aclose()
20-
await app_module.shutdown_event()
2121

2222
client.cleanup = cleanup # type: ignore[attr-defined]
2323
return client
@@ -30,16 +30,14 @@ async def test_tasks_admin_endpoints_with_memory_backend(monkeypatch):
3030

3131
try:
3232
client = await _create_client()
33-
try:
34-
response = await client.get("/tasks/dead-letter")
35-
assert response.status_code == 200
36-
assert response.json()["dead_letter"] == []
37-
38-
worker_resp = await client.get("/tasks/workers")
39-
assert worker_resp.status_code == 200
40-
assert worker_resp.json() == {"workers": {}}
41-
finally:
42-
await client.cleanup()
33+
response = await client.get("/tasks/dead-letter")
34+
assert response.status_code == 200
35+
assert response.json()["dead_letter"] == []
36+
37+
worker_resp = await client.get("/tasks/workers")
38+
assert worker_resp.status_code == 200
39+
assert worker_resp.json() == {"workers": {}}
40+
await client.cleanup()
4341
finally:
4442
app.dependency_overrides.clear()
4543

@@ -116,48 +114,52 @@ async def test_tasks_admin_endpoints_surface_queue_data(monkeypatch):
116114

117115
try:
118116
client = await _create_client()
117+
original_backend = settings.task_queue_backend
118+
stub = StubQueue()
119+
original_get_queue = app_module.get_task_queue
120+
original_queue = app_module._task_queue
121+
122+
async def fake_get_queue():
123+
return stub
124+
125+
monkeypatch.setattr(app_module, "get_task_queue", fake_get_queue)
126+
settings.task_queue_backend = "memory"
119127
try:
120-
original_backend = settings.task_queue_backend
121-
original_queue = app_module._task_queue
122-
stub = StubQueue()
123-
app_module._task_queue = stub
124-
settings.task_queue_backend = "memory"
125-
try:
126-
dead_resp = await client.get(
127-
"/tasks/dead-letter", params={"limit": 5, "offset": 0, "workflow_id": "plan-123"}
128-
)
129-
assert dead_resp.status_code == 200
130-
assert dead_resp.json()["dead_letter"][0]["task_id"] == "dead-1"
131-
assert stub.limit == 5
132-
assert stub.workflow_id == "plan-123"
133-
134-
del_resp = await client.delete("/tasks/dead-letter/dead-1")
135-
assert del_resp.status_code == 200
136-
assert stub.deleted == "dead-1"
137-
138-
detail_resp = await client.get("/tasks/dead-letter/dead-1")
139-
assert detail_resp.status_code == 200
140-
assert detail_resp.json()["task_id"] == "dead-1"
141-
142-
worker_resp = await client.get("/tasks/workers")
143-
assert worker_resp.status_code == 200
144-
assert worker_resp.json() == {"workers": {"worker:1": {"status": "ok"}}}
145-
146-
requeue_resp = await client.post("/tasks/dead-letter/dead-1/requeue")
147-
assert requeue_resp.status_code == 200
148-
assert stub.requeued == "dead-1"
149-
assert requeue_resp.json()["status"] == "requeued"
150-
151-
purge_resp = await client.delete("/tasks/dead-letter")
152-
assert purge_resp.status_code == 200
153-
assert purge_resp.json()["deleted"] == 1
154-
155-
purge_resp_age = await client.delete("/tasks/dead-letter", params={"older_than_minutes": 10})
156-
assert purge_resp_age.status_code == 200
157-
finally:
158-
settings.task_queue_backend = original_backend
159-
app_module._task_queue = original_queue
128+
dead_resp = await client.get(
129+
"/tasks/dead-letter", params={"limit": 5, "offset": 0, "workflow_id": "plan-123"}
130+
)
131+
assert dead_resp.status_code == 200
132+
assert dead_resp.json()["dead_letter"][0]["task_id"] == "dead-1"
133+
assert stub.limit == 5
134+
assert stub.workflow_id == "plan-123"
135+
136+
del_resp = await client.delete("/tasks/dead-letter/dead-1")
137+
assert del_resp.status_code == 200
138+
assert stub.deleted == "dead-1"
139+
140+
detail_resp = await client.get("/tasks/dead-letter/dead-1")
141+
assert detail_resp.status_code == 200
142+
assert detail_resp.json()["task_id"] == "dead-1"
143+
144+
worker_resp = await client.get("/tasks/workers")
145+
assert worker_resp.status_code == 200
146+
assert worker_resp.json() == {"workers": {"worker:1": {"status": "ok"}}}
147+
148+
requeue_resp = await client.post("/tasks/dead-letter/dead-1/requeue")
149+
assert requeue_resp.status_code == 200
150+
assert stub.requeued == "dead-1"
151+
assert requeue_resp.json()["status"] == "requeued"
152+
153+
purge_resp = await client.delete("/tasks/dead-letter")
154+
assert purge_resp.status_code == 200
155+
assert purge_resp.json()["deleted"] == 1
156+
157+
purge_resp_age = await client.delete("/tasks/dead-letter", params={"older_than_minutes": 10})
158+
assert purge_resp_age.status_code == 200
160159
finally:
160+
settings.task_queue_backend = original_backend
161+
app_module._task_queue = original_queue
162+
monkeypatch.setattr(app_module, "get_task_queue", original_get_queue)
161163
await client.cleanup()
162164
finally:
163165
app.dependency_overrides.clear()

0 commit comments

Comments
 (0)