Skip to content
25 changes: 19 additions & 6 deletions packages/service-library/src/servicelib/rabbitmq/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from uuid import uuid4

import aio_pika
from common_library.logging.logging_errors import create_troubleshooting_log_kwargs
from pydantic import NonNegativeInt

from ..logging_utils import log_catch, log_context
Expand Down Expand Up @@ -52,7 +53,7 @@ def _get_x_death_count(message: aio_pika.abc.AbstractIncomingMessage) -> int:
return count


async def _safe_nack(
async def _nack_message(
message_handler: MessageHandler,
max_retries_upon_error: int,
message: aio_pika.abc.AbstractIncomingMessage,
Expand All @@ -72,7 +73,7 @@ async def _safe_nack(
# NOTE: puts message to the Dead Letter Exchange
await message.nack(requeue=False)
else:
_logger.exception(
_logger.error(
"Handler '%s' is giving up on message '%s' with body '%s'",
message_handler,
message,
Expand All @@ -93,13 +94,24 @@ async def _on_message(
msg=f"Received message from {message.exchange=}, {message.routing_key=}",
):
if not await message_handler(message.body):
await _safe_nack(message_handler, max_retries_upon_error, message)
except Exception: # pylint: disable=broad-exception-caught
await _nack_message(
message_handler, max_retries_upon_error, message
)
except Exception as exc:
_logger.exception(
"Exception raised when handling message. TIP: review your code"
**create_troubleshooting_log_kwargs(
"Unhandled exception raised in message handler or when nacking message",
error=exc,
error_context={
"message_id": message.message_id,
"message_body": message.body,
"message_handler": message_handler.__name__,
},
tip="This could indicate an error in the message handler, please check the message handler code",
)
)
with log_catch(_logger, reraise=False):
await _safe_nack(message_handler, max_retries_upon_error, message)
await _nack_message(message_handler, max_retries_upon_error, message)


@dataclass
Expand Down Expand Up @@ -144,6 +156,7 @@ async def close(self) -> None:
async def _get_channel(self) -> aio_pika.abc.AbstractChannel:
assert self._connection_pool # nosec
async with self._connection_pool.acquire() as connection:
assert isinstance(connection, aio_pika.RobustConnection) # nosec
channel: aio_pika.abc.AbstractChannel = await connection.channel()
channel.close_callbacks.add(self._channel_close_callback)
return channel
Expand Down
47 changes: 32 additions & 15 deletions packages/service-library/src/servicelib/rabbitmq/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import aio_pika
import aiormq
from common_library.logging.logging_errors import create_troubleshooting_log_kwargs
from settings_library.rabbit import RabbitSettings

from ..logging_utils import log_catch
Expand All @@ -29,33 +30,49 @@ def _connection_close_callback(
exc: BaseException | None,
) -> None:
if exc:
if isinstance(exc, asyncio.CancelledError):
_logger.info("Rabbit connection cancelled")
elif isinstance(exc, aiormq.exceptions.ConnectionClosed):
_logger.info("Rabbit connection closed: %s", exc)
if isinstance(
exc, asyncio.CancelledError | aiormq.exceptions.ConnectionClosed
):
_logger.info(
**create_troubleshooting_log_kwargs(
"RabbitMQ connection closed",
error=exc,
error_context={"sender": sender},
)
)
else:
_logger.error(
"Rabbit connection closed with exception from %s:%s",
type(exc),
exc,
**create_troubleshooting_log_kwargs(
"RabbitMQ connection closed with unexpected error",
error=exc,
error_context={"sender": sender},
)
)
self._healthy_state = False

def _channel_close_callback(
self,
sender: Any, # pylint: disable=unused-argument # noqa: ARG002
sender: Any,
exc: BaseException | None,
) -> None:
if exc:
if isinstance(exc, asyncio.CancelledError):
_logger.info("Rabbit channel cancelled")
elif isinstance(exc, aiormq.exceptions.ChannelClosed):
_logger.info("Rabbit channel closed")
if isinstance(
exc, asyncio.CancelledError | aiormq.exceptions.ChannelClosed
):
_logger.info(
**create_troubleshooting_log_kwargs(
"RabbitMQ channel closed",
error=exc,
error_context={"sender": sender},
)
)
else:
_logger.error(
"Rabbit channel closed with exception from %s:%s",
type(exc),
exc,
**create_troubleshooting_log_kwargs(
"RabbitMQ channel closed with unexpected error",
error=exc,
error_context={"sender": sender},
)
)
self._healthy_state = False

Expand Down
3 changes: 3 additions & 0 deletions packages/service-library/src/servicelib/redis/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from pydantic import NonNegativeInt

DEFAULT_EXPECTED_LOCK_OVERALL_TIME: Final[datetime.timedelta] = datetime.timedelta(
seconds=30
)
DEFAULT_LOCK_TTL: Final[datetime.timedelta] = datetime.timedelta(seconds=10)
DEFAULT_SOCKET_TIMEOUT: Final[datetime.timedelta] = datetime.timedelta(seconds=30)

Expand Down
16 changes: 15 additions & 1 deletion packages/service-library/src/servicelib/redis/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..background_task import periodic
from ._client import RedisClientSDK
from ._constants import DEFAULT_LOCK_TTL
from ._constants import DEFAULT_EXPECTED_LOCK_OVERALL_TIME, DEFAULT_LOCK_TTL
from ._errors import CouldNotAcquireLockError, LockLostError
from ._utils import auto_extend_lock

Expand Down Expand Up @@ -95,6 +95,7 @@ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
):
raise CouldNotAcquireLockError(lock=lock)

lock_acquisition_time = arrow.utcnow()
try:
async with asyncio.TaskGroup() as tg:
started_event = asyncio.Event()
Expand Down Expand Up @@ -157,6 +158,19 @@ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
"Look for synchronous code that prevents refreshing the lock or asyncio loop overload.",
)
)
finally:
lock_release_time = arrow.utcnow()
locking_time = lock_release_time - lock_acquisition_time
if locking_time > DEFAULT_EXPECTED_LOCK_OVERALL_TIME:
_logger.warning(
"Lock `%s'for %s was held for %s which is longer than the expected (%s). "
"TIP: consider reducing the locking time by optimizing the code inside "
"the critical section or increasing the default locking time",
redis_lock_key,
coro.__name__,
locking_time,
DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
)

return _wrapper

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, ParamSpec, TypeVar

import arrow
from common_library.async_tools import cancel_wait_task
from common_library.logging.logging_errors import create_troubleshooting_log_kwargs

from ..background_task import periodic
from ._client import RedisClientSDK
from ._constants import (
DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
DEFAULT_SEMAPHORE_TTL,
DEFAULT_SOCKET_TIMEOUT,
)
Expand Down Expand Up @@ -42,6 +44,7 @@ async def _managed_semaphore_execution(
if not await semaphore.acquire():
raise SemaphoreAcquisitionError(name=semaphore_key, capacity=semaphore.capacity)

lock_acquisition_time = arrow.utcnow()
try:
# NOTE: Use TaskGroup for proper exception propagation, this ensures that in case of error the context manager will be properly exited
# and the semaphore released.
Expand Down Expand Up @@ -100,6 +103,18 @@ async def _periodic_renewer() -> None:
"Look for synchronous code that prevents refreshing the semaphore or asyncio loop overload.",
)
)
finally:
lock_release_time = arrow.utcnow()
locking_time = lock_release_time - lock_acquisition_time
if locking_time > DEFAULT_EXPECTED_LOCK_OVERALL_TIME:
_logger.warning(
"Semaphore '%s' was held for %s which is longer than expected (%s). "
"TIP: consider reducing the locking time by optimizing the code inside "
"the critical section or increasing the default locking time",
semaphore_key,
locking_time,
DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
)


def _create_semaphore(
Expand Down
8 changes: 4 additions & 4 deletions packages/service-library/tests/rabbitmq/test_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async def _always_returning_fail(_: Any) -> bool:


@pytest.mark.parametrize("topics", _TOPICS)
@pytest.mark.no_cleanup_check_rabbitmq_server_has_no_errors()
@pytest.mark.no_cleanup_check_rabbitmq_server_has_no_errors
async def test_publish_with_no_registered_subscriber(
on_message_spy: mock.Mock,
create_rabbitmq_client: Callable[[str], RabbitMQClient],
Expand Down Expand Up @@ -476,7 +476,7 @@ def _raise_once_then_true(*args, **kwargs):

@pytest.fixture
async def ensure_queue_deletion(
create_rabbitmq_client: Callable[[str], RabbitMQClient]
create_rabbitmq_client: Callable[[str], RabbitMQClient],
) -> AsyncIterator[Callable[[QueueName], None]]:
created_queues = set()

Expand Down Expand Up @@ -723,7 +723,7 @@ async def test_rabbit_adding_topics_to_a_fanout_exchange(
await _assert_message_received(mocked_message_parser, 0)


@pytest.mark.no_cleanup_check_rabbitmq_server_has_no_errors()
@pytest.mark.no_cleanup_check_rabbitmq_server_has_no_errors
async def test_rabbit_not_using_the_same_exchange_type_raises(
create_rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
Expand All @@ -738,7 +738,7 @@ async def test_rabbit_not_using_the_same_exchange_type_raises(
await client.subscribe(exchange_name, mocked_message_parser, topics=[])


@pytest.mark.no_cleanup_check_rabbitmq_server_has_no_errors()
@pytest.mark.no_cleanup_check_rabbitmq_server_has_no_errors
async def test_unsubscribe_consumer(
create_rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
Expand Down
Loading