Skip to content
16 changes: 15 additions & 1 deletion business_objects/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from . import general, organization, team_member
from .. import User, enums
from .. import User, enums, Team, TeamMember, TeamResource
from ..session import session
from typing import List, Optional
from sqlalchemy import sql
Expand Down Expand Up @@ -52,6 +52,20 @@ def get_all(
return query.all()


def get_all_team_members_by_project(project_id: str) -> List[User]:
query = (
session.query(User)
.join(TeamMember, TeamMember.user_id == User.id)
.join(Team, Team.id == TeamMember.team_id)
.join(TeamResource, TeamResource.team_id == Team.id)
.filter(TeamResource.resource_id == project_id)
.filter(
TeamResource.resource_type == enums.TeamResourceType.COGNITION_PROJECT.value
)
)
return query.all()


def get_count_assigned() -> int:
return session.query(User.id).filter(User.organization_id != None).count()

Expand Down
8 changes: 8 additions & 0 deletions cognition_objects/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def get(project_id: str, conversation_id: str) -> CognitionConversation:
)


def get_by_id(conversation_id: str) -> CognitionConversation:
return (
session.query(CognitionConversation)
.filter(CognitionConversation.id == conversation_id)
.first()
)


def exists(project_id: str, conversation_id: str) -> bool:
return (
session.query(CognitionConversation)
Expand Down
68 changes: 68 additions & 0 deletions cognition_objects/conversation_global_shares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from operator import or_
from typing import List, Optional
from ..business_objects import general
from ..session import session
from ..models import CognitionConversation, ConversationGlobalShare
from submodules.model.util import sql_alchemy_to_dict


def get(conversation_global_share_id: str) -> Optional[ConversationGlobalShare]:
return (
session.query(ConversationGlobalShare)
.filter(ConversationGlobalShare.id == conversation_global_share_id)
.first()
)


def get_by_conversation(conversation_id: str) -> List[ConversationGlobalShare]:
return (
session.query(ConversationGlobalShare)
.filter(ConversationGlobalShare.conversation_id == conversation_id)
.first()
)


def create(
conversation_id: str,
shared_by: str,
with_commit: bool = True,
) -> ConversationGlobalShare:
global_share = ConversationGlobalShare(
conversation_id=conversation_id, shared_by=shared_by
)
general.add(global_share, with_commit)
return global_share


def delete_by_conversation(
conversation_id: str, user_id: str, with_commit: bool = True
):
(
session.query(ConversationGlobalShare)
.filter(
ConversationGlobalShare.conversation_id == conversation_id,
ConversationGlobalShare.shared_by == user_id,
)
.delete()
)
if with_commit:
session.commit()


def get_by_user(project_id: str, user_id: str) -> List[ConversationGlobalShare]:
conversation_global_shares = (
session.query(ConversationGlobalShare, CognitionConversation.header)
.join(
CognitionConversation,
ConversationGlobalShare.conversation_id == CognitionConversation.id,
)
.filter(ConversationGlobalShare.shared_by == user_id)
.filter(CognitionConversation.project_id == project_id)
.all()
)
conversation_global_shares_dict = []
for share_obj, header in conversation_global_shares:
share_dict = sql_alchemy_to_dict(share_obj)
share_dict["conversation_header"] = header
conversation_global_shares_dict.append(share_dict)
return conversation_global_shares_dict
187 changes: 187 additions & 0 deletions cognition_objects/conversation_shares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from operator import or_
from typing import List, Optional
from ..business_objects import general
from ..session import session
from ..models import ConversationShare, CognitionConversation
from submodules.model.util import sql_alchemy_to_dict


def get(share_id: str, user_id: str) -> Optional[ConversationShare]:
return (
session.query(ConversationShare)
.filter(ConversationShare.id == share_id)
.filter(
or_(
ConversationShare.shared_by == user_id,
ConversationShare.shared_with == user_id,
)
)
.first()
)


def get_all_by_conversation(conversation_id: str) -> List[ConversationShare]:
return (
session.query(ConversationShare)
.filter(ConversationShare.conversation_id == conversation_id)
.all()
)


def update_by_conversation(
conversation_id: str,
user_id: str,
shared_with: List[str],
can_copy: Optional[bool] = None,
with_commit: bool = True,
) -> List[ConversationShare]:
existing_shares = (
session.query(ConversationShare)
.filter(ConversationShare.conversation_id == conversation_id)
.filter(ConversationShare.shared_by == user_id)
.all()
)

existing_shared_with = {share.shared_with: share for share in existing_shares}
shared_with_set = set(shared_with)

for share in existing_shares:
if share.shared_with not in shared_with_set:
session.delete(share)

for sharing_user_id in shared_with:
if sharing_user_id not in existing_shared_with:
share = ConversationShare(
conversation_id=conversation_id,
shared_with=sharing_user_id,
shared_by=user_id,
can_copy=can_copy if can_copy is not None else False,
)
general.add(share, with_commit=False)
else:
if can_copy is not None:
existing_shared_with[sharing_user_id].can_copy = can_copy

general.flush_or_commit(with_commit)

updated_shares = (
session.query(ConversationShare)
.filter(ConversationShare.conversation_id == conversation_id)
.filter(ConversationShare.shared_by == user_id)
.all()
)
return updated_shares


def get_all_shared_by_or_for_user(
project_id: str, user_id: str, with_header: bool = True
) -> List[ConversationShare]:
conversation_shares = (
session.query(ConversationShare, CognitionConversation.header)
.join(
CognitionConversation,
ConversationShare.conversation_id == CognitionConversation.id,
)
.filter(CognitionConversation.project_id == project_id)
.filter(
or_(
ConversationShare.shared_by == user_id,
ConversationShare.shared_with == user_id,
)
)
.all()
)
conversation_shares_dict = []
for share_obj, header in conversation_shares:
share_dict = sql_alchemy_to_dict(share_obj)
share_dict["conversation_header"] = header
conversation_shares_dict.append(share_dict)

return conversation_shares_dict


def get_all_shared_by_user(user_id: str) -> List[ConversationShare]:
return (
session.query(ConversationShare)
.filter(ConversationShare.shared_by == user_id)
.all()
)


def create(
conversation_id: str,
shared_with: str,
shared_by: str,
can_copy: bool = False,
with_commit: bool = True,
) -> ConversationShare:
share = ConversationShare(
conversation_id=conversation_id,
shared_with=shared_with,
shared_by=shared_by,
can_copy=can_copy,
)
general.add(share, with_commit)
return share


def create_many(
conversation_id: str,
shared_with_user_ids: List[str],
shared_by: str,
can_copy: bool = False,
with_commit: bool = True,
) -> List[ConversationShare]:

shares = []

for user_id in shared_with_user_ids:
share = ConversationShare(
conversation_id=conversation_id,
shared_with=user_id,
shared_by=shared_by,
can_copy=can_copy,
)
general.add(share, with_commit=False)
shares.append(share)

general.flush_or_commit(with_commit)
return shares


def update(
share_id: str,
user_id: str,
can_copy: Optional[bool] = None,
with_commit: bool = True,
) -> Optional[ConversationShare]:
share_entity = get(share_id)
if share_entity is None:
return None

if str(share_entity.shared_by) != user_id:
raise ValueError("You are not allowed to update this sharing context.")

if can_copy is not None:
share_entity.can_copy = can_copy

general.flush_or_commit(with_commit)
return share_entity


def delete_shared_with(
conversation_share_id: str, user_id: str, with_commit: bool = True
) -> None:
session.query(ConversationShare).filter(
ConversationShare.id == conversation_share_id
).filter(ConversationShare.shared_with == user_id).delete()
general.flush_or_commit(with_commit)


def delete_shared_by_by_conversation_id(
conversation_id: str, user_id: str, with_commit: bool = True
) -> None:
session.query(ConversationShare).filter(
ConversationShare.conversation_id == conversation_id
).filter(ConversationShare.shared_by == user_id).delete()
general.flush_or_commit(with_commit)
12 changes: 12 additions & 0 deletions cognition_objects/conversation_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ def delete_association(
general.flush_or_commit(with_commit)


def delete_associations_by_conversation(
conversation_id: str,
with_commit: bool = True,
) -> None:

session.query(CognitionConversationTagAssociation).filter(
CognitionConversationTagAssociation.conversation_id == conversation_id,
).delete(synchronize_session=False)

general.flush_or_commit(with_commit)


def get_lookup_by_conversation_ids(
conversation_ids: List[str],
) -> Dict[str, List[Dict[str, Any]]]:
Expand Down
8 changes: 8 additions & 0 deletions cognition_objects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def update(
llm_config: Optional[Dict[str, Any]] = None,
tokenizer: Optional[str] = None,
icon: Optional[str] = None,
allow_conversation_sharing_organization: Optional[bool] = None,
allow_conversation_sharing_global: Optional[bool] = None,
with_commit: bool = True,
) -> CognitionProject:
project: CognitionProject = get(project_id)
Expand Down Expand Up @@ -288,6 +290,12 @@ def update(
project.tokenizer = tokenizer
if icon is not None:
project.icon = icon
if allow_conversation_sharing_organization is not None:
project.allow_conversation_sharing_organization = (
allow_conversation_sharing_organization
)
if allow_conversation_sharing_global is not None:
project.allow_conversation_sharing_global = allow_conversation_sharing_global
general.flush_or_commit(with_commit)
return project

Expand Down
2 changes: 2 additions & 0 deletions enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class Tablenames(Enum):
ADMIN_QUERY_MESSAGE_SUMMARY = "admin_query_message_summary"
RELEASE_NOTIFICATION = "release_notification"
TIMED_EXECUTIONS = "timed_executions"
CONVERSATION_SHARE = "conversation_share"
CONVERSATION_GLOBAL_SHARE = "conversation_global_share"

def snake_case_to_pascal_case(self):
# the type name (written in PascalCase) of a table is needed to create backrefs
Expand Down
Loading