Skip to content

Commit 02cb69c

Browse files
committed
✨ Refactor lifespan utility functions and enhance error handling in FastAPI
1 parent 035f475 commit 02cb69c

File tree

6 files changed

+63
-22
lines changed

6 files changed

+63
-22
lines changed

packages/service-library/src/servicelib/fastapi/lifespan_utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,37 @@ class LifespanAlreadyCalledError(LifespanError):
2020
msg_template = "The lifespan '{lifespan_name}' has already been called."
2121

2222

23+
class LifespanExpectedCalledError(LifespanError):
24+
msg_template = "The lifespan '{lifespan_name}' was not called. Ensure it is properly configured and invoked."
25+
26+
2327
_CALLED_LIFESPANS_KEY: Final[str] = "_CALLED_LIFESPANS"
2428

2529

2630
def is_lifespan_called(state: State, lifespan_name: str) -> bool:
31+
assert not isinstance( # nosec
32+
state, FastAPI
33+
), "TIP: lifespan func has (app, state) positional arguments"
34+
2735
called_lifespans = state.get(_CALLED_LIFESPANS_KEY, set())
2836
return lifespan_name in called_lifespans
2937

3038

31-
def record_lifespan_called_once(state: State, lifespan_name: str) -> State:
39+
def mark_lifespace_called(state: State, lifespan_name: str) -> State:
3240
"""Validates if a lifespan has already been called and records it in the state.
3341
Raises LifespanAlreadyCalledError if the lifespan has already been called.
3442
"""
35-
assert not isinstance( # nosec
36-
state, FastAPI
37-
), "TIP: lifespan func has (app, state) positional arguments"
38-
3943
if is_lifespan_called(state, lifespan_name):
4044
raise LifespanAlreadyCalledError(lifespan_name=lifespan_name)
4145

4246
called_lifespans = state.get(_CALLED_LIFESPANS_KEY, set())
4347
called_lifespans.add(lifespan_name)
4448
return {_CALLED_LIFESPANS_KEY: called_lifespans}
49+
50+
51+
def ensure_lifespan_called(state: State, lifespan_name: str) -> None:
52+
"""Ensures that a lifespan has been called.
53+
Raises LifespanNotCalledError if the lifespan has not been called.
54+
"""
55+
if not is_lifespan_called(state, lifespan_name):
56+
raise LifespanExpectedCalledError(lifespan_name=lifespan_name)

packages/service-library/src/servicelib/fastapi/postgres_lifespan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sqlalchemy.ext.asyncio import AsyncEngine
1111

1212
from ..db_asyncpg_utils import create_async_engine_and_database_ready
13-
from .lifespan_utils import LifespanOnStartupError, record_lifespan_called_once
13+
from .lifespan_utils import LifespanOnStartupError, mark_lifespace_called
1414

1515
_logger = logging.getLogger(__name__)
1616

@@ -33,7 +33,7 @@ async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[
3333
with log_context(_logger, logging.INFO, f"{__name__}"):
3434

3535
# Mark lifespan as called
36-
called_state = record_lifespan_called_once(state, "postgres_database_lifespan")
36+
called_state = mark_lifespace_called(state, "postgres_database_lifespan")
3737

3838
settings = state[PostgresLifespanState.POSTGRES_SETTINGS]
3939

packages/service-library/src/servicelib/fastapi/rabbitmq_lifespan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from servicelib.rabbitmq import wait_till_rabbitmq_responsive
99
from settings_library.rabbit import RabbitSettings
1010

11-
from .lifespan_utils import LifespanOnStartupError, record_lifespan_called_once
11+
from .lifespan_utils import (
12+
LifespanOnStartupError,
13+
mark_lifespace_called,
14+
)
1215

1316
_logger = logging.getLogger(__name__)
1417

@@ -33,7 +36,7 @@ async def rabbitmq_connectivity_lifespan(
3336
with log_context(_logger, logging.INFO, _lifespan_name):
3437

3538
# Check if lifespan has already been called
36-
called_state = record_lifespan_called_once(state, _lifespan_name)
39+
called_state = mark_lifespace_called(state, _lifespan_name)
3740

3841
# Validate input state
3942
try:

packages/service-library/src/servicelib/fastapi/redis_lifespan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from settings_library.redis import RedisDatabase, RedisSettings
1111

1212
from ..redis import RedisClientSDK
13-
from .lifespan_utils import LifespanOnStartupError, record_lifespan_called_once
13+
from .lifespan_utils import LifespanOnStartupError, mark_lifespace_called
1414

1515
_logger = logging.getLogger(__name__)
1616

@@ -34,7 +34,7 @@ async def redis_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[Sta
3434
with log_context(_logger, logging.INFO, f"{__name__}"):
3535

3636
# Check if lifespan has already been called
37-
called_state = record_lifespan_called_once(state, "redis_database_lifespan")
37+
called_state = mark_lifespace_called(state, "redis_database_lifespan")
3838

3939
# Validate input state
4040
try:

packages/service-library/tests/fastapi/test_lifespan_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from pytest_simcore.helpers.logging_tools import log_context
1818
from servicelib.fastapi.lifespan_utils import (
1919
LifespanAlreadyCalledError,
20+
LifespanExpectedCalledError,
2021
LifespanOnShutdownError,
2122
LifespanOnStartupError,
22-
record_lifespan_called_once,
23+
ensure_lifespan_called,
24+
mark_lifespace_called,
2325
)
2426

2527

@@ -259,18 +261,21 @@ async def test_app_lifespan_with_error_on_shutdown(
259261

260262

261263
async def test_lifespan_called_more_than_once(is_pdb_enabled: bool):
262-
state = {}
263-
264264
app_lifespan = LifespanManager()
265265

266266
@app_lifespan.add
267267
async def _one(_, state: State) -> AsyncIterator[State]:
268-
called_state = record_lifespan_called_once(state, "test_lifespan_one")
268+
called_state = mark_lifespace_called(state, "test_lifespan_one")
269269
yield {"other": 0, **called_state}
270270

271271
@app_lifespan.add
272272
async def _two(_, state: State) -> AsyncIterator[State]:
273-
called_state = record_lifespan_called_once(state, "test_lifespan_two")
273+
ensure_lifespan_called(state, "test_lifespan_one")
274+
275+
with pytest.raises(LifespanExpectedCalledError):
276+
ensure_lifespan_called(state, "test_lifespan_three")
277+
278+
called_state = mark_lifespace_called(state, "test_lifespan_two")
274279
yield {"something": 0, **called_state}
275280

276281
app_lifespan.add(_one) # added "by mistake"

packages/service-library/tests/fastapi/test_rabbitmq_lifespan.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ def mock_rabbitmq_connection(mocker: MockerFixture) -> MockType:
3737

3838
@pytest.fixture
3939
def mock_rabbitmq_rpc_client_class(mocker: MockerFixture) -> MockType:
40-
return mocker.patch.object(
41-
servicelib.rabbitmq._client_rpc,
42-
"RabbitMQRPCClient",
43-
return_value=mocker.AsyncMock(),
40+
mock_rpc_client_instance = mocker.AsyncMock()
41+
mocker.patch.object(
42+
servicelib.rabbitmq._client_rpc.RabbitMQRPCClient,
43+
"create",
44+
return_value=mock_rpc_client_instance,
4445
)
46+
mock_rpc_client_instance.close = mocker.AsyncMock()
47+
return mock_rpc_client_instance
4548

4649

4750
@pytest.fixture
@@ -64,25 +67,40 @@ class AppSettings(BaseApplicationSettings):
6467
..., json_schema_extra={"auto_default_from_env": True}
6568
)
6669

70+
# setup settings
6771
async def my_app_settings(app: FastAPI) -> AsyncIterator[State]:
6872
app.state.settings = AppSettings.create_from_envs()
6973

7074
yield RabbitMQLifespanState(
7175
RABBIT_SETTINGS=app.state.settings.RABBITMQ,
7276
).model_dump()
7377

78+
# setup rpc-server using rabbitmq_rpc_client_context (yes, a "rpc_server" is built with an RabbitMQRpcClient)
7479
async def my_app_rpc_server(app: FastAPI, state: State) -> AsyncIterator[State]:
80+
assert "RABBIT_CONNECTIVITY_LIFESPAN_NAME" in state
7581

7682
async with rabbitmq_rpc_client_context(
7783
"rpc_server", app.state.settings.RABBITMQ
7884
) as rpc_server:
7985
app.state.rpc_server = rpc_server
8086
yield {}
8187

88+
# setup rpc-client using rabbitmq_rpc_client_context
89+
async def my_app_rpc_client(app: FastAPI, state: State) -> AsyncIterator[State]:
90+
91+
assert "RABBIT_CONNECTIVITY_LIFESPAN_NAME" in state
92+
93+
async with rabbitmq_rpc_client_context(
94+
"rpc_client", app.state.settings.RABBITMQ
95+
) as rpc_client:
96+
app.state.rpc_client = rpc_client
97+
yield {}
98+
8299
app_lifespan = LifespanManager()
83100
app_lifespan.add(my_app_settings)
84101
app_lifespan.add(rabbitmq_connectivity_lifespan)
85102
app_lifespan.add(my_app_rpc_server)
103+
app_lifespan.add(my_app_rpc_client)
86104

87105
assert not mock_rabbitmq_connection.called
88106
assert not mock_rabbitmq_rpc_client_class.called
@@ -103,17 +121,20 @@ async def test_lifespan_rabbitmq_in_an_app(
103121
app,
104122
startup_timeout=None if is_pdb_enabled else 10,
105123
shutdown_timeout=None if is_pdb_enabled else 10,
106-
) as asgi_manager:
124+
):
125+
107126
# Verify that RabbitMQ responsiveness was checked
108127
mock_rabbitmq_connection.assert_called_once_with(
109128
app.state.settings.RABBITMQ.dsn
110129
)
111130

112131
# Verify that RabbitMQ settings are in the lifespan manager state
113132
assert app.state.settings.RABBITMQ
133+
assert app.state.rpc_server
134+
assert app.state.rpc_client
114135

115136
# No explicit shutdown logic for RabbitMQ in this case
116-
assert mock_rabbitmq_rpc_client_class.called
137+
assert mock_rabbitmq_rpc_client_class.close.called
117138

118139

119140
async def test_lifespan_rabbitmq_with_invalid_settings(

0 commit comments

Comments
 (0)