Skip to content

Commit 15df3b6

Browse files
Rework Factories (#149)
* add rename mappine * blacklist * PR * recursive renaming * PR comments * PR comments * PR * empty conversations
1 parent 29ec945 commit 15df3b6

File tree

3 files changed

+102
-7
lines changed

3 files changed

+102
-7
lines changed

cognition_objects/message.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,34 @@ def get_all_by_conversation_id(
2121
)
2222

2323

24+
def get_all_by_conversation_ids(
25+
project_id: str, conversation_ids: List[str]
26+
) -> Dict[str, List[CognitionMessage]]:
27+
28+
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
29+
conversation_ids = [
30+
prevent_sql_injection(conversation_id, isinstance(conversation_id, str))
31+
for conversation_id in conversation_ids
32+
]
33+
conversation_where = (
34+
" AND conversation_id IN ('" + "','".join(conversation_ids) + "')"
35+
)
36+
query = f"""
37+
SELECT jsonb_object_agg(conversation_id, messages)
38+
FROM (
39+
SELECT m.conversation_id, array_agg(row_to_json(m) ORDER BY created_at ASC) AS messages
40+
FROM cognition.message m
41+
WHERE project_id = '{project_id}'{conversation_where}
42+
GROUP BY conversation_id
43+
) x
44+
"""
45+
46+
message_info = general.execute_first(query)
47+
if message_info and message_info[0]:
48+
return message_info[0]
49+
return {}
50+
51+
2452
def get_last_by_conversation_id(
2553
project_id: str, conversation_id: str
2654
) -> CognitionMessage:

cognition_objects/pipeline_log.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..models import CognitionPipelineLogs, CognitionMessage
55
from datetime import datetime
66
from .. import enums
7+
from ..util import prevent_sql_injection
78

89

910
def get_all_by_message_id(
@@ -220,3 +221,33 @@ def update_to_new_diff_structure(
220221

221222
if with_commit:
222223
general.commit()
224+
225+
226+
def get_error_and_time_elapsed_by_conversation_ids(
227+
project_id: str,
228+
conversation_ids: List[str],
229+
) -> Dict[str, CognitionPipelineLogs]:
230+
if not conversation_ids:
231+
return {}
232+
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
233+
conversation_ids = [
234+
prevent_sql_injection(conversation_id, isinstance(conversation_id, str))
235+
for conversation_id in conversation_ids
236+
]
237+
conversation_where = (
238+
" AND m.conversation_id IN ('" + "','".join(conversation_ids) + "')"
239+
)
240+
query = f"""
241+
SELECT jsonb_object_agg(id, jsonb_build_object('logs_have_error', has_error, 'time_logs_elapsed', time_elapsed))
242+
FROM (
243+
SELECT m.id, SUM(pl.has_error::INT) > 0 has_error,SUM(pl.time_elapsed) time_elapsed
244+
FROM cognition.message m
245+
INNER JOIN cognition.pipeline_logs pl
246+
ON m.project_id = pl.project_id AND m.id = pl.message_id
247+
WHERE m.project_Id = '{project_id}'{conversation_where}
248+
GROUP BY m.id
249+
)x"""
250+
conversation_info = general.execute_first(query)
251+
if conversation_info and conversation_info[0]:
252+
return conversation_info[0]
253+
return {}

util.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,12 @@ def sql_alchemy_to_dict(
9696
sql_alchemy_object: Any,
9797
for_frontend: bool = False,
9898
column_whitelist: Optional[Iterable[str]] = None,
99+
column_blacklist: Optional[Iterable[str]] = None,
100+
column_rename_map: Optional[Dict[str, str]] = None,
99101
):
100-
result = __sql_alchemy_to_dict(sql_alchemy_object, column_whitelist)
102+
result = __sql_alchemy_to_dict(
103+
sql_alchemy_object, column_whitelist, column_blacklist, column_rename_map
104+
)
101105
if for_frontend:
102106
return to_frontend_obj(result)
103107
return result
@@ -136,26 +140,58 @@ def convert_value(value, max_lvl: int):
136140

137141

138142
def __sql_alchemy_to_dict(
139-
sql_alchemy_object: Any, column_whitelist: Optional[Iterable[str]] = None
143+
sql_alchemy_object: Any,
144+
column_whitelist: Optional[Iterable[str]] = None,
145+
column_blacklist: Optional[Iterable[str]] = None,
146+
column_rename_map: Optional[Dict[str, str]] = None,
140147
):
148+
def rename_columns(data: Any) -> Any:
149+
if column_rename_map:
150+
if isinstance(data, dict):
151+
data = {
152+
column_rename_map.get(k, k): rename_columns(v)
153+
for k, v in data.items()
154+
}
155+
elif isinstance(data, list):
156+
data = [rename_columns(item) for item in data]
157+
return data
158+
141159
if isinstance(sql_alchemy_object, list):
142160
# list is for all() queries
143-
return [__sql_alchemy_to_dict(x, column_whitelist) for x in sql_alchemy_object]
161+
return [
162+
__sql_alchemy_to_dict(
163+
x, column_whitelist, column_blacklist, column_rename_map
164+
)
165+
for x in sql_alchemy_object
166+
]
144167

145168
elif isinstance(sql_alchemy_object, Row):
146169
# basic SELECT .. FROM query)
147170
# _mapping is a RowMapping object that is not serializable but dict like
148-
return {
171+
result = {
149172
k: v
150173
for k, v in dict(sql_alchemy_object._mapping).items()
151-
if not column_whitelist or k in column_whitelist
174+
if (not column_whitelist or k in column_whitelist)
175+
and (not column_blacklist or k not in column_blacklist)
152176
}
177+
return rename_columns(result)
153178
elif isinstance(sql_alchemy_object, Base):
154-
return {
179+
result = {
155180
c.name: getattr(sql_alchemy_object, c.name)
156181
for c in sql_alchemy_object.__table__.columns
157-
if not column_whitelist or c.name in column_whitelist
182+
if (not column_whitelist or c.name in column_whitelist)
183+
and (not column_blacklist or c.name not in column_blacklist)
184+
}
185+
return rename_columns(result)
186+
elif isinstance(sql_alchemy_object, dict):
187+
result = {
188+
k: v
189+
for k, v in sql_alchemy_object.items()
190+
if (not column_whitelist or k in column_whitelist)
191+
and (not column_blacklist or k not in column_blacklist)
158192
}
193+
return rename_columns(result)
194+
159195
else:
160196
return sql_alchemy_object
161197

0 commit comments

Comments
 (0)