Skip to content

Commit cf337c2

Browse files
JWittmeyerLennartSchmidtKernandhreljaKern
authored
Removal of GraphQL Data Return Schema (#150)
* Rework of lt query * fix full record * remove edge helpter * org fix * perf: update overview-stats query perf: eliminate payload transform in refinery-ui * perf: add extended embeddings query fix: add float cast on decimal.Decimal data types * fix: prevent_sql_injection for project_id * perf: update extended embeddings query perf: add columns whitelist caching logic * fix: add typing to new fn --------- Co-authored-by: LennartSchmidtKern <[email protected]> Co-authored-by: andhreljaKern <[email protected]>
1 parent 15df3b6 commit cf337c2

File tree

4 files changed

+146
-96
lines changed

4 files changed

+146
-96
lines changed

business_objects/embedding.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@
1111
from ..util import prevent_sql_injection
1212

1313

14+
ALL_EMBEDDINGS_WHITELIST = {
15+
"id",
16+
"name",
17+
"custom",
18+
"type",
19+
"state",
20+
"progress",
21+
"dimension",
22+
"count",
23+
"platform",
24+
"model",
25+
"filter_attributes",
26+
"attribute_id",
27+
}
28+
EMBEDDINGS_WHITELIST_COLUMNS_STRING = None
29+
30+
1431
def get(project_id: str, embedding_id: str) -> Embedding:
1532
return (
1633
session.query(Embedding)
@@ -104,6 +121,62 @@ def get_all_embeddings_by_project_id(project_id: str) -> List[Embedding]:
104121
return session.query(Embedding).filter(Embedding.project_id == project_id).all()
105122

106123

124+
def get_all_embeddings_by_project_id_extended(project_id: str) -> List[Dict[str, Any]]:
125+
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
126+
query = __get_all_embeddings_by_project_id_extended_query(project_id)
127+
return general.execute_all(query)
128+
129+
130+
def __get_embedding_whitelist_columns_string() -> str:
131+
global EMBEDDINGS_WHITELIST_COLUMNS_STRING
132+
if EMBEDDINGS_WHITELIST_COLUMNS_STRING is None:
133+
EMBEDDINGS_WHITELIST_COLUMNS_STRING = general.construct_select_columns(
134+
Embedding.__tablename__,
135+
prefix="e",
136+
include_columns=ALL_EMBEDDINGS_WHITELIST,
137+
)
138+
return EMBEDDINGS_WHITELIST_COLUMNS_STRING
139+
140+
141+
def __get_all_embeddings_by_project_id_extended_query(project_id: str) -> str:
142+
return f"""
143+
WITH num_recs AS (
144+
SELECT COUNT(r.*) number_records
145+
FROM record r
146+
WHERE r.project_id = '{project_id}'
147+
), e AS (
148+
SELECT
149+
{__get_embedding_whitelist_columns_string()},
150+
nr.number_records,
151+
(SELECT COUNT(et.*) FROM embedding_tensor et WHERE et.embedding_id = e.id) tensor_count
152+
FROM embedding e, num_recs nr
153+
WHERE e.project_id = '{project_id}'
154+
)
155+
SELECT
156+
{__get_embedding_whitelist_columns_string()},
157+
CASE
158+
WHEN e."type" = '{enums.EmbeddingType.ON_ATTRIBUTE.value}' THEN (
159+
SELECT json_array_length(et."data")
160+
FROM embedding_tensor et
161+
WHERE et.embedding_id = e.id
162+
LIMIT 1)
163+
WHEN e."type" = '{enums.EmbeddingType.ON_TOKEN.value}' THEN (
164+
SELECT json_array_length(et."data"->0)
165+
FROM embedding_tensor et
166+
WHERE et.embedding_id = e.id
167+
LIMIT 1)
168+
END dimension,
169+
CASE
170+
WHEN e."state" = '{enums.EmbeddingState.FINISHED.value}' THEN
171+
1.
172+
WHEN e."state" IN ('{enums.EmbeddingState.INITIALIZING.value}', '{enums.EmbeddingState.WAITING.value}') THEN
173+
0.
174+
ELSE LEAST(0.1 + (e.tensor_count / (e.number_records * 0.9)), 0.99)
175+
END progress
176+
FROM e
177+
"""
178+
179+
107180
def get_finished_embeddings(project_id: str) -> List[Embedding]:
108181
return (
109182
session.query(Embedding)

business_objects/labeling_task.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -70,48 +70,47 @@ def get_task_and_label_by_ids_and_type(
7070
def get_labeling_tasks_by_project_id_full(project_id: str) -> Row:
7171
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
7272
query = f"""
73-
WITH attribute_select AS (
74-
SELECT id, jsonb_build_object('id',id,'name', NAME,'relative_position', relative_position, 'data_type', data_Type) a_data
73+
WITH attribute_select AS (
74+
SELECT id, a.NAME, relative_position
7575
FROM attribute a
7676
WHERE project_id = '{project_id}'
7777
),
78-
label_select AS (
79-
SELECT labeling_Task_id, jsonb_build_object('edges',array_agg(jsonb_build_object('node',jsonb_build_object('id',id,'name', NAME,'color', color, 'hotkey', hotkey)))) l_data
78+
label_select AS (
79+
SELECT labeling_Task_id, array_agg(jsonb_build_object('id',id,'name', NAME,'color', color, 'hotkey', hotkey)) l_data
8080
FROM labeling_task_label ltl
8181
WHERE project_id = '{project_id}'
8282
GROUP BY 1
83-
),
83+
),
8484
is_select AS (
85-
SELECT labeling_task_id, jsonb_build_object('edges',array_agg(jsonb_build_object('node',jsonb_build_object('id',id,'type', type,'return_type', return_type, 'description', description,'name',NAME)))) i_data
85+
SELECT labeling_task_id, array_agg(jsonb_build_object('id',id,'type', type,'return_type', return_type, 'description', description,'name',NAME)) i_data
8686
FROM information_source _is
8787
WHERE project_id = '{project_id}'
8888
GROUP BY 1
8989
)
9090
91-
SELECT
92-
'{project_id}' id,
93-
jsonb_build_object('edges',array_agg(jsonb_build_object('node', lt_data))) labeling_tasks
94-
FROM (
95-
SELECT
96-
jsonb_build_object(
97-
'id',lt.id,
98-
'name', NAME,
99-
'task_target', task_target,
100-
'task_type', task_type,
101-
'attribute',a.a_data,
102-
'labels',COALESCE(l.l_data,jsonb_build_object('edges',ARRAY[]::jsonb[])),
103-
'information_sources',COALESCE(i.i_data,jsonb_build_object('edges',ARRAY[]::jsonb[]))
104-
) lt_data
105-
FROM labeling_task lt
106-
LEFT JOIN attribute_select a
107-
ON lt.attribute_id = a.id
108-
LEFT JOIN label_select l
109-
ON l.labeling_Task_id = lt.id
110-
LEFT JOIN is_select i
111-
ON i.labeling_task_id = lt.id
112-
WHERE project_id = '{project_id}'
113-
) x """
114-
return general.execute_first(query)
91+
SELECT array_agg(
92+
jsonb_build_object(
93+
'id',lt.id,
94+
'name', lt.NAME,
95+
'task_target', task_target,
96+
'task_type', task_type,
97+
'target_id',CASE WHEN lt.task_target = '{enums.LabelingTaskTarget.ON_ATTRIBUTE.value}' THEN a.id::TEXT ELSE '' END,
98+
'target_name',CASE WHEN lt.task_target = '{enums.LabelingTaskTarget.ON_ATTRIBUTE.value}' THEN a.name ELSE 'Full Record' END,
99+
'labels',COALESCE(l.l_data,ARRAY[]::jsonb[]),
100+
'information_sources',COALESCE(i.i_data,ARRAY[]::jsonb[])
101+
) ORDER BY a.relative_position, a.name) lt_data
102+
FROM labeling_task lt
103+
LEFT JOIN attribute_select a
104+
ON lt.attribute_id = a.id
105+
LEFT JOIN label_select l
106+
ON l.labeling_Task_id = lt.id
107+
LEFT JOIN is_select i
108+
ON i.labeling_task_id = lt.id
109+
WHERE project_id = '{project_id}' """
110+
values = general.execute_first(query)
111+
if values and values[0]:
112+
return values[0]
113+
return []
115114

116115

117116
def get_task_name_id_dict(project_id: str) -> Dict[str, str]:

business_objects/organization.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ def get_organization_overview_stats(
5858
values = general.execute_first(
5959
__get_organization_overview_stats_query(organization_id)
6060
)
61-
if values:
61+
if values and values[0]:
6262
return values[0]
63+
return []
6364

6465

6566
def get_user_count(organization_id: str) -> int:
@@ -69,41 +70,47 @@ def get_user_count(organization_id: str) -> int:
6970
def __get_organization_overview_stats_query(organization_id: str):
7071
return f"""
7172
WITH labeled_records AS (
72-
SELECT project_id, source_type, COUNT(*) source_records
73-
FROM (
74-
SELECT rla.project_id, rla.record_id, rla.source_type
75-
FROM record r
76-
INNER JOIN record_label_association rla
77-
ON r.project_id = rla.project_id AND r.id = rla.record_id AND r.category = '{enums.RecordCategory.SCALE.value}'
78-
INNER JOIN project p
79-
ON rla.project_id = p.id
80-
WHERE p.organization_id = '{organization_id}'
81-
AND rla.source_type IN ('{enums.LabelSource.MANUAL.value}', '{enums.LabelSource.WEAK_SUPERVISION.value}')
82-
GROUP BY rla.project_id, rla.record_id, rla.source_type
83-
) r_reduction
84-
GROUP BY project_id, source_type)
85-
86-
SELECT array_agg(row_to_json(x))
87-
FROM (
88-
SELECT
89-
base.project_id "projectId",
90-
base.base_count "numDataScaleUploaded",
91-
COALESCE(lr_m.source_records,0) "numDataScaleManual",
92-
COALESCE(lr_w.source_records,0) "numDataScaleProgrammatical"
73+
SELECT project_id, source_type, COUNT(*) source_records
9374
FROM (
94-
SELECT r.project_id, COUNT(*) base_count
95-
FROM project p
96-
LEFT JOIN record r
97-
ON r.project_id = p.id
75+
SELECT rla.project_id, rla.record_id, rla.source_type
76+
FROM record r
77+
INNER JOIN record_label_association rla
78+
ON r.project_id = rla.project_id AND r.id = rla.record_id AND r.category = '{enums.RecordCategory.SCALE.value}'
79+
INNER JOIN project p
80+
ON rla.project_id = p.id
9881
WHERE p.organization_id = '{organization_id}'
99-
AND p."status" != '{enums.ProjectStatus.IN_DELETION.value}'
100-
AND r.category = '{enums.RecordCategory.SCALE.value}'
101-
GROUP BY r.project_id
102-
) base
103-
LEFT JOIN labeled_records lr_m
104-
ON base.project_id = lr_m.project_id AND lr_m.source_type = '{enums.LabelSource.MANUAL.value}'
105-
LEFT JOIN labeled_records lr_w
106-
ON base.project_id = lr_w.project_id AND lr_w.source_type = '{enums.LabelSource.WEAK_SUPERVISION.value}' )x
82+
AND rla.source_type IN ('{enums.LabelSource.MANUAL.value}', '{enums.LabelSource.WEAK_SUPERVISION.value}')
83+
GROUP BY rla.project_id, rla.record_id, rla.source_type
84+
) r_reduction
85+
GROUP BY project_id, source_type
86+
) SELECT jsonb_object_agg(x."projectId", row_to_json(x))
87+
FROM (
88+
SELECT
89+
*,
90+
TRIM_SCALE(ROUND(("numDataScaleManual" * 100. / "numDataScaleUploaded")::numeric, 2)) || ' %' "manuallyLabeled",
91+
TRIM_SCALE(ROUND(("numDataScaleProgrammatical" * 100. / "numDataScaleUploaded")::numeric, 2)) || ' %' "weaklySupervised"
92+
FROM (
93+
SELECT
94+
base.project_id "projectId",
95+
base.base_count "numDataScaleUploaded",
96+
COALESCE(lr_m.source_records,0) "numDataScaleManual",
97+
COALESCE(lr_w.source_records,0) "numDataScaleProgrammatical"
98+
FROM (
99+
SELECT r.project_id, COUNT(*) base_count
100+
FROM project p
101+
LEFT JOIN record r
102+
ON r.project_id = p.id
103+
WHERE p.organization_id = '{organization_id}'
104+
AND p."status" != '{enums.ProjectStatus.IN_DELETION.value}'
105+
AND r.category = '{enums.RecordCategory.SCALE.value}'
106+
GROUP BY r.project_id
107+
) base
108+
LEFT JOIN labeled_records lr_m
109+
ON base.project_id = lr_m.project_id AND lr_m.source_type = '{enums.LabelSource.MANUAL.value}'
110+
LEFT JOIN labeled_records lr_w
111+
ON base.project_id = lr_w.project_id AND lr_w.source_type = '{enums.LabelSource.WEAK_SUPERVISION.value}'
112+
) y
113+
) x
107114
"""
108115

109116

util.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Iterable as collections_abc_Iterable
55
from re import sub, match, compile
66
import sqlalchemy
7+
import decimal
78
from uuid import UUID
89
from datetime import datetime
910

@@ -107,38 +108,6 @@ def sql_alchemy_to_dict(
107108
return result
108109

109110

110-
def pack_edges_node(result, name: str, max_lvl: Optional[int] = None):
111-
112-
def convert_value(value, max_lvl: int):
113-
new_lvl = max_lvl - 1 if max_lvl is not None else None
114-
if isinstance(value, list):
115-
return {
116-
"edges": [
117-
{
118-
"node": (
119-
convert_value(item, new_lvl)
120-
if max_lvl is None or max_lvl > 0
121-
else item
122-
)
123-
}
124-
for item in value
125-
]
126-
}
127-
elif isinstance(value, dict):
128-
return {
129-
key: (
130-
convert_value(val, new_lvl)
131-
if max_lvl is None or max_lvl > 0
132-
else val
133-
)
134-
for key, val in value.items()
135-
}
136-
else:
137-
return value
138-
139-
return {"data": {name: convert_value(result, max_lvl)}}
140-
141-
142111
def __sql_alchemy_to_dict(
143112
sql_alchemy_object: Any,
144113
column_whitelist: Optional[Iterable[str]] = None,
@@ -217,6 +186,8 @@ def to_frontend_obj_raw(value: Union[List, Dict]):
217186
def to_json_serializable(x: Any):
218187
if isinstance(x, datetime):
219188
return x.isoformat()
189+
elif isinstance(x, decimal.Decimal):
190+
return float(x)
220191
elif isinstance(x, UUID):
221192
return str(x)
222193
else:

0 commit comments

Comments
 (0)