Skip to content

Commit 035f475

Browse files
committed
✨ Add RabbitMQ lifespan management and validation in FastAPI
1 parent 89669b0 commit 035f475

File tree

4 files changed

+210
-2
lines changed

4 files changed

+210
-2
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import logging
2+
from collections.abc import AsyncIterator
3+
4+
from fastapi import FastAPI
5+
from fastapi_lifespan_manager import State
6+
from pydantic import BaseModel, ValidationError
7+
from servicelib.logging_utils import log_context
8+
from servicelib.rabbitmq import wait_till_rabbitmq_responsive
9+
from settings_library.rabbit import RabbitSettings
10+
11+
from .lifespan_utils import LifespanOnStartupError, record_lifespan_called_once
12+
13+
_logger = logging.getLogger(__name__)
14+
15+
16+
class RabbitMQConfigurationError(LifespanOnStartupError):
17+
msg_template = "Invalid RabbitMQ config on startup : {validation_error}"
18+
19+
20+
class RabbitMQLifespanState(BaseModel):
21+
RABBIT_SETTINGS: RabbitSettings
22+
23+
24+
async def rabbitmq_connectivity_lifespan(
25+
_: FastAPI, state: State
26+
) -> AsyncIterator[State]:
27+
"""Ensures RabbitMQ connectivity during lifespan.
28+
29+
For creating clients, use additional lifespans like rabbitmq_rpc_client_context.
30+
"""
31+
_lifespan_name = f"{__name__}.{rabbitmq_connectivity_lifespan.__name__}"
32+
33+
with log_context(_logger, logging.INFO, _lifespan_name):
34+
35+
# Check if lifespan has already been called
36+
called_state = record_lifespan_called_once(state, _lifespan_name)
37+
38+
# Validate input state
39+
try:
40+
rabbit_state = RabbitMQLifespanState.model_validate(state)
41+
rabbit_dsn_with_secrets = rabbit_state.RABBIT_SETTINGS.dsn
42+
except ValidationError as exc:
43+
raise RabbitMQConfigurationError(validation_error=exc, state=state) from exc
44+
45+
# Wait for RabbitMQ to be responsive
46+
await wait_till_rabbitmq_responsive(rabbit_dsn_with_secrets)
47+
48+
yield {"RABBIT_CONNECTIVITY_LIFESPAN_NAME": _lifespan_name, **called_state}

packages/service-library/src/servicelib/rabbitmq/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from models_library.rabbitmq_basic_types import RPCNamespace
22

33
from ._client import RabbitMQClient
4-
from ._client_rpc import RabbitMQRPCClient
4+
from ._client_rpc import RabbitMQRPCClient, rabbitmq_rpc_client_context
55
from ._constants import BIND_TO_ALL_TOPICS, RPC_REQUEST_DEFAULT_TIMEOUT_S
66
from ._errors import (
77
RemoteMethodNotRegisteredError,
@@ -28,6 +28,7 @@
2828
"RabbitMQRPCClient",
2929
"RemoteMethodNotRegisteredError",
3030
"is_rabbitmq_responsive",
31+
"rabbitmq_rpc_client_context",
3132
"wait_till_rabbitmq_responsive",
3233
)
3334

packages/service-library/src/servicelib/rabbitmq/_client_rpc.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
import functools
33
import logging
4-
from collections.abc import Callable
4+
from collections.abc import AsyncIterator, Callable
5+
from contextlib import asynccontextmanager
56
from dataclasses import dataclass
67
from typing import Any
78

@@ -156,3 +157,19 @@ async def unregister_handler(self, handler: Callable[..., Any]) -> None:
156157
raise RPCNotInitializedError
157158

158159
await self._rpc.unregister(handler)
160+
161+
162+
@asynccontextmanager
163+
async def rabbitmq_rpc_client_context(
164+
rpc_client_name: str, settings: RabbitSettings, **kwargs
165+
) -> AsyncIterator[RabbitMQRPCClient]:
166+
"""
167+
Adapter to create and close a RabbitMQRPCClient using an async context manager.
168+
"""
169+
rpc_client = await RabbitMQRPCClient.create(
170+
client_name=rpc_client_name, settings=settings, **kwargs
171+
)
172+
try:
173+
yield rpc_client
174+
finally:
175+
await rpc_client.close()
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# pylint: disable=protected-access
2+
# pylint: disable=redefined-outer-name
3+
# pylint: disable=too-many-arguments
4+
# pylint: disable=unused-argument
5+
# pylint: disable=unused-variable
6+
7+
from collections.abc import AsyncIterator
8+
9+
import pytest
10+
import servicelib.fastapi.rabbitmq_lifespan
11+
import servicelib.rabbitmq
12+
from asgi_lifespan import LifespanManager as ASGILifespanManager
13+
from fastapi import FastAPI
14+
from fastapi_lifespan_manager import LifespanManager, State
15+
from pydantic import Field
16+
from pytest_mock import MockerFixture, MockType
17+
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
18+
from pytest_simcore.helpers.typing_env import EnvVarsDict
19+
from servicelib.fastapi.rabbitmq_lifespan import (
20+
RabbitMQConfigurationError,
21+
RabbitMQLifespanState,
22+
rabbitmq_connectivity_lifespan,
23+
)
24+
from servicelib.rabbitmq import rabbitmq_rpc_client_context
25+
from settings_library.application import BaseApplicationSettings
26+
from settings_library.rabbit import RabbitSettings
27+
28+
29+
@pytest.fixture
30+
def mock_rabbitmq_connection(mocker: MockerFixture) -> MockType:
31+
return mocker.patch.object(
32+
servicelib.fastapi.rabbitmq_lifespan,
33+
"wait_till_rabbitmq_responsive",
34+
return_value=mocker.AsyncMock(),
35+
)
36+
37+
38+
@pytest.fixture
39+
def mock_rabbitmq_rpc_client_class(mocker: MockerFixture) -> MockType:
40+
return mocker.patch.object(
41+
servicelib.rabbitmq._client_rpc,
42+
"RabbitMQRPCClient",
43+
return_value=mocker.AsyncMock(),
44+
)
45+
46+
47+
@pytest.fixture
48+
def app_environment(monkeypatch: pytest.MonkeyPatch) -> EnvVarsDict:
49+
return setenvs_from_dict(
50+
monkeypatch, RabbitSettings.model_json_schema()["examples"][0]
51+
)
52+
53+
54+
@pytest.fixture
55+
def app_lifespan(
56+
app_environment: EnvVarsDict,
57+
mock_rabbitmq_connection: MockType,
58+
mock_rabbitmq_rpc_client_class: MockType,
59+
) -> LifespanManager:
60+
assert app_environment
61+
62+
class AppSettings(BaseApplicationSettings):
63+
RABBITMQ: RabbitSettings = Field(
64+
..., json_schema_extra={"auto_default_from_env": True}
65+
)
66+
67+
async def my_app_settings(app: FastAPI) -> AsyncIterator[State]:
68+
app.state.settings = AppSettings.create_from_envs()
69+
70+
yield RabbitMQLifespanState(
71+
RABBIT_SETTINGS=app.state.settings.RABBITMQ,
72+
).model_dump()
73+
74+
async def my_app_rpc_server(app: FastAPI, state: State) -> AsyncIterator[State]:
75+
76+
async with rabbitmq_rpc_client_context(
77+
"rpc_server", app.state.settings.RABBITMQ
78+
) as rpc_server:
79+
app.state.rpc_server = rpc_server
80+
yield {}
81+
82+
app_lifespan = LifespanManager()
83+
app_lifespan.add(my_app_settings)
84+
app_lifespan.add(rabbitmq_connectivity_lifespan)
85+
app_lifespan.add(my_app_rpc_server)
86+
87+
assert not mock_rabbitmq_connection.called
88+
assert not mock_rabbitmq_rpc_client_class.called
89+
90+
return app_lifespan
91+
92+
93+
async def test_lifespan_rabbitmq_in_an_app(
94+
is_pdb_enabled: bool,
95+
app_environment: EnvVarsDict,
96+
mock_rabbitmq_connection: MockType,
97+
mock_rabbitmq_rpc_client_class: MockType,
98+
app_lifespan: LifespanManager,
99+
):
100+
app = FastAPI(lifespan=app_lifespan)
101+
102+
async with ASGILifespanManager(
103+
app,
104+
startup_timeout=None if is_pdb_enabled else 10,
105+
shutdown_timeout=None if is_pdb_enabled else 10,
106+
) as asgi_manager:
107+
# Verify that RabbitMQ responsiveness was checked
108+
mock_rabbitmq_connection.assert_called_once_with(
109+
app.state.settings.RABBITMQ.dsn
110+
)
111+
112+
# Verify that RabbitMQ settings are in the lifespan manager state
113+
assert app.state.settings.RABBITMQ
114+
115+
# No explicit shutdown logic for RabbitMQ in this case
116+
assert mock_rabbitmq_rpc_client_class.called
117+
118+
119+
async def test_lifespan_rabbitmq_with_invalid_settings(
120+
is_pdb_enabled: bool,
121+
):
122+
async def my_app_settings(app: FastAPI) -> AsyncIterator[State]:
123+
yield {"RABBIT_SETTINGS": None}
124+
125+
app_lifespan = LifespanManager()
126+
app_lifespan.add(my_app_settings)
127+
app_lifespan.add(rabbitmq_connectivity_lifespan)
128+
129+
app = FastAPI(lifespan=app_lifespan)
130+
131+
with pytest.raises(RabbitMQConfigurationError, match="Invalid RabbitMQ") as excinfo:
132+
async with ASGILifespanManager(
133+
app,
134+
startup_timeout=None if is_pdb_enabled else 10,
135+
shutdown_timeout=None if is_pdb_enabled else 10,
136+
):
137+
...
138+
139+
exception = excinfo.value
140+
assert isinstance(exception, RabbitMQConfigurationError)
141+
assert exception.validation_error
142+
assert exception.state["RABBIT_SETTINGS"] is None

0 commit comments

Comments
 (0)