diff --git a/services/web/server/src/simcore_service_webserver/notifications/_rabbitmq_exclusive_queue_consumers.py b/services/web/server/src/simcore_service_webserver/notifications/_rabbitmq_exclusive_queue_consumers.py index 7c3925b8c50f..45c45bdfed1f 100644 --- a/services/web/server/src/simcore_service_webserver/notifications/_rabbitmq_exclusive_queue_consumers.py +++ b/services/web/server/src/simcore_service_webserver/notifications/_rabbitmq_exclusive_queue_consumers.py @@ -1,4 +1,6 @@ +import asyncio import logging +from collections import defaultdict from collections.abc import AsyncIterator, Generator from typing import Final @@ -36,6 +38,8 @@ _logger = logging.getLogger(__name__) _APP_RABBITMQ_CONSUMERS_KEY: Final[str] = f"{__name__}.rabbit_consumers" +APP_WALLET_SUBSCRIPTIONS_KEY: Final[str] = "wallet_subscriptions" +APP_WALLET_SUBSCRIPTION_LOCK_KEY: Final[str] = "wallet_subscription_lock" async def _convert_to_node_update_event( @@ -192,6 +196,12 @@ async def on_cleanup_ctx_rabbitmq_consumers( app[_APP_RABBITMQ_CONSUMERS_KEY] = await subscribe_to_rabbitmq( app, _EXCHANGE_TO_PARSER_CONFIG ) + + app[APP_WALLET_SUBSCRIPTIONS_KEY] = defaultdict( + int + ) # wallet_id -> subscriber count + app[APP_WALLET_SUBSCRIPTION_LOCK_KEY] = asyncio.Lock() # Ensures exclusive access + yield # cleanup diff --git a/services/web/server/src/simcore_service_webserver/notifications/wallet_osparc_credits.py b/services/web/server/src/simcore_service_webserver/notifications/wallet_osparc_credits.py index f66293bf3775..2f314bd93330 100644 --- a/services/web/server/src/simcore_service_webserver/notifications/wallet_osparc_credits.py +++ b/services/web/server/src/simcore_service_webserver/notifications/wallet_osparc_credits.py @@ -7,28 +7,37 @@ from servicelib.rabbitmq import RabbitMQClient from ..rabbitmq import get_rabbitmq_client +from ._rabbitmq_exclusive_queue_consumers import ( + APP_WALLET_SUBSCRIPTION_LOCK_KEY, + APP_WALLET_SUBSCRIPTIONS_KEY, +) _logger = logging.getLogger(__name__) -_SUBSCRIBABLE_EXCHANGES = [ - WalletCreditsMessage, -] - - async def subscribe(app: web.Application, wallet_id: WalletID) -> None: - rabbit_client: RabbitMQClient = get_rabbitmq_client(app) - for exchange in _SUBSCRIBABLE_EXCHANGES: - exchange_name = exchange.get_channel_name() - await rabbit_client.add_topics(exchange_name, topics=[f"{wallet_id}"]) + async with app[APP_WALLET_SUBSCRIPTION_LOCK_KEY]: + counter = app[APP_WALLET_SUBSCRIPTIONS_KEY][wallet_id] + app[APP_WALLET_SUBSCRIPTIONS_KEY][wallet_id] += 1 + + if counter == 0: # First subscriber + rabbit_client: RabbitMQClient = get_rabbitmq_client(app) + await rabbit_client.add_topics( + WalletCreditsMessage.get_channel_name(), topics=[f"{wallet_id}"] + ) async def unsubscribe(app: web.Application, wallet_id: WalletID) -> None: - rabbit_client: RabbitMQClient = get_rabbitmq_client(app) - for exchange in _SUBSCRIBABLE_EXCHANGES: - exchange_name = exchange.get_channel_name() - with log_catch(_logger, reraise=False): - # NOTE: in case something bad happenned with the connection to the RabbitMQ server - # such as a network disconnection. this call can fail. - await rabbit_client.remove_topics(exchange_name, topics=[f"{wallet_id}"]) + + async with app[APP_WALLET_SUBSCRIPTION_LOCK_KEY]: + counter = app[APP_WALLET_SUBSCRIPTIONS_KEY].get(wallet_id, 0) + if counter > 0: + app[APP_WALLET_SUBSCRIPTIONS_KEY][wallet_id] -= 1 + + if counter == 1: # Last subscriber + rabbit_client: RabbitMQClient = get_rabbitmq_client(app) + with log_catch(_logger, reraise=False): + await rabbit_client.remove_topics( + WalletCreditsMessage.get_channel_name(), topics=[f"{wallet_id}"] + ) diff --git a/services/web/server/src/simcore_service_webserver/projects/_projects_service.py b/services/web/server/src/simcore_service_webserver/projects/_projects_service.py index 08d99259eb2c..7797ba4bf0ef 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_projects_service.py +++ b/services/web/server/src/simcore_service_webserver/projects/_projects_service.py @@ -1114,7 +1114,7 @@ async def patch_project_node( project_id=project_id, user_id=user_id, product_name=product_name, - permission="write", # NOTE: MD: before only read was sufficient, double check this + permission="write", ) # 2. If patching service key or version make sure it's valid diff --git a/services/web/server/tests/unit/isolated/notifications/test_wallet_osparc_credits.py b/services/web/server/tests/unit/isolated/notifications/test_wallet_osparc_credits.py new file mode 100644 index 000000000000..227bfbe484cf --- /dev/null +++ b/services/web/server/tests/unit/isolated/notifications/test_wallet_osparc_credits.py @@ -0,0 +1,70 @@ +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument +# pylint: disable=unused-import +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest +from models_library.wallets import WalletID +from simcore_service_webserver.notifications import wallet_osparc_credits + + +@pytest.fixture +def app_with_wallets(): + app = { + "wallet_subscription_lock": asyncio.Lock(), + "wallet_subscriptions": {}, + } + return app + + +@pytest.fixture +def wallet_id(): + return WalletID(1) + + +async def test_subscribe_first_and_second(app_with_wallets, wallet_id): + app = app_with_wallets + app["wallet_subscriptions"][wallet_id] = 0 + mock_rabbit = AsyncMock() + with patch( + "simcore_service_webserver.notifications.wallet_osparc_credits.get_rabbitmq_client", + return_value=mock_rabbit, + ): + await wallet_osparc_credits.subscribe(app, wallet_id) + mock_rabbit.add_topics.assert_awaited_once() + # Second subscribe should not call add_topics again + await wallet_osparc_credits.subscribe(app, wallet_id) + assert mock_rabbit.add_topics.await_count == 1 + assert app["wallet_subscriptions"][wallet_id] == 2 + + +async def test_unsubscribe_last_and_not_last(app_with_wallets, wallet_id): + app = app_with_wallets + app["wallet_subscriptions"][wallet_id] = 2 + mock_rabbit = AsyncMock() + with patch( + "simcore_service_webserver.notifications.wallet_osparc_credits.get_rabbitmq_client", + return_value=mock_rabbit, + ): + # Not last unsubscribe + await wallet_osparc_credits.unsubscribe(app, wallet_id) + mock_rabbit.remove_topics.assert_not_awaited() + assert app["wallet_subscriptions"][wallet_id] == 1 + # Last unsubscribe + await wallet_osparc_credits.unsubscribe(app, wallet_id) + mock_rabbit.remove_topics.assert_awaited_once() + assert app["wallet_subscriptions"][wallet_id] == 0 + + +async def test_unsubscribe_when_not_subscribed(app_with_wallets, wallet_id): + app = app_with_wallets + # wallet_id not present + mock_rabbit = AsyncMock() + with patch( + "simcore_service_webserver.notifications.wallet_osparc_credits.get_rabbitmq_client", + return_value=mock_rabbit, + ): + await wallet_osparc_credits.unsubscribe(app, wallet_id) + mock_rabbit.remove_topics.assert_not_awaited() + assert app["wallet_subscriptions"].get(wallet_id, 0) == 0