Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions packages/models-library/src/models_library/conversations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from datetime import datetime
from enum import auto
from typing import TypeAlias
from typing import Annotated, TypeAlias
from uuid import UUID

from models_library.groups import GroupID
from models_library.projects import ProjectID
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, StringConstraints

from .products import ProductName
from .utils.enums import StrAutoEnum

ConversationID: TypeAlias = UUID
ConversationName: TypeAlias = Annotated[
str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255)
]

ConversationMessageID: TypeAlias = UUID


Expand All @@ -36,7 +40,7 @@ class ConversationMessageType(StrAutoEnum):
class ConversationGetDB(BaseModel):
conversation_id: ConversationID
product_name: ProductName
name: str
name: ConversationName
project_uuid: ProjectID | None
user_group_id: GroupID
type: ConversationType
Expand All @@ -63,7 +67,7 @@ class ConversationMessageGetDB(BaseModel):


class ConversationPatchDB(BaseModel):
name: str | None = None
name: ConversationName | None = None


class ConversationMessagePatchDB(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
from models_library.rest_pagination import PageTotalCount
from models_library.users import UserID

from ..projects._groups_repository import list_project_groups
from ..users._users_service import get_users_in_group

# Import or define SocketMessageDict
from ..users.api import get_user_primary_group_id
from . import _conversation_message_repository
from ._conversation_service import _get_recipients
from ._socketio import (
notify_conversation_message_created,
notify_conversation_message_deleted,
Expand All @@ -31,16 +29,6 @@
_logger = logging.getLogger(__name__)


async def _get_recipients(app: web.Application, project_id: ProjectID) -> set[UserID]:
groups = await list_project_groups(app, project_id=project_id)
return {
user
for group in groups
if group.read
for user in await get_users_in_group(app, gid=group.gid)
}


async def create_message(
app: web.Application,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,29 @@
from models_library.rest_pagination import PageTotalCount
from models_library.users import UserID

from ..conversations._socketio import (
notify_conversation_created,
notify_conversation_deleted,
notify_conversation_updated,
)
from ..projects._groups_repository import list_project_groups
from ..users._users_service import get_users_in_group
from ..users.api import get_user_primary_group_id
from . import _conversation_repository

_logger = logging.getLogger(__name__)


async def _get_recipients(app: web.Application, project_id: ProjectID) -> set[UserID]:
groups = await list_project_groups(app, project_id=project_id)
return {
user
for group in groups
if group.read
for user in await get_users_in_group(app, gid=group.gid)
}


async def create_conversation(
app: web.Application,
*,
Expand All @@ -37,7 +54,7 @@ async def create_conversation(

_user_group_id = await get_user_primary_group_id(app, user_id=user_id)

return await _conversation_repository.create(
created_conversation = await _conversation_repository.create(
app,
name=name,
project_uuid=project_uuid,
Expand All @@ -46,6 +63,15 @@ async def create_conversation(
product_name=product_name,
)

await notify_conversation_created(
app,
recipients=await _get_recipients(app, project_uuid),
project_id=project_uuid,
conversation=created_conversation,
)

return created_conversation


async def get_conversation(
app: web.Application,
Expand All @@ -61,27 +87,51 @@ async def get_conversation(
async def update_conversation(
app: web.Application,
*,
project_id: ProjectID,
conversation_id: ConversationID,
# Update attributes
updates: ConversationPatchDB,
) -> ConversationGetDB:
return await _conversation_repository.update(
updated_conversation = await _conversation_repository.update(
app,
conversation_id=conversation_id,
updates=updates,
)

await notify_conversation_updated(
app,
recipients=await _get_recipients(app, project_id),
project_id=project_id,
conversation=updated_conversation,
)

return updated_conversation


async def delete_conversation(
app: web.Application,
*,
product_name: ProductName,
project_id: ProjectID,
user_id: UserID,
conversation_id: ConversationID,
) -> None:
await _conversation_repository.delete(
app,
conversation_id=conversation_id,
)

_user_group_id = await get_user_primary_group_id(app, user_id=user_id)

await notify_conversation_deleted(
app,
recipients=await _get_recipients(app, project_id),
product_name=product_name,
user_group_id=_user_group_id,
project_id=project_id,
conversation_id=conversation_id,
)


async def list_conversations_for_project(
app: web.Application,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@

from aiohttp import web
from models_library.conversations import (
ConversationGetDB,
ConversationID,
ConversationMessageGetDB,
ConversationMessageID,
ConversationMessageType,
ConversationName,
ConversationType,
)
from models_library.groups import GroupID
from models_library.products import ProductName
from models_library.projects import ProjectID
from models_library.socketio import SocketMessageDict
from models_library.users import UserID
Expand All @@ -20,6 +24,10 @@

_MAX_CONCURRENT_SENDS: Final[int] = 3

SOCKET_IO_CONVERSATION_CREATED_EVENT: Final[str] = "conversation:created"
SOCKET_IO_CONVERSATION_DELETED_EVENT: Final[str] = "conversation:deleted"
SOCKET_IO_CONVERSATION_UPDATED_EVENT: Final[str] = "conversation:updated"

SOCKET_IO_CONVERSATION_MESSAGE_CREATED_EVENT: Final[str] = (
"conversation:message:created"
)
Expand All @@ -31,7 +39,34 @@
)


class BaseConversationMessage(BaseModel):
class BaseEvent(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
from_attributes=True,
alias_generator=AliasGenerator(
serialization_alias=to_camel,
),
)


class BaseConversationEvent(BaseEvent):
product_name: ProductName
project_id: ProjectID | None
user_group_id: GroupID
conversation_id: ConversationID
type: ConversationType


class ConversationCreatedOrUpdatedEvent(BaseConversationEvent):
name: ConversationName
created: datetime.datetime
modified: datetime.datetime


class ConversationDeletedEvent(BaseConversationEvent): ...


class BaseConversationMessageEvent(BaseEvent):
conversation_id: ConversationID
message_id: ConversationMessageID
user_group_id: GroupID
Expand All @@ -46,13 +81,13 @@ class BaseConversationMessage(BaseModel):
)


class ConversationMessageCreatedOrUpdated(BaseConversationMessage):
class ConversationMessageCreatedOrUpdatedEvent(BaseConversationMessageEvent):
content: str
created: datetime.datetime
modified: datetime.datetime


class ConversationMessageDeleted(BaseConversationMessage): ...
class ConversationMessageDeletedEvent(BaseConversationMessageEvent): ...


async def _send_message_to_recipients(
Expand All @@ -62,16 +97,79 @@ async def _send_message_to_recipients(
):
async for _ in limited_as_completed(
(
send_message_to_user(
app, recipient, notification_message, ignore_queue=True
)
send_message_to_user(app, recipient, notification_message)
for recipient in recipients
),
limit=_MAX_CONCURRENT_SENDS,
):
...


async def notify_conversation_created(
app: web.Application,
*,
recipients: set[UserID],
project_id: ProjectID,
conversation: ConversationGetDB,
) -> None:
notification_message = SocketMessageDict(
event_type=SOCKET_IO_CONVERSATION_CREATED_EVENT,
data={
**ConversationCreatedOrUpdatedEvent(
project_id=project_id,
**conversation.model_dump(),
).model_dump(mode="json", by_alias=True),
},
)

await _send_message_to_recipients(app, recipients, notification_message)


async def notify_conversation_updated(
app: web.Application,
*,
recipients: set[UserID],
project_id: ProjectID,
conversation: ConversationGetDB,
) -> None:
notification_message = SocketMessageDict(
event_type=SOCKET_IO_CONVERSATION_UPDATED_EVENT,
data={
**ConversationCreatedOrUpdatedEvent(
project_id=project_id,
**conversation.model_dump(),
).model_dump(mode="json", by_alias=True),
},
)

await _send_message_to_recipients(app, recipients, notification_message)


async def notify_conversation_deleted(
app: web.Application,
*,
recipients: set[UserID],
product_name: ProductName,
project_id: ProjectID,
user_group_id: GroupID,
conversation_id: ConversationID,
) -> None:
notification_message = SocketMessageDict(
event_type=SOCKET_IO_CONVERSATION_DELETED_EVENT,
data={
**ConversationDeletedEvent(
product_name=product_name,
project_id=project_id,
conversation_id=conversation_id,
user_group_id=user_group_id,
type=ConversationType.PROJECT_STATIC,
).model_dump(mode="json", by_alias=True),
},
)

await _send_message_to_recipients(app, recipients, notification_message)


async def notify_conversation_message_created(
app: web.Application,
*,
Expand All @@ -83,7 +181,7 @@ async def notify_conversation_message_created(
event_type=SOCKET_IO_CONVERSATION_MESSAGE_CREATED_EVENT,
data={
"projectId": project_id,
**ConversationMessageCreatedOrUpdated(
**ConversationMessageCreatedOrUpdatedEvent(
**conversation_message.model_dump()
).model_dump(mode="json", by_alias=True),
},
Expand All @@ -104,7 +202,7 @@ async def notify_conversation_message_updated(
event_type=SOCKET_IO_CONVERSATION_MESSAGE_UPDATED_EVENT,
data={
"projectId": project_id,
**ConversationMessageCreatedOrUpdated(
**ConversationMessageCreatedOrUpdatedEvent(
**conversation_message.model_dump()
).model_dump(mode="json", by_alias=True),
},
Expand All @@ -127,7 +225,7 @@ async def notify_conversation_message_deleted(
event_type=SOCKET_IO_CONVERSATION_MESSAGE_DELETED_EVENT,
data={
"projectId": project_id,
**ConversationMessageDeleted(
**ConversationMessageDeletedEvent(
conversation_id=conversation_id,
message_id=message_id,
user_group_id=user_group_id,
Expand Down
Loading
Loading