Skip to content

Commit 1b3cff2

Browse files
authored
Wrap Existing FastAPI Lifespan Handlers (#191)
Make the DBOS FastAPI lifespan handler wrap whatever handler already exists (if one does) instead of overriding it.
1 parent 9998cd8 commit 1b3cff2

File tree

3 files changed

+59
-20
lines changed

3 files changed

+59
-20
lines changed

dbos/_fastapi.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import uuid
2-
from typing import Any, Callable, cast
2+
from typing import Any, Callable, MutableMapping, cast
33

44
from fastapi import FastAPI
55
from fastapi import Request as FastAPIRequest
66
from fastapi.responses import JSONResponse
7-
from starlette.types import ASGIApp, Message, Receive, Scope, Send
7+
from starlette.types import ASGIApp, Receive, Scope, Send
88

99
from . import DBOS
1010
from ._context import (
@@ -61,15 +61,16 @@ def __init__(self, app: ASGIApp, dbos: DBOS):
6161

6262
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6363
if scope["type"] == "lifespan":
64-
while True:
65-
message = await receive()
66-
if message["type"] == "lifespan.startup":
64+
65+
async def wrapped_send(message: MutableMapping[str, Any]) -> None:
66+
if message["type"] == "lifespan.startup.complete":
6767
self.dbos._launch()
68-
await send({"type": "lifespan.startup.complete"})
69-
elif message["type"] == "lifespan.shutdown":
68+
elif message["type"] == "lifespan.shutdown.complete":
7069
self.dbos._destroy()
71-
await send({"type": "lifespan.shutdown.complete"})
72-
break
70+
await send(message)
71+
72+
# Call the original app with our wrapped functions
73+
await self.app(scope, receive, wrapped_send)
7374
else:
7475
await self.app(scope, receive, send)
7576

tests/conftest.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,7 @@ def dbos_fastapi(
118118
) -> Generator[Tuple[DBOS, FastAPI], Any, None]:
119119
DBOS.destroy()
120120
app = FastAPI()
121-
122-
# ignore the on_event deprecation warnings
123-
with warnings.catch_warnings():
124-
warnings.filterwarnings(
125-
"ignore",
126-
category=DeprecationWarning,
127-
message=r"\s*on_event is deprecated, use lifespan event handlers instead\.",
128-
)
129-
dbos = DBOS(fastapi=app, config=config)
121+
dbos = DBOS(fastapi=app, config=config)
130122

131123
# This is for test convenience.
132124
# Usually fastapi itself does launch, but we are not completing the fastapi lifecycle

tests/test_fastapi.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
import asyncio
12
import logging
23
import uuid
3-
from typing import Tuple
4+
from contextlib import asynccontextmanager
5+
from typing import Any, Tuple
46

7+
import httpx
58
import pytest
69
import sqlalchemy as sa
10+
import uvicorn
711
from fastapi import FastAPI
812
from fastapi.testclient import TestClient
913

1014
# Public API
11-
from dbos import DBOS
15+
from dbos import DBOS, ConfigFile
1216

1317
# Private API because this is a unit test
1418
from dbos._context import assert_current_dbos_context
@@ -159,6 +163,48 @@ def test_endpoint(var1: str, var2: str) -> dict[str, str]:
159163
assert workflow_handles[0].get_result() == ("a", wfuuid)
160164

161165

166+
@pytest.mark.asyncio
167+
async def test_custom_lifespan(
168+
config: ConfigFile, cleanup_test_databases: None
169+
) -> None:
170+
resource = None
171+
port = 8000
172+
173+
@asynccontextmanager
174+
async def lifespan(app: FastAPI) -> Any:
175+
nonlocal resource
176+
resource = 1
177+
yield
178+
resource = None
179+
180+
app = FastAPI(lifespan=lifespan)
181+
182+
DBOS.destroy()
183+
DBOS(fastapi=app, config=config)
184+
185+
@app.get("/")
186+
@DBOS.workflow()
187+
async def resource_workflow() -> Any:
188+
return {"resource": resource}
189+
190+
uvicorn_config = uvicorn.Config(
191+
app=app, host="127.0.0.1", port=port, log_level="error"
192+
)
193+
server = uvicorn.Server(config=uvicorn_config)
194+
195+
# Run server in background task
196+
server_task = asyncio.create_task(server.serve())
197+
await asyncio.sleep(0.2) # Give server time to start
198+
199+
async with httpx.AsyncClient() as client:
200+
r = await client.get(f"http://127.0.0.1:{port}")
201+
assert r.json()["resource"] == 1
202+
203+
server.should_exit = True
204+
await server_task
205+
assert resource is None
206+
207+
162208
def test_stacked_decorators_wf(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None:
163209
dbos, app = dbos_fastapi
164210
client = TestClient(app)

0 commit comments

Comments
 (0)