88from datetime import datetime , timedelta
99from typing import Any
1010
11+ from contextlib import asynccontextmanager
12+
1113from fastapi import Depends , FastAPI , HTTPException , WebSocket , WebSocketDisconnect
1214from fastapi .responses import PlainTextResponse
1315from pydantic import BaseModel , model_validator
6870from agent_pm .storage .tasks import TaskStatus , get_task_queue
6971from agent_pm .tools import registry
7072
73+ @asynccontextmanager
74+ async def lifespan (_app : FastAPI ):
75+ global _task_queue
76+ _task_queue = await get_task_queue ()
77+ await _task_queue .start ()
78+ logger .info ("Agent PM service started" )
79+ try :
80+ yield
81+ finally :
82+ if _task_queue :
83+ await _task_queue .stop ()
84+ logger .info ("Agent PM service stopped" )
85+
86+
87+ lifespan_app = FastAPI (title = "Agent PM" , version = "0.1.0" , lifespan = lifespan )
88+
7189if settings .log_format == "json" :
7290 configure_structured_logging ()
7391else :
7492 configure_logging (settings .trace_dir )
7593
7694logger = logging .getLogger (__name__ )
77- app = FastAPI (title = "Agent PM" , version = "0.1.0" )
7895_jira_lock = asyncio .Lock ()
7996_task_queue = None
8097
81- plugin_registry .attach_app (app )
98+ plugin_registry .attach_app (lifespan_app )
8299
83100
84101class FollowupUpdate (BaseModel ):
85102 status : str
86103
87104
88- @app .on_event ("startup" )
89- async def startup_event ():
90- global _task_queue
91- _task_queue = await get_task_queue ()
92- await _task_queue .start ()
93- logger .info ("Agent PM service started" )
94-
95-
96- @app .on_event ("shutdown" )
97- async def shutdown_event ():
98- if _task_queue :
99- await _task_queue .stop ()
100- logger .info ("Agent PM service stopped" )
105+ app = lifespan_app
101106
102107
103108async def ensure_project_allowed (plan : TicketPlan ) -> TicketPlan :
@@ -435,10 +440,9 @@ async def metrics() -> PlainTextResponse:
435440@app .get ("/tasks" )
436441async def list_tasks (status : str | None = None , limit : int = 50 , _admin_key : AdminKeyDep = None ) -> dict [str , Any ]:
437442 """List all tasks with optional status filter."""
438- if not _task_queue :
439- raise HTTPException (status_code = 503 , detail = "Task queue not initialized" )
443+ task_queue = await get_task_queue ()
440444 task_status = TaskStatus (status ) if status else None
441- tasks = await _task_queue .list_tasks (status = task_status , limit = limit )
445+ tasks = await task_queue .list_tasks (status = task_status , limit = limit )
442446 return {
443447 "tasks" : [
444448 {
@@ -463,9 +467,8 @@ async def list_dead_letter(
463467 error_type : str | None = None ,
464468 _admin_key : AdminKeyDep = None ,
465469) -> dict [str , Any ]:
466- if not _task_queue :
467- raise HTTPException (status_code = 503 , detail = "Task queue not initialized" )
468- items , total = await _task_queue .list_dead_letters (
470+ task_queue = await get_task_queue ()
471+ items , total = await task_queue .list_dead_letters (
469472 limit = limit , offset = offset , workflow_id = workflow_id , error_type = error_type
470473 )
471474 return {
@@ -487,48 +490,43 @@ async def list_dead_letter(
487490
488491@app .get ("/tasks/dead-letter/{task_id}" )
489492async def get_dead_letter (task_id : str , _admin_key : AdminKeyDep = None ) -> dict [str , Any ]:
490- if not _task_queue :
491- raise HTTPException (status_code = 503 , detail = "Task queue not initialized" )
492- item = await _task_queue .get_dead_letter (task_id )
493+ task_queue = await get_task_queue ()
494+ item = await task_queue .get_dead_letter (task_id )
493495 if not item :
494496 raise HTTPException (status_code = 404 , detail = "Dead-letter task not found" )
495497 return item
496498
497499
498500@app .delete ("/tasks/dead-letter/{task_id}" )
499501async def delete_dead_letter (task_id : str , _admin_key : AdminKeyDep = None ) -> dict [str , Any ]:
500- if not _task_queue :
501- raise HTTPException (status_code = 503 , detail = "Task queue not initialized" )
502- await _task_queue .delete_dead_letter (task_id )
502+ task_queue = await get_task_queue ()
503+ await task_queue .delete_dead_letter (task_id )
503504 return {"task_id" : task_id , "status" : "deleted" }
504505
505506
506507@app .delete ("/tasks/dead-letter" )
507508async def purge_dead_letters (
508509 older_than_minutes : int | None = None , _admin_key : AdminKeyDep = None
509510) -> dict [str , int ]:
510- if not _task_queue :
511- raise HTTPException (status_code = 503 , detail = "Task queue not initialized" )
511+ task_queue = await get_task_queue ()
512512 if older_than_minutes is None :
513- deleted = await _task_queue .purge_dead_letters ()
513+ deleted = await task_queue .purge_dead_letters ()
514514 else :
515- deleted = await _task_queue .purge_dead_letters_older_than (timedelta (minutes = older_than_minutes ))
515+ deleted = await task_queue .purge_dead_letters_older_than (timedelta (minutes = older_than_minutes ))
516516 return {"deleted" : deleted }
517517
518518
519519@app .get ("/tasks/workers" )
520520async def worker_status (_admin_key : AdminKeyDep = None ) -> dict [str , Any ]:
521- if not _task_queue :
522- raise HTTPException (status_code = 503 , detail = "Task queue not initialized" )
523- return {"workers" : await _task_queue .worker_heartbeats ()}
521+ task_queue = await get_task_queue ()
522+ return {"workers" : await task_queue .worker_heartbeats ()}
524523
525524
526525@app .get ("/tasks/{task_id}" )
527526async def get_task (task_id : str , _admin_key : AdminKeyDep = None ) -> dict [str , Any ]:
528527 """Get task status by ID."""
529- if not _task_queue :
530- raise HTTPException (status_code = 503 , detail = "Task queue not initialized" )
531- task = await _task_queue .get_task (task_id )
528+ task_queue = await get_task_queue ()
529+ task = await task_queue .get_task (task_id )
532530 if not task :
533531 raise HTTPException (status_code = 404 , detail = "Task not found" )
534532 return {
@@ -545,9 +543,8 @@ async def get_task(task_id: str, _admin_key: AdminKeyDep = None) -> dict[str, An
545543
546544@app .post ("/tasks/dead-letter/{task_id}/requeue" )
547545async def requeue_dead_letter (task_id : str , _admin_key : AdminKeyDep = None ) -> dict [str , Any ]:
548- if not _task_queue :
549- raise HTTPException (status_code = 503 , detail = "Task queue not initialized" )
550- payload = await _task_queue .requeue_dead_letter (task_id )
546+ task_queue = await get_task_queue ()
547+ payload = await task_queue .requeue_dead_letter (task_id )
551548 if payload is None :
552549 raise HTTPException (status_code = 404 , detail = "Dead-letter task not found" )
553550 return {"task_id" : payload .get ("task_id" , task_id ), "status" : "requeued" }
0 commit comments