88"""
99
1010import argparse
11+ import asyncio
1112import logging
1213from contextlib import asynccontextmanager
1314from dataclasses import dataclass
@@ -44,6 +45,7 @@ class Args:
4445 check_collision : bool = False
4546
4647 autostart : bool = True
48+ timeout_health_check : float | None = None
4749
4850 wake_up_on_start : bool = True
4951 goto_sleep_on_stop : bool = True
@@ -54,7 +56,7 @@ class Args:
5456 localhost_only : bool | None = None
5557
5658
57- def create_app (args : Args ) -> FastAPI :
59+ def create_app (args : Args , health_check_event : asyncio . Event | None = None ) -> FastAPI :
5860 """Create and configure the FastAPI application."""
5961 localhost_only = (
6062 args .localhost_only
@@ -108,6 +110,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
108110
109111 app .include_router (router )
110112
113+ if health_check_event is not None :
114+
115+ @app .post ("/health-check" )
116+ async def health_check () -> dict [str , str ]:
117+ """Health check endpoint to reset the health check timer."""
118+ health_check_event .set ()
119+ return {"status" : "ok" }
120+
111121 app .add_middleware (
112122 CORSMiddleware ,
113123 allow_origins = ["*" ], # or restrict to your HF domain
@@ -142,8 +152,36 @@ def run_app(args: Args) -> None:
142152 """Run the FastAPI app with Uvicorn."""
143153 logging .basicConfig (level = logging .INFO )
144154
145- app = create_app (args )
146- uvicorn .run (app , host = args .fastapi_host , port = args .fastapi_port )
155+ health_check_event = asyncio .Event ()
156+ app = create_app (args , health_check_event )
157+
158+ config = uvicorn .Config (app , host = args .fastapi_host , port = args .fastapi_port )
159+ server = uvicorn .Server (config )
160+
161+ async def health_check_timeout (timeout_seconds : float ) -> None :
162+ while True :
163+ try :
164+ await asyncio .wait_for (
165+ health_check_event .wait (),
166+ timeout = timeout_seconds ,
167+ )
168+ health_check_event .clear ()
169+ except asyncio .TimeoutError :
170+ logging .warning ("Health check timeout reached, stopping app." )
171+ server .should_exit = True
172+ break
173+
174+ loop = asyncio .get_event_loop ()
175+ if args .timeout_health_check is not None :
176+ loop .create_task (health_check_timeout (args .timeout_health_check ))
177+
178+ try :
179+ loop .run_until_complete (server .serve ())
180+ except KeyboardInterrupt :
181+ logging .info ("Received Ctrl-C, shutting down gracefully." )
182+ finally :
183+ # Optional: additional cleanup here
184+ pass
147185
148186
149187def main () -> None :
@@ -198,6 +236,12 @@ def main() -> None:
198236 dest = "autostart" ,
199237 help = "Do not automatically start the daemon on launch (default: False)." ,
200238 )
239+ parser .add_argument (
240+ "--timeout-health-check" ,
241+ type = float ,
242+ default = None ,
243+ help = "Set the health check timeout in seconds (default: None)." ,
244+ )
201245 parser .add_argument (
202246 "--wake-up-on-start" ,
203247 action = "store_true" ,
0 commit comments