Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...notifications import project_logs
from ...products import products_web
from ...products.models import Product
from ...resource_manager.user_sessions import managed_resource
from ...resource_manager.user_sessions import PROJECT_ID_KEY, managed_resource
from ...security.decorators import permission_required
from ...socketio.server import get_socket_server
from ...users import users_service
Expand Down Expand Up @@ -91,9 +91,10 @@ async def open_project(request: web.Request) -> web.Response:
),
)

await projects_wallets_service.check_project_financial_status(
await projects_wallets_service.check_project_financial_status_and_wallet_access(
request.app,
project_id=path_params.project_id,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
)

Expand Down Expand Up @@ -220,7 +221,17 @@ async def close_project(request: web.Request) -> web.Response:
X_SIMCORE_USER_AGENT, UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
),
)
await project_logs.unsubscribe(request.app, path_params.project_id)

with managed_resource(
req_ctx.user_id, client_session_id, request.app
) as user_session:
all_user_sessions_with_project = await user_session.find_users_of_resource(
request.app, key=PROJECT_ID_KEY, value=f"{path_params.project_id}"
)
# Only unsubscribe from logs if there is no other occurrence of the open project
if len(all_user_sessions_with_project) == 0:
await project_logs.unsubscribe(request.app, path_params.project_id)

return web.json_response(status=status.HTTP_204_NO_CONTENT)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,23 @@ async def get_project_wallet(app, project_id: ProjectID) -> WalletGet | None:
return wallet


async def check_project_financial_status(
app, *, project_id: ProjectID, product_name: ProductName
async def check_project_financial_status_and_wallet_access(
app, *, project_id: ProjectID, user_id: UserID, product_name: ProductName
):
db: ProjectDBAPI = ProjectDBAPI.get_from_app_context(app)

current_project_wallet = await db.get_project_wallet(project_uuid=project_id)
rpc_client = get_rabbitmq_rpc_client(app)

if current_project_wallet:
# ensure the wallet can be used by the user
await wallets_service.get_wallet_by_user(
app,
user_id=user_id,
wallet_id=current_project_wallet.wallet_id,
product_name=product_name,
)

# Do not allow to open project if the project connected wallet is in DEBT!
project_wallet_credits_in_debt = (
await credit_transactions.get_project_wallet_total_credits(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ._wallets_service import (
check_project_financial_status,
check_project_financial_status_and_wallet_access,
connect_wallet_to_project,
get_project_wallet,
)

__all__: tuple[str, ...] = (
"check_project_financial_status",
"check_project_financial_status_and_wallet_access",
"connect_wallet_to_project",
"get_project_wallet",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cached_property
from typing import Final

from aiohttp import web
Expand Down Expand Up @@ -64,7 +63,7 @@ class UserSessionResourcesRegistry:
def _registry(self) -> RedisResourceRegistry:
return get_registry(self.app)

@cached_property
@property
def resource_key(self) -> UserSession:
return UserSession(
user_id=self.user_id,
Expand Down Expand Up @@ -148,8 +147,11 @@ async def find_all_resources_of_user(self, key: str) -> list[str]:
msg=f"{self.user_id=} finding all {key} from registry",
extra=get_log_record_extra(user_id=self.user_id),
):
return await get_registry(self.app).find_resources(
UserSession(user_id=self.user_id, client_session_id="*"), key
return await self._registry.find_resources(
UserSession(
user_id=self.user_id, client_session_id="*"
), # <-- this one checks for all user tabs
key,
)

async def find(self, resource_name: str) -> list[str]:
Expand All @@ -161,7 +163,10 @@ async def find(self, resource_name: str) -> list[str]:
extra=get_log_record_extra(user_id=self.user_id),
)

return await self._registry.find_resources(self.resource_key, resource_name)
return await self._registry.find_resources(
self.resource_key,
resource_name, # <-- when initialized with specific tab (client_session_id), checks only that tab otherwise all tabs
)

async def add(self, key: str, value: str) -> None:
_logger.debug(
Expand Down
Loading