Skip to content

Commit f9ef28a

Browse files
Conversation tagging (#181)
* Model & first methods * Adds more functions * Dummy for merge fix * Lookup function * Missing tagged chats * Admin queries extension * Cached user fix * check state fix --------- Co-authored-by: Lina <[email protected]>
1 parent a169a62 commit f9ef28a

File tree

7 files changed

+378
-2
lines changed

7 files changed

+378
-2
lines changed

business_objects/user.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def get(user_id: str) -> User:
1616

1717
def get_user_cached_if_not_admin(user_id: str) -> Optional[User]:
1818
user = get_user_cached(user_id)
19+
if not user:
20+
# cache is None, but user is automatically created so we recollect to be sure
21+
return get(user_id)
1922
if (user.email or "").endswith("@kern.ai") and user.verified:
2023
# for admins this could result in two db requests shortly after each other
2124
# but it's better than having the jumping users without the correct org id

cognition_objects/conversation.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from ..cognition_objects import message
66
from ..business_objects import general
77
from ..session import session
8-
from ..models import CognitionConversation, CognitionMessage
8+
from ..models import (
9+
CognitionConversation,
10+
CognitionMessage,
11+
CognitionConversationTagAssociation,
12+
)
913
from ..util import prevent_sql_injection
1014
from sqlalchemy.sql.expression import Subquery
1115
from sqlalchemy import or_
@@ -200,6 +204,28 @@ def get_all_paginated_by_project_id(
200204
return total_count, num_pages, paginated_result
201205

202206

207+
def get_missing_tagged_conversations(
208+
project_id: str, user_id: str, tag_id: str, not_needed_conversations: List[str]
209+
) -> List[CognitionConversation]:
210+
missing = (
211+
session.query(CognitionConversation)
212+
.join(
213+
CognitionConversationTagAssociation,
214+
(
215+
CognitionConversationTagAssociation.conversation_id
216+
== CognitionConversation.id
217+
),
218+
)
219+
.filter(
220+
CognitionConversationTagAssociation.tag_id == tag_id,
221+
CognitionConversation.id.notin_(not_needed_conversations),
222+
CognitionConversation.project_id == project_id,
223+
CognitionConversation.created_by == user_id,
224+
)
225+
).all()
226+
return missing
227+
228+
203229
def __get_conversation_ids_by_filter(
204230
project_id: str,
205231
user_id: Optional[str] = None,
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from typing import Dict, List, Optional, Any
2+
3+
from ..business_objects import general
4+
from ..session import session
5+
from ..models import CognitionConversationTag, CognitionConversationTagAssociation
6+
from ..util import sql_alchemy_to_dict, prevent_sql_injection
7+
from sqlalchemy.orm.attributes import flag_modified
8+
from sqlalchemy.types import Boolean
9+
from sqlalchemy import or_
10+
11+
12+
BLACKLIST_CONVERSATION_TAG_ASSOCIATION = {"id", "conversation_id"}
13+
14+
15+
def get(tag_id: str) -> CognitionConversationTag:
16+
return (
17+
session.query(CognitionConversationTag)
18+
.filter(
19+
CognitionConversationTag.id == tag_id,
20+
)
21+
.first()
22+
)
23+
24+
25+
def get_all_by_user(user_id: str) -> List[CognitionConversationTag]:
26+
return (
27+
session.query(CognitionConversationTag)
28+
.filter(
29+
CognitionConversationTag.created_by == user_id,
30+
)
31+
.all()
32+
)
33+
34+
35+
def get_all_relevant(user_id: str, project_id: str):
36+
return (
37+
session.query(CognitionConversationTag)
38+
.filter(
39+
CognitionConversationTag.created_by == user_id,
40+
or_(
41+
# global_tag is boolean true
42+
CognitionConversationTag.config["global_tag"].astext.cast(Boolean),
43+
# use_for_projects contains project_id
44+
CognitionConversationTag.config["use_for_projects"].contains(
45+
[project_id]
46+
),
47+
),
48+
)
49+
.all()
50+
)
51+
52+
53+
def create(
54+
user_id: str,
55+
name: str,
56+
config: Dict[str, Any],
57+
with_commit: bool = True,
58+
) -> CognitionConversationTag:
59+
tag: CognitionConversationTag = CognitionConversationTag(
60+
created_by=user_id,
61+
name=name,
62+
config=config,
63+
)
64+
general.add(tag, with_commit)
65+
return tag
66+
67+
68+
def update(
69+
user_id: str,
70+
tag_id: str,
71+
name: Optional[str] = None,
72+
config: Optional[Dict[str, Any]] = None,
73+
with_commit: bool = True,
74+
) -> CognitionConversationTag:
75+
tag_entity = get(tag_id)
76+
if tag_entity is None:
77+
return
78+
if str(tag_entity.created_by) != user_id:
79+
raise ValueError("You are not allowed to update this tag.")
80+
if name is not None:
81+
tag_entity.name = name
82+
if config is not None and len(config) > 0:
83+
for key, value in config.items():
84+
if value is None:
85+
tag_entity.config.pop(key, None)
86+
else:
87+
tag_entity.config[key] = value
88+
flag_modified(tag_entity, "config")
89+
general.flush_or_commit(with_commit)
90+
return tag_entity
91+
92+
93+
def delete(tag_id: str, with_commit: bool = True) -> None:
94+
session.query(CognitionConversationTag).filter(
95+
CognitionConversationTag.id == tag_id,
96+
).delete()
97+
general.flush_or_commit(with_commit)
98+
99+
100+
def delete_many(tag_ids: List[str], with_commit: bool = True) -> None:
101+
session.query(CognitionConversationTag).filter(
102+
CognitionConversationTag.id.in_(tag_ids),
103+
).delete(synchronize_session=False)
104+
general.flush_or_commit(with_commit)
105+
106+
107+
def create_association(
108+
conversation_id: str,
109+
tag_id: str,
110+
with_commit: bool = True,
111+
) -> None:
112+
association = CognitionConversationTagAssociation(
113+
conversation_id=conversation_id,
114+
tag_id=tag_id,
115+
)
116+
general.add(association, with_commit)
117+
118+
119+
def delete_association(
120+
conversation_id: str,
121+
tag_id: str,
122+
with_commit: bool = True,
123+
) -> None:
124+
session.query(CognitionConversationTagAssociation).filter(
125+
CognitionConversationTagAssociation.conversation_id == conversation_id,
126+
CognitionConversationTagAssociation.tag_id == tag_id,
127+
).delete(synchronize_session=False)
128+
general.flush_or_commit(with_commit)
129+
130+
131+
def get_lookup_by_conversation_ids(
132+
conversation_ids: List[str],
133+
) -> Dict[str, List[Dict[str, Any]]]:
134+
associations = (
135+
session.query(CognitionConversationTagAssociation)
136+
.filter(
137+
CognitionConversationTagAssociation.conversation_id.in_(conversation_ids)
138+
)
139+
.all()
140+
)
141+
tag_lookup: Dict[str, List[Dict[str, Any]]] = {}
142+
143+
for association in associations:
144+
if str(association.conversation_id) not in tag_lookup:
145+
tag_lookup[str(association.conversation_id)] = []
146+
tag_lookup[str(association.conversation_id)].append(
147+
sql_alchemy_to_dict(
148+
association, column_blacklist=BLACKLIST_CONVERSATION_TAG_ASSOCIATION
149+
)
150+
)
151+
return tag_lookup
152+
153+
154+
def get_tag_counts(project_id: str, user_id: str) -> Dict[str, int]:
155+
156+
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
157+
user_id = prevent_sql_injection(user_id, isinstance(user_id, str))
158+
159+
query = f"""
160+
SELECT json_object_agg(tid,t_count)
161+
FROM (
162+
SELECT COALESCE(cta.tag_id::TEXT,'<untagged>') tid, COUNT(*) t_count
163+
FROM cognition.conversation C
164+
LEFT JOIN cognition.conversation_tag_association cta
165+
ON c.id = cta.conversation_id
166+
WHERE c.project_id = '{project_id}' AND c.created_by = '{user_id}'
167+
group BY cta.tag_id
168+
) x """
169+
170+
value = general.execute_first(query)
171+
if value and value[0]:
172+
return value[0]
173+
return {}

cognition_objects/file_extraction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def update(
9191
with_commit: bool = True,
9292
) -> FileExtraction:
9393
file_extraction = get_by_id(org_id, file_extraction_id)
94-
if file_extraction.state == enums.FileCachingState.CANCELED.value:
94+
if (
95+
not file_extraction
96+
or file_extraction.state == enums.FileCachingState.CANCELED.value
97+
):
9598
return
9699
if minio_path is not None:
97100
file_extraction.minio_path = minio_path

enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ class Tablenames(Enum):
168168
INTEGRATION_PDF = "pdf"
169169
INTEGRATION_SHAREPOINT = "sharepoint"
170170
STEP_TEMPLATES = "step_templates" # templates for strategy steps
171+
CONVERSATION_TAG = "conversation_tag" # config of tags used in conversations
172+
CONVERSATION_TAG_ASSOCIATION = (
173+
"conversation_tag_association" # association between conversation and tags
174+
)
171175

172176
def snake_case_to_pascal_case(self):
173177
# the type name (written in PascalCase) of a table is needed to create backrefs
@@ -917,6 +921,11 @@ class AdminQueries(Enum):
917921
FOLDER_MACRO_EXECUTION_SUMMARY = (
918922
"FOLDER_MACRO_EXECUTION_SUMMARY" # parameter options: organization_id
919923
)
924+
CREATED_TAGS_PER_ORG = (
925+
"CREATED_TAGS_PER_ORG" # parameter options: organization_id, without_kern_email
926+
)
927+
CONVERSATIONS_PER_TAG = "CONVERSATIONS_PER_TAG" # parameter options: organization_id, without_kern_email, distinct_conversations
928+
MULTITAGGED_CONVERSATIONS = "MULTITAGGED_CONVERSATIONS" # parameter options: organization_id, without_kern_email
920929

921930

922931
class CognitionIntegrationType(Enum):

global_objects/admin_queries.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,137 @@ def get_result_admin_query(
4141
return __get_macro_executions(**parameters, as_query=as_query)
4242
elif query == enums.AdminQueries.FOLDER_MACRO_EXECUTION_SUMMARY:
4343
return __get_folder_macro_execution_summary(**parameters, as_query=as_query)
44+
elif query == enums.AdminQueries.CREATED_TAGS_PER_ORG:
45+
return __get_created_tags_per_org(**parameters, as_query=as_query)
46+
elif query == enums.AdminQueries.CONVERSATIONS_PER_TAG:
47+
return __get_conversations_per_tag(**parameters, as_query=as_query)
48+
elif query == enums.AdminQueries.MULTITAGGED_CONVERSATIONS:
49+
return __get_multitagged_conversations(**parameters, as_query=as_query)
50+
4451
return []
4552

4653

54+
def __get_multitagged_conversations(
55+
organization_id: str = "",
56+
without_kern_email: bool = False,
57+
as_query: bool = False,
58+
):
59+
60+
org_join = ""
61+
if organization_id:
62+
organization_id = prevent_sql_injection(
63+
organization_id, isinstance(organization_id, str)
64+
)
65+
org_join = f""" INNER JOIN cognition.project p
66+
ON c.project_id = p.id AND p.organization_id = '{organization_id}'"""
67+
68+
filter_join = ""
69+
if without_kern_email:
70+
filter_join = """
71+
INNER JOIN PUBLIC.user u
72+
ON c.created_by = u.id AND u.email NOT LIKE '%@kern.ai'"""
73+
74+
query = f"""
75+
SELECT o.name organization_name, p.name project_name, COUNT(*) conv_with_gr_1_tag
76+
FROM (
77+
SELECT c.project_id, conversation_id
78+
FROM cognition.conversation C
79+
{filter_join}
80+
INNER JOIN cognition.conversation_tag_association cta
81+
ON c.id = cta.conversation_id
82+
{org_join}
83+
group BY 1, 2
84+
HAVING COUNT(*) > 1
85+
) x
86+
INNER JOIN cognition.project p
87+
ON x.project_id = p.id
88+
INNER JOIN organization o
89+
ON p.organization_id = o.id
90+
group BY 1,2
91+
"""
92+
if as_query:
93+
return query
94+
return general.execute_all(query)
95+
96+
97+
def __get_conversations_per_tag(
98+
organization_id: str = "",
99+
without_kern_email: bool = False,
100+
distinct_conversations: bool = False,
101+
as_query: bool = False,
102+
):
103+
104+
org_where = ""
105+
if organization_id:
106+
organization_id = prevent_sql_injection(
107+
organization_id, isinstance(organization_id, str)
108+
)
109+
org_where = f""" WHERE o.id = '{organization_id}'"""
110+
111+
filter_join = ""
112+
if without_kern_email:
113+
filter_join = """
114+
INNER JOIN PUBLIC.user u
115+
ON c.created_by = u.id AND u.email NOT LIKE '%@kern.ai'"""
116+
117+
count_query = "*"
118+
if distinct_conversations:
119+
count_query = "DISTINCT c.id"
120+
121+
query = f"""
122+
SELECT o.name organization_name, p.name project_name, COUNT({count_query}) tags_created
123+
FROM cognition.conversation_tag_association cta
124+
INNER JOIN cognition.conversation c
125+
ON c.id = cta.conversation_id
126+
{filter_join}
127+
INNER JOIN cognition.project p
128+
ON c.project_id = p.id
129+
INNER JOIN organization o
130+
ON p.organization_id = o.id
131+
{org_where}
132+
GROUP BY 1,2
133+
"""
134+
135+
if as_query:
136+
return query
137+
return general.execute_all(query)
138+
139+
140+
def __get_created_tags_per_org(
141+
organization_id: str = "",
142+
without_kern_email: bool = False,
143+
as_query: bool = False,
144+
):
145+
146+
org_where = ""
147+
if organization_id:
148+
organization_id = prevent_sql_injection(
149+
organization_id, isinstance(organization_id, str)
150+
)
151+
org_where = f""" WHERE o.id = '{organization_id}'"""
152+
153+
filter_join = ""
154+
if without_kern_email:
155+
filter_join = """AND u.email NOT LIKE '%@kern.ai'"""
156+
157+
query = f"""
158+
SELECT
159+
o.name organization_name, COUNT(*)
160+
FROM organization o
161+
INNER JOIN PUBLIC.user u
162+
ON o.id = u.organization_id {filter_join}
163+
INNER JOIN cognition.conversation_tag ct
164+
ON u.id = ct.created_by
165+
{org_where}
166+
GROUP BY 1
167+
ORDER BY 1
168+
"""
169+
170+
if as_query:
171+
return query
172+
return general.execute_all(query)
173+
174+
47175
def __get_folder_macro_execution_summary(
48176
slices: int = 7, # how many chunks are relevant
49177
organization_id: str = "",

0 commit comments

Comments
 (0)