Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
from servicelib.utils import logged_gather

from ...notifications import project_logs
from ...resource_manager.user_sessions import PROJECT_ID_KEY, managed_resource
from .._projects_service import retrieve_and_notify_project_locked_state
from ...resource_manager.user_sessions import (
PROJECT_ID_KEY,
managed_resource,
)
from .._projects_service import (
conditionally_unsubscribe_from_project_logs,
retrieve_and_notify_project_locked_state,
)

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,15 +73,6 @@ async def _on_user_disconnected(

assert len(projects) <= 1, "At the moment, at most one project per session" # nosec

with log_context(
_logger,
logging.DEBUG,
msg=f"user disconnects and unsubscribes from following {projects=}",
):
await logged_gather(
*[project_logs.unsubscribe(app, ProjectID(prj)) for prj in projects]
)

await logged_gather(
*[
retrieve_and_notify_project_locked_state(
Expand All @@ -85,6 +82,11 @@ async def _on_user_disconnected(
]
)

for _project_id in projects: # At the moment, only 1 is expected
await conditionally_unsubscribe_from_project_logs(
app, ProjectID(_project_id), user_id
)


def setup_project_observer_events(app: web.Application) -> None:
setup_observer_registry(app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ...users import users_service
from ...utils_aiohttp import envelope_json_response, get_api_base_url
from .. import _projects_service, projects_wallets_service
from .._projects_service import conditionally_unsubscribe_from_project_logs
from ..exceptions import ProjectStartsTooManyDynamicNodesError
from ._rest_exceptions import handle_plugin_requests_exceptions
from ._rest_schemas import AuthenticatedRequestContext, ProjectPathParams
Expand Down Expand Up @@ -91,9 +92,10 @@ async def open_project(request: web.Request) -> web.Response:
),
)

await projects_wallets_service.check_project_financial_status(
await projects_wallets_service.check_project_financial_status_and_wallet_access(
request.app,
project_id=path_params.project_id,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
)

Expand Down Expand Up @@ -220,7 +222,11 @@ async def close_project(request: web.Request) -> web.Response:
X_SIMCORE_USER_AGENT, UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
),
)
await project_logs.unsubscribe(request.app, path_params.project_id)

await conditionally_unsubscribe_from_project_logs(
request.app, path_params.project_id, req_ctx.user_id
)

return web.json_response(status=status.HTTP_204_NO_CONTENT)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,15 @@
from ..director_v2 import director_v2_service
from ..dynamic_scheduler import api as dynamic_scheduler_service
from ..models import ClientSessionID
from ..notifications import project_logs
from ..products import products_web
from ..rabbitmq import get_rabbitmq_rpc_client
from ..redis import get_redis_lock_manager_client_sdk
from ..resource_manager.models import UserSession
from ..resource_manager.registry import get_registry
from ..resource_manager.user_sessions import (
PROJECT_ID_KEY,
SOCKET_ID_FIELDNAME,
managed_resource,
)
from ..resource_usage import service as rut_api
Expand Down Expand Up @@ -184,6 +187,43 @@
_logger = logging.getLogger(__name__)


async def conditionally_unsubscribe_from_project_logs(
app: web.Application, project_id: ProjectID, user_id: UserID
) -> None:
"""
Unsubscribes from project logs only if no active socket connections remain for the project.

This function checks for actual socket connections rather than just user sessions,
ensuring logs are only unsubscribed when truly no users are connected.

Args:
app: The web application instance
project_id: The project ID to check
user_id: Optional user ID to use for the resource session (defaults to 0 if None)
"""
redis_resource_registry = get_registry(app)
with managed_resource(user_id, None, app) as user_session:
all_user_sessions_with_project = await user_session.find_users_of_resource(
app, key=PROJECT_ID_KEY, value=f"{project_id}"
)

# Check for each user session if it has an active socket_id
actually_used_sockets_on_project = 0
for user_session_key in all_user_sessions_with_project:
output = await redis_resource_registry.find_resources(
key=user_session_key, resource_name=SOCKET_ID_FIELDNAME
)
if output:
actually_used_sockets_on_project += 1

# Only unsubscribe from logs if there are no active socket connections to the project.
# NOTE: With multiple webserver replicas, this ensures we don't unsubscribe until
# the last socket is closed, though another replica may still maintain an active
# subscription even if no users are connected to it.
if actually_used_sockets_on_project == 0:
await project_logs.unsubscribe(app, project_id)


async def patch_project_and_notify_users(
app: web.Application,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,23 @@ async def get_project_wallet(app, project_id: ProjectID) -> WalletGet | None:
return wallet


async def check_project_financial_status(
app, *, project_id: ProjectID, product_name: ProductName
async def check_project_financial_status_and_wallet_access(
app, *, project_id: ProjectID, user_id: UserID, product_name: ProductName
):
db: ProjectDBAPI = ProjectDBAPI.get_from_app_context(app)

current_project_wallet = await db.get_project_wallet(project_uuid=project_id)
rpc_client = get_rabbitmq_rpc_client(app)

if current_project_wallet:
# ensure the wallet can be used by the user
await wallets_service.get_wallet_by_user(
app,
user_id=user_id,
wallet_id=current_project_wallet.wallet_id,
product_name=product_name,
)

# Do not allow to open project if the project connected wallet is in DEBT!
project_wallet_credits_in_debt = (
await credit_transactions.get_project_wallet_total_credits(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ._wallets_service import (
check_project_financial_status,
check_project_financial_status_and_wallet_access,
connect_wallet_to_project,
get_project_wallet,
)

__all__: tuple[str, ...] = (
"check_project_financial_status",
"check_project_financial_status_and_wallet_access",
"connect_wallet_to_project",
"get_project_wallet",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cached_property
from typing import Final

from aiohttp import web
Expand All @@ -20,10 +19,10 @@
_logger = logging.getLogger(__name__)


_SOCKET_ID_FIELDNAME: Final[str] = "socket_id"
SOCKET_ID_FIELDNAME: Final[str] = "socket_id"
PROJECT_ID_KEY: Final[str] = "project_id"

assert _SOCKET_ID_FIELDNAME in ResourcesDict.__annotations__ # nosec
assert SOCKET_ID_FIELDNAME in ResourcesDict.__annotations__ # nosec
assert PROJECT_ID_KEY in ResourcesDict.__annotations__ # nosec


Expand Down Expand Up @@ -64,7 +63,7 @@ class UserSessionResourcesRegistry:
def _registry(self) -> RedisResourceRegistry:
return get_registry(self.app)

@cached_property
@property
def resource_key(self) -> UserSession:
return UserSession(
user_id=self.user_id,
Expand All @@ -81,7 +80,7 @@ async def set_socket_id(self, socket_id: str) -> None:
)

await self._registry.set_resource(
self.resource_key, (_SOCKET_ID_FIELDNAME, socket_id)
self.resource_key, (SOCKET_ID_FIELDNAME, socket_id)
)
# NOTE: hearthbeat is not emulated in tests, make sure that with very small GC intervals
# the resources do not expire; this value is usually in the order of minutes
Expand Down Expand Up @@ -113,7 +112,7 @@ async def remove_socket_id(self) -> None:
extra=get_log_record_extra(user_id=self.user_id),
)

await self._registry.remove_resource(self.resource_key, _SOCKET_ID_FIELDNAME)
await self._registry.remove_resource(self.resource_key, SOCKET_ID_FIELDNAME)
await self._registry.set_key_alive(
self.resource_key,
expiration_time=_get_service_deletion_timeout(self.app),
Expand All @@ -132,13 +131,13 @@ async def find_socket_ids(self) -> list[str]:
"user %s/tab %s finding %s from registry...",
self.user_id,
self.client_session_id,
_SOCKET_ID_FIELDNAME,
SOCKET_ID_FIELDNAME,
extra=get_log_record_extra(user_id=self.user_id),
)

return await self._registry.find_resources(
UserSession(user_id=self.user_id, client_session_id="*"),
_SOCKET_ID_FIELDNAME,
SOCKET_ID_FIELDNAME,
)

async def find_all_resources_of_user(self, key: str) -> list[str]:
Expand All @@ -148,8 +147,11 @@ async def find_all_resources_of_user(self, key: str) -> list[str]:
msg=f"{self.user_id=} finding all {key} from registry",
extra=get_log_record_extra(user_id=self.user_id),
):
return await get_registry(self.app).find_resources(
UserSession(user_id=self.user_id, client_session_id="*"), key
return await self._registry.find_resources(
UserSession(
user_id=self.user_id, client_session_id="*"
), # <-- this one checks for all user tabs
key,
)

async def find(self, resource_name: str) -> list[str]:
Expand All @@ -161,7 +163,10 @@ async def find(self, resource_name: str) -> list[str]:
extra=get_log_record_extra(user_id=self.user_id),
)

return await self._registry.find_resources(self.resource_key, resource_name)
return await self._registry.find_resources(
self.resource_key,
resource_name, # <-- when initialized with specific tab (client_session_id), checks only that tab otherwise all tabs
)

async def add(self, key: str, value: str) -> None:
_logger.debug(
Expand Down
Loading