diff --git a/packages/models-library/src/models_library/conversations.py b/packages/models-library/src/models_library/conversations.py index 5d33a0fcd45..e8e22ebd559 100644 --- a/packages/models-library/src/models_library/conversations.py +++ b/packages/models-library/src/models_library/conversations.py @@ -1,16 +1,20 @@ from datetime import datetime from enum import auto -from typing import TypeAlias +from typing import Annotated, TypeAlias from uuid import UUID from models_library.groups import GroupID from models_library.projects import ProjectID -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, StringConstraints from .products import ProductName from .utils.enums import StrAutoEnum ConversationID: TypeAlias = UUID +ConversationName: TypeAlias = Annotated[ + str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255) +] + ConversationMessageID: TypeAlias = UUID @@ -36,7 +40,7 @@ class ConversationMessageType(StrAutoEnum): class ConversationGetDB(BaseModel): conversation_id: ConversationID product_name: ProductName - name: str + name: ConversationName project_uuid: ProjectID | None user_group_id: GroupID type: ConversationType @@ -63,7 +67,7 @@ class ConversationMessageGetDB(BaseModel): class ConversationPatchDB(BaseModel): - name: str | None = None + name: ConversationName | None = None class ConversationMessagePatchDB(BaseModel): diff --git a/services/web/server/src/simcore_service_webserver/conversations/_conversation_message_service.py b/services/web/server/src/simcore_service_webserver/conversations/_conversation_message_service.py index caeec8b030c..5b4f8397648 100644 --- a/services/web/server/src/simcore_service_webserver/conversations/_conversation_message_service.py +++ b/services/web/server/src/simcore_service_webserver/conversations/_conversation_message_service.py @@ -16,12 +16,10 @@ from models_library.rest_pagination import PageTotalCount from models_library.users import UserID -from ..projects._groups_repository import list_project_groups -from ..users._users_service import get_users_in_group - # Import or define SocketMessageDict from ..users.api import get_user_primary_group_id from . import _conversation_message_repository +from ._conversation_service import _get_recipients from ._socketio import ( notify_conversation_message_created, notify_conversation_message_deleted, @@ -31,16 +29,6 @@ _logger = logging.getLogger(__name__) -async def _get_recipients(app: web.Application, project_id: ProjectID) -> set[UserID]: - groups = await list_project_groups(app, project_id=project_id) - return { - user - for group in groups - if group.read - for user in await get_users_in_group(app, gid=group.gid) - } - - async def create_message( app: web.Application, *, diff --git a/services/web/server/src/simcore_service_webserver/conversations/_conversation_service.py b/services/web/server/src/simcore_service_webserver/conversations/_conversation_service.py index e4541f56c3f..fda9dde006a 100644 --- a/services/web/server/src/simcore_service_webserver/conversations/_conversation_service.py +++ b/services/web/server/src/simcore_service_webserver/conversations/_conversation_service.py @@ -16,12 +16,29 @@ from models_library.rest_pagination import PageTotalCount from models_library.users import UserID +from ..conversations._socketio import ( + notify_conversation_created, + notify_conversation_deleted, + notify_conversation_updated, +) +from ..projects._groups_repository import list_project_groups +from ..users._users_service import get_users_in_group from ..users.api import get_user_primary_group_id from . import _conversation_repository _logger = logging.getLogger(__name__) +async def _get_recipients(app: web.Application, project_id: ProjectID) -> set[UserID]: + groups = await list_project_groups(app, project_id=project_id) + return { + user + for group in groups + if group.read + for user in await get_users_in_group(app, gid=group.gid) + } + + async def create_conversation( app: web.Application, *, @@ -37,7 +54,7 @@ async def create_conversation( _user_group_id = await get_user_primary_group_id(app, user_id=user_id) - return await _conversation_repository.create( + created_conversation = await _conversation_repository.create( app, name=name, project_uuid=project_uuid, @@ -46,6 +63,15 @@ async def create_conversation( product_name=product_name, ) + await notify_conversation_created( + app, + recipients=await _get_recipients(app, project_uuid), + project_id=project_uuid, + conversation=created_conversation, + ) + + return created_conversation + async def get_conversation( app: web.Application, @@ -61,20 +87,33 @@ async def get_conversation( async def update_conversation( app: web.Application, *, + project_id: ProjectID, conversation_id: ConversationID, # Update attributes updates: ConversationPatchDB, ) -> ConversationGetDB: - return await _conversation_repository.update( + updated_conversation = await _conversation_repository.update( app, conversation_id=conversation_id, updates=updates, ) + await notify_conversation_updated( + app, + recipients=await _get_recipients(app, project_id), + project_id=project_id, + conversation=updated_conversation, + ) + + return updated_conversation + async def delete_conversation( app: web.Application, *, + product_name: ProductName, + user_id: UserID, + project_id: ProjectID, conversation_id: ConversationID, ) -> None: await _conversation_repository.delete( @@ -82,6 +121,17 @@ async def delete_conversation( conversation_id=conversation_id, ) + _user_group_id = await get_user_primary_group_id(app, user_id=user_id) + + await notify_conversation_deleted( + app, + recipients=await _get_recipients(app, project_id), + product_name=product_name, + user_group_id=_user_group_id, + project_id=project_id, + conversation_id=conversation_id, + ) + async def list_conversations_for_project( app: web.Application, diff --git a/services/web/server/src/simcore_service_webserver/conversations/_socketio.py b/services/web/server/src/simcore_service_webserver/conversations/_socketio.py index 03761ca4961..ad232f639e8 100644 --- a/services/web/server/src/simcore_service_webserver/conversations/_socketio.py +++ b/services/web/server/src/simcore_service_webserver/conversations/_socketio.py @@ -3,12 +3,16 @@ from aiohttp import web from models_library.conversations import ( + ConversationGetDB, ConversationID, ConversationMessageGetDB, ConversationMessageID, ConversationMessageType, + ConversationName, + ConversationType, ) from models_library.groups import GroupID +from models_library.products import ProductName from models_library.projects import ProjectID from models_library.socketio import SocketMessageDict from models_library.users import UserID @@ -20,6 +24,10 @@ _MAX_CONCURRENT_SENDS: Final[int] = 3 +SOCKET_IO_CONVERSATION_CREATED_EVENT: Final[str] = "conversation:created" +SOCKET_IO_CONVERSATION_DELETED_EVENT: Final[str] = "conversation:deleted" +SOCKET_IO_CONVERSATION_UPDATED_EVENT: Final[str] = "conversation:updated" + SOCKET_IO_CONVERSATION_MESSAGE_CREATED_EVENT: Final[str] = ( "conversation:message:created" ) @@ -31,7 +39,34 @@ ) -class BaseConversationMessage(BaseModel): +class BaseEvent(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + from_attributes=True, + alias_generator=AliasGenerator( + serialization_alias=to_camel, + ), + ) + + +class BaseConversationEvent(BaseEvent): + product_name: ProductName + project_id: ProjectID | None + user_group_id: GroupID + conversation_id: ConversationID + type: ConversationType + + +class ConversationCreatedOrUpdatedEvent(BaseConversationEvent): + name: ConversationName + created: datetime.datetime + modified: datetime.datetime + + +class ConversationDeletedEvent(BaseConversationEvent): ... + + +class BaseConversationMessageEvent(BaseEvent): conversation_id: ConversationID message_id: ConversationMessageID user_group_id: GroupID @@ -46,13 +81,13 @@ class BaseConversationMessage(BaseModel): ) -class ConversationMessageCreatedOrUpdated(BaseConversationMessage): +class ConversationMessageCreatedOrUpdatedEvent(BaseConversationMessageEvent): content: str created: datetime.datetime modified: datetime.datetime -class ConversationMessageDeleted(BaseConversationMessage): ... +class ConversationMessageDeletedEvent(BaseConversationMessageEvent): ... async def _send_message_to_recipients( @@ -62,9 +97,7 @@ async def _send_message_to_recipients( ): async for _ in limited_as_completed( ( - send_message_to_user( - app, recipient, notification_message, ignore_queue=True - ) + send_message_to_user(app, recipient, notification_message) for recipient in recipients ), limit=_MAX_CONCURRENT_SENDS, @@ -72,6 +105,71 @@ async def _send_message_to_recipients( ... +async def notify_conversation_created( + app: web.Application, + *, + recipients: set[UserID], + project_id: ProjectID, + conversation: ConversationGetDB, +) -> None: + notification_message = SocketMessageDict( + event_type=SOCKET_IO_CONVERSATION_CREATED_EVENT, + data={ + **ConversationCreatedOrUpdatedEvent( + project_id=project_id, + **conversation.model_dump(), + ).model_dump(mode="json", by_alias=True), + }, + ) + + await _send_message_to_recipients(app, recipients, notification_message) + + +async def notify_conversation_updated( + app: web.Application, + *, + recipients: set[UserID], + project_id: ProjectID, + conversation: ConversationGetDB, +) -> None: + notification_message = SocketMessageDict( + event_type=SOCKET_IO_CONVERSATION_UPDATED_EVENT, + data={ + **ConversationCreatedOrUpdatedEvent( + project_id=project_id, + **conversation.model_dump(), + ).model_dump(mode="json", by_alias=True), + }, + ) + + await _send_message_to_recipients(app, recipients, notification_message) + + +async def notify_conversation_deleted( + app: web.Application, + *, + recipients: set[UserID], + product_name: ProductName, + user_group_id: GroupID, + project_id: ProjectID, + conversation_id: ConversationID, +) -> None: + notification_message = SocketMessageDict( + event_type=SOCKET_IO_CONVERSATION_DELETED_EVENT, + data={ + **ConversationDeletedEvent( + product_name=product_name, + project_id=project_id, + conversation_id=conversation_id, + user_group_id=user_group_id, + type=ConversationType.PROJECT_STATIC, + ).model_dump(mode="json", by_alias=True), + }, + ) + + await _send_message_to_recipients(app, recipients, notification_message) + + async def notify_conversation_message_created( app: web.Application, *, @@ -83,7 +181,7 @@ async def notify_conversation_message_created( event_type=SOCKET_IO_CONVERSATION_MESSAGE_CREATED_EVENT, data={ "projectId": project_id, - **ConversationMessageCreatedOrUpdated( + **ConversationMessageCreatedOrUpdatedEvent( **conversation_message.model_dump() ).model_dump(mode="json", by_alias=True), }, @@ -104,7 +202,7 @@ async def notify_conversation_message_updated( event_type=SOCKET_IO_CONVERSATION_MESSAGE_UPDATED_EVENT, data={ "projectId": project_id, - **ConversationMessageCreatedOrUpdated( + **ConversationMessageCreatedOrUpdatedEvent( **conversation_message.model_dump() ).model_dump(mode="json", by_alias=True), }, @@ -127,7 +225,7 @@ async def notify_conversation_message_deleted( event_type=SOCKET_IO_CONVERSATION_MESSAGE_DELETED_EVENT, data={ "projectId": project_id, - **ConversationMessageDeleted( + **ConversationMessageDeletedEvent( conversation_id=conversation_id, message_id=message_id, user_group_id=user_group_id, diff --git a/services/web/server/src/simcore_service_webserver/projects/_conversations_service.py b/services/web/server/src/simcore_service_webserver/projects/_conversations_service.py index b415f694bcf..48aeae02cc6 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_conversations_service.py +++ b/services/web/server/src/simcore_service_webserver/projects/_conversations_service.py @@ -8,6 +8,7 @@ ConversationMessageID, ConversationMessagePatchDB, ConversationMessageType, + ConversationName, ConversationPatchDB, ConversationType, ) @@ -87,7 +88,7 @@ async def update_project_conversation( project_uuid: ProjectID, conversation_id: ConversationID, # attributes - name: str, + name: ConversationName, ) -> ConversationGetDB: await check_user_project_permission( app, @@ -98,6 +99,7 @@ async def update_project_conversation( ) return await conversations_service.update_conversation( app, + project_id=project_uuid, conversation_id=conversation_id, updates=ConversationPatchDB(name=name), ) @@ -119,7 +121,11 @@ async def delete_project_conversation( permission="read", ) await conversations_service.delete_conversation( - app, conversation_id=conversation_id + app, + product_name=product_name, + project_id=project_uuid, + user_id=user_id, + conversation_id=conversation_id, ) diff --git a/services/web/server/tests/unit/with_dbs/02/test_projects_conversations_handlers.py b/services/web/server/tests/unit/with_dbs/02/test_projects_conversations_handlers.py index 42f1960a9f9..606f5368120 100644 --- a/services/web/server/tests/unit/with_dbs/02/test_projects_conversations_handlers.py +++ b/services/web/server/tests/unit/with_dbs/02/test_projects_conversations_handlers.py @@ -6,12 +6,13 @@ # pylint: disable=too-many-statements -from collections.abc import Callable +from collections.abc import Callable, Iterable from http import HTTPStatus -from unittest.mock import MagicMock +from types import SimpleNamespace import pytest -import simcore_service_webserver.conversations._conversation_message_service +import simcore_service_webserver.conversations._conversation_message_service as conversation_message_service +import simcore_service_webserver.conversations._conversation_service as conversation_service import sqlalchemy as sa from aiohttp.test_utils import TestClient from models_library.api_schemas_webserver.projects_conversations import ( @@ -33,14 +34,18 @@ @pytest.fixture -def mock_notify_function(mocker: MockerFixture) -> Callable[[str], MagicMock]: - def _mock(function_name: str) -> MagicMock: - return mocker.patch.object( - simcore_service_webserver.conversations._conversation_message_service, - function_name, +def mock_functions_factory( + mocker: MockerFixture, +) -> Callable[[Iterable[tuple[object, str]]], SimpleNamespace]: + def _patch(targets_and_names: Iterable[tuple[object, str]]) -> SimpleNamespace: + return SimpleNamespace( + **{ + name: mocker.patch.object(target, name) + for target, name in targets_and_names + } ) - return _mock + return _patch @pytest.mark.parametrize( @@ -81,7 +86,16 @@ async def test_project_conversations_full_workflow( logged_user: UserInfoDict, user_project: ProjectDict, expected: HTTPStatus, + mock_functions_factory: Callable[[Iterable[tuple[object, str]]], SimpleNamespace], ): + mocks = mock_functions_factory( + [ + (conversation_service, "notify_conversation_created"), + (conversation_service, "notify_conversation_updated"), + (conversation_service, "notify_conversation_deleted"), + ] + ) + base_url = client.app.router["list_project_conversations"].url_for( project_id=user_project["uuid"] ) @@ -106,6 +120,12 @@ async def test_project_conversations_full_workflow( assert ConversationRestGet.model_validate(data) _first_conversation_id = data["conversationId"] + assert mocks.notify_conversation_created.call_count == 1 + kwargs = mocks.notify_conversation_created.call_args.kwargs + + assert f"{kwargs['project_id']}" == user_project["uuid"] + assert kwargs["conversation"].name == "My conversation" + # Now we will create second conversation body = {"name": "My conversation", "type": "PROJECT_ANNOTATION"} resp = await client.post(f"{base_url}", json=body) @@ -115,6 +135,12 @@ async def test_project_conversations_full_workflow( ) assert ConversationRestGet.model_validate(data) + assert mocks.notify_conversation_created.call_count == 2 + kwargs = mocks.notify_conversation_created.call_args.kwargs + + assert f"{kwargs['project_id']}" == user_project["uuid"] + assert kwargs["conversation"].name == "My conversation" + # Now we will list all conversations for the project resp = await client.get(f"{base_url}") data, _, meta, links = await assert_status( @@ -145,6 +171,12 @@ async def test_project_conversations_full_workflow( ) assert data["name"] == updated_name + assert mocks.notify_conversation_updated.call_count == 1 + kwargs = mocks.notify_conversation_updated.call_args.kwargs + + assert f"{kwargs['project_id']}" == user_project["uuid"] + assert kwargs["conversation"].name == updated_name + # Now we will delete the first conversation resp = await client.delete(f"{base_url}/{_first_conversation_id}") data, _ = await assert_status( @@ -152,6 +184,11 @@ async def test_project_conversations_full_workflow( status.HTTP_204_NO_CONTENT, ) + assert mocks.notify_conversation_deleted.call_count == 1 + kwargs = mocks.notify_conversation_deleted.call_args.kwargs + + assert f"{kwargs['conversation_id']}" == _first_conversation_id + # Now we will list all conversations for the project resp = await client.get(f"{base_url}") data, _, meta = await assert_status( @@ -178,16 +215,14 @@ async def test_project_conversation_messages_full_workflow( user_project: ProjectDict, expected: HTTPStatus, postgres_db: sa.engine.Engine, - mock_notify_function: Callable[[str], MagicMock], + mock_functions_factory: Callable[[Iterable[tuple[object, str]]], SimpleNamespace], ): - mocked_notify_conversation_message_created = mock_notify_function( - "notify_conversation_message_created" - ) - mocked_notify_conversation_message_updated = mock_notify_function( - "notify_conversation_message_updated" - ) - mocked_notify_conversation_message_deleted = mock_notify_function( - "notify_conversation_message_deleted" + mocks = mock_functions_factory( + [ + (conversation_message_service, "notify_conversation_message_created"), + (conversation_message_service, "notify_conversation_message_updated"), + (conversation_message_service, "notify_conversation_message_deleted"), + ] ) base_project_url = client.app.router["list_project_conversations"].url_for( @@ -217,8 +252,8 @@ async def test_project_conversation_messages_full_workflow( assert ConversationMessageRestGet.model_validate(data) _first_message_id = data["messageId"] - assert mocked_notify_conversation_message_created.call_count == 1 - kwargs = mocked_notify_conversation_message_created.call_args.kwargs + assert mocks.notify_conversation_message_created.call_count == 1 + kwargs = mocks.notify_conversation_message_created.call_args.kwargs assert f"{kwargs['project_id']}" == user_project["uuid"] assert kwargs["conversation_message"].content == "My first message" @@ -233,8 +268,8 @@ async def test_project_conversation_messages_full_workflow( assert ConversationMessageRestGet.model_validate(data) _second_message_id = data["messageId"] - assert mocked_notify_conversation_message_created.call_count == 2 - kwargs = mocked_notify_conversation_message_created.call_args.kwargs + assert mocks.notify_conversation_message_created.call_count == 2 + kwargs = mocks.notify_conversation_message_created.call_args.kwargs assert user_project["uuid"] == f"{kwargs['project_id']}" assert kwargs["conversation_message"].content == "My second message" @@ -265,8 +300,8 @@ async def test_project_conversation_messages_full_workflow( expected, ) - assert mocked_notify_conversation_message_updated.call_count == 1 - kwargs = mocked_notify_conversation_message_updated.call_args.kwargs + assert mocks.notify_conversation_message_updated.call_count == 1 + kwargs = mocks.notify_conversation_message_updated.call_args.kwargs assert user_project["uuid"] == f"{kwargs['project_id']}" assert kwargs["conversation_message"].content == updated_content @@ -301,8 +336,8 @@ async def test_project_conversation_messages_full_workflow( status.HTTP_204_NO_CONTENT, ) - assert mocked_notify_conversation_message_deleted.call_count == 1 - kwargs = mocked_notify_conversation_message_deleted.call_args.kwargs + assert mocks.notify_conversation_message_deleted.call_count == 1 + kwargs = mocks.notify_conversation_message_deleted.call_args.kwargs assert f"{kwargs['project_id']}" == user_project["uuid"] assert f"{kwargs['conversation_id']}" == _conversation_id @@ -399,8 +434,8 @@ async def test_project_conversation_messages_full_workflow( status.HTTP_204_NO_CONTENT, ) - assert mocked_notify_conversation_message_deleted.call_count == 2 - kwargs = mocked_notify_conversation_message_deleted.call_args.kwargs + assert mocks.notify_conversation_message_deleted.call_count == 2 + kwargs = mocks.notify_conversation_message_deleted.call_args.kwargs assert f"{kwargs['project_id']}" == user_project["uuid"] assert f"{kwargs['conversation_id']}" == _conversation_id