Skip to content

Commit 3846c00

Browse files
committed
removed all sql text to services/db/client.py
1 parent 15e4a28 commit 3846c00

File tree

7 files changed

+199
-143
lines changed

7 files changed

+199
-143
lines changed

test2text/pages/reports/report_by_req.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -66,36 +66,24 @@ def write_annotations(current_annotations: set[tuple]):
6666
distance_sql = ", vec_distance_L2(embedding, ?) AS distance"
6767
distance_order_sql = "distance ASC, "
6868

69-
where_sql = ""
70-
if where_clauses:
71-
where_sql = f"WHERE {' AND '.join(where_clauses)}"
72-
7369
with st.container(border=True):
7470
st.session_state.update({"req_form_submitting": True})
75-
sql = f"""
76-
SELECT
77-
Requirements.id as req_id,
78-
Requirements.external_id as req_external_id,
79-
Requirements.summary as req_summary
80-
{distance_sql}
81-
FROM
82-
Requirements
83-
{where_sql}
84-
ORDER BY
85-
{distance_order_sql}Requirements.id
86-
"""
87-
data = db.conn.execute(
88-
sql, params + [query_embedding_bytes] if distance_sql else params
71+
data = db.get_ordered_values_from_requirements(
72+
distance_sql,
73+
where_clauses,
74+
distance_order_sql,
75+
params + [query_embedding_bytes] if distance_sql else params,
8976
)
77+
9078
if distance_sql:
9179
requirements_dict = {
9280
f"{req_external_id} {summary[:SUMMARY_LENGTH]}... [smart search d={round_distance(distance)}]": req_id
93-
for (req_id, req_external_id, summary, distance) in data.fetchall()
81+
for (req_id, req_external_id, summary, distance) in data
9482
}
9583
else:
9684
requirements_dict = {
9785
f"{req_external_id} {summary[:SUMMARY_LENGTH]}...": req_id
98-
for (req_id, req_external_id, summary) in data.fetchall()
86+
for (req_id, req_external_id, summary) in data
9987
}
10088

10189
st.subheader("Choose 1 of filtered requirements")
@@ -144,39 +132,8 @@ def write_annotations(current_annotations: set[tuple]):
144132
if filter_limit:
145133
params.append(f"{filter_limit}")
146134

147-
where_sql = ""
148-
if where_clauses:
149-
where_sql = f"WHERE {' AND '.join(where_clauses)}"
135+
rows = db.join_all_tables_by_requirements(where_clauses, params)
150136

151-
sql = f"""
152-
SELECT
153-
Requirements.id as req_id,
154-
Requirements.external_id as req_external_id,
155-
Requirements.summary as req_summary,
156-
Requirements.embedding as req_embedding,
157-
158-
Annotations.id as anno_id,
159-
Annotations.summary as anno_summary,
160-
Annotations.embedding as anno_embedding,
161-
162-
AnnotationsToRequirements.cached_distance as distance,
163-
164-
TestCases.id as case_id,
165-
TestCases.test_script as test_script,
166-
TestCases.test_case as test_case
167-
FROM
168-
Requirements
169-
JOIN AnnotationsToRequirements ON Requirements.id = AnnotationsToRequirements.requirement_id
170-
JOIN Annotations ON Annotations.id = AnnotationsToRequirements.annotation_id
171-
JOIN CasesToAnnos ON Annotations.id = CasesToAnnos.annotation_id
172-
JOIN TestCases ON TestCases.id = CasesToAnnos.case_id
173-
{where_sql}
174-
ORDER BY
175-
Requirements.id, AnnotationsToRequirements.cached_distance, TestCases.id
176-
LIMIT ?
177-
"""
178-
data = db.conn.execute(sql, params)
179-
rows = data.fetchall()
180137
if not rows:
181138
st.error(
182139
"There is no requested data to inspect.\n"

test2text/pages/reports/report_by_tc.py

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -66,26 +66,13 @@ def write_requirements(current_requirements: set[tuple]):
6666
distance_sql = ", vec_distance_L2(embedding, ?) AS distance"
6767
distance_order_sql = "distance ASC, "
6868

69-
where_sql = ""
70-
if where_clauses:
71-
where_sql = f"WHERE {' AND '.join(where_clauses)}"
72-
7369
with st.container(border=True):
7470
st.session_state.update({"tc_form_submitting": True})
75-
sql = f"""
76-
SELECT
77-
TestCases.id as case_id,
78-
TestCases.test_script as test_script,
79-
TestCases.test_case as test_case
80-
{distance_sql}
81-
FROM
82-
TestCases
83-
{where_sql}
84-
ORDER BY
85-
{distance_order_sql}TestCases.id
86-
"""
87-
data = db.conn.execute(
88-
sql, params + [query_embedding_bytes] if distance_sql else params
71+
data = db.get_ordered_values_from_test_cases(
72+
distance_sql,
73+
where_clauses,
74+
distance_order_sql,
75+
params + [query_embedding_bytes] if distance_sql else params,
8976
)
9077
if distance_sql:
9178
tc_dict = {
@@ -136,39 +123,8 @@ def write_requirements(current_requirements: set[tuple]):
136123
if filter_limit:
137124
params.append(f"{filter_limit}")
138125

139-
where_sql = ""
140-
if where_clauses:
141-
where_sql = f"WHERE {' AND '.join(where_clauses)}"
126+
rows = db.join_all_tables_by_test_cases(where_clauses, params)
142127

143-
sql = f"""
144-
SELECT
145-
TestCases.id as case_id,
146-
TestCases.test_script as test_script,
147-
TestCases.test_case as test_case,
148-
149-
Annotations.id as anno_id,
150-
Annotations.summary as anno_summary,
151-
Annotations.embedding as anno_embedding,
152-
153-
AnnotationsToRequirements.cached_distance as distance,
154-
155-
Requirements.id as req_id,
156-
Requirements.external_id as req_external_id,
157-
Requirements.summary as req_summary,
158-
Requirements.embedding as req_embedding
159-
FROM
160-
TestCases
161-
JOIN CasesToAnnos ON TestCases.id = CasesToAnnos.case_id
162-
JOIN Annotations ON Annotations.id = CasesToAnnos.annotation_id
163-
JOIN AnnotationsToRequirements ON Annotations.id = AnnotationsToRequirements.annotation_id
164-
JOIN Requirements ON Requirements.id = AnnotationsToRequirements.requirement_id
165-
{where_sql}
166-
ORDER BY
167-
case_id, distance, req_id
168-
LIMIT ?
169-
"""
170-
data = db.conn.execute(sql, params)
171-
rows = data.fetchall()
172128
if not rows:
173129
st.error(
174130
"There is no requested data to inspect.\n"

test2text/services/db/client.py

Lines changed: 173 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def get_table_names(self):
8888
cursor.close()
8989
return tables
9090

91+
def get_column_values(self, columns: list[str], from_table: str):
92+
cursor = self.conn.execute(f"SELECT {', '.join(columns)} FROM {from_table}")
93+
return cursor.fetchall()
94+
9195
@property
9296
def get_db_full_info(self):
9397
"""
@@ -100,27 +104,25 @@ def get_db_full_info(self):
100104
db_tables_info = {}
101105
table_names = self.get_table_names()
102106
for table_name in table_names:
103-
row_count = self.count_all_entries_in_table(table_name)
107+
row_count = self.count_all_entries(table_name)
104108
db_tables_info.update(
105109
{
106110
table_name: row_count,
107111
}
108112
)
109113
return db_tables_info
110114

111-
def count_all_entries_in_table(self, table: str) -> int:
112-
count = self.conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
115+
def count_all_entries(self, from_table: str) -> int:
116+
count = self.conn.execute(f"SELECT COUNT(*) FROM {from_table}").fetchone()[0]
113117
return count
114118

115-
def count_notnull_entries_in_table(
116-
self, column: str, table: str
119+
def count_notnull_entries(
120+
self, columns: list[str], from_table: str
117121
) -> Union[int, None]:
118-
if self.has_column(column, table):
119-
count = self.conn.execute(
120-
f"SELECT COUNT(*) FROM {table} WHERE {column} IS NOT NULL"
121-
).fetchone()[0]
122-
return count
123-
return None
122+
count = self.conn.execute(
123+
f"SELECT COUNT(*) FROM {from_table} WHERE {', '.join(columns)} IS NOT NULL"
124+
).fetchone()[0]
125+
return count
124126

125127
def has_column(self, column_name: str, table_name: str) -> bool:
126128
"""
@@ -134,3 +136,163 @@ def has_column(self, column_name: str, table_name: str) -> bool:
134136
columns = [row[1] for row in cursor.fetchall()] # row[1] is the column name
135137
cursor.close()
136138
return column_name in columns
139+
140+
def get_null_entries(self, from_table: str) -> list:
141+
cursor = self.conn.execute(
142+
f"SELECT id, summary FROM {from_table} WHERE embedding IS NULL"
143+
)
144+
return cursor.fetchall()
145+
146+
def get_distances(self) -> list[tuple[int, int, float]]:
147+
"""
148+
Returns a list of tuples containing the id of the annotation and the id of the requirement,
149+
and the distance between their embeddings (anno_id, req_id, distance).
150+
The distance is calculated using the L2 norm. The results are ordered by requirement ID and distance.
151+
"""
152+
cursor = self.conn.execute("""
153+
SELECT
154+
Annotations.id AS anno_id,
155+
Requirements.id AS req_id,
156+
vec_distance_L2(Annotations.embedding, Requirements.embedding) AS distance
157+
FROM Annotations, Requirements
158+
WHERE Annotations.embedding IS NOT NULL AND Requirements.embedding IS NOT NULL
159+
ORDER BY req_id, distance
160+
""")
161+
return cursor.fetchall()
162+
163+
def get_embeddings_from_annotations_to_requirements_table(self):
164+
"""
165+
Returns a list of annotation's embeddings that are stored in the AnnotationsToRequirements table.
166+
The embeddings are ordered by annotation ID.
167+
"""
168+
cursor = self.conn.execute("""
169+
SELECT embedding FROM Annotations
170+
WHERE id IN (
171+
SELECT DISTINCT annotation_id FROM AnnotationsToRequirements
172+
)
173+
""")
174+
return cursor.fetchall()
175+
176+
def join_all_tables_by_requirements(
177+
self, where_clauses="", params=None
178+
) -> list[tuple]:
179+
"""
180+
Join all tables related to requirements based on the provided where clauses and parameters.
181+
return a list of tuples containing :
182+
req_id,
183+
req_external_id,
184+
req_summary,
185+
req_embedding,
186+
anno_id,
187+
anno_summary,
188+
anno_embedding,
189+
distance,
190+
case_id,
191+
test_script,
192+
test_case
193+
"""
194+
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
195+
sql = f"""
196+
SELECT
197+
Requirements.id as req_id,
198+
Requirements.external_id as req_external_id,
199+
Requirements.summary as req_summary,
200+
Requirements.embedding as req_embedding,
201+
202+
Annotations.id as anno_id,
203+
Annotations.summary as anno_summary,
204+
Annotations.embedding as anno_embedding,
205+
206+
AnnotationsToRequirements.cached_distance as distance,
207+
208+
TestCases.id as case_id,
209+
TestCases.test_script as test_script,
210+
TestCases.test_case as test_case
211+
FROM
212+
Requirements
213+
JOIN AnnotationsToRequirements ON Requirements.id = AnnotationsToRequirements.requirement_id
214+
JOIN Annotations ON Annotations.id = AnnotationsToRequirements.annotation_id
215+
JOIN CasesToAnnos ON Annotations.id = CasesToAnnos.annotation_id
216+
JOIN TestCases ON TestCases.id = CasesToAnnos.case_id
217+
{where_sql}
218+
ORDER BY
219+
Requirements.id, AnnotationsToRequirements.cached_distance, TestCases.id
220+
LIMIT ?
221+
"""
222+
data = self.conn.execute(sql, params)
223+
return data.fetchall()
224+
225+
def get_ordered_values_from_requirements(
226+
self, distance_sql="", where_clauses="", distance_order_sql="", params=None
227+
) -> list[tuple]:
228+
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
229+
sql = f"""
230+
SELECT
231+
Requirements.id as req_id,
232+
Requirements.external_id as req_external_id,
233+
Requirements.summary as req_summary
234+
{distance_sql}
235+
FROM
236+
Requirements
237+
{where_sql}
238+
ORDER BY
239+
{distance_order_sql}Requirements.id
240+
"""
241+
data = self.conn.execute(sql, params)
242+
return data.fetchall()
243+
244+
def get_ordered_values_from_test_cases(
245+
self, distance_sql="", where_clauses="", distance_order_sql="", params=None
246+
) -> list[tuple]:
247+
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
248+
sql = f"""
249+
SELECT
250+
TestCases.id as case_id,
251+
TestCases.test_script as test_script,
252+
TestCases.test_case as test_case
253+
{distance_sql}
254+
FROM
255+
TestCases
256+
{where_sql}
257+
ORDER BY
258+
{distance_order_sql}TestCases.id
259+
"""
260+
data = self.conn.execute(sql, params)
261+
return data.fetchall()
262+
263+
def join_all_tables_by_test_cases(
264+
self, where_clauses="", params=None
265+
) -> list[tuple]:
266+
where_sql = ""
267+
if where_clauses:
268+
where_sql = f"WHERE {' AND '.join(where_clauses)}"
269+
270+
sql = f"""
271+
SELECT
272+
TestCases.id as case_id,
273+
TestCases.test_script as test_script,
274+
TestCases.test_case as test_case,
275+
276+
Annotations.id as anno_id,
277+
Annotations.summary as anno_summary,
278+
Annotations.embedding as anno_embedding,
279+
280+
AnnotationsToRequirements.cached_distance as distance,
281+
282+
Requirements.id as req_id,
283+
Requirements.external_id as req_external_id,
284+
Requirements.summary as req_summary,
285+
Requirements.embedding as req_embedding
286+
FROM
287+
TestCases
288+
JOIN CasesToAnnos ON TestCases.id = CasesToAnnos.case_id
289+
JOIN Annotations ON Annotations.id = CasesToAnnos.annotation_id
290+
JOIN AnnotationsToRequirements ON Annotations.id = AnnotationsToRequirements.annotation_id
291+
JOIN Requirements ON Requirements.id = AnnotationsToRequirements.requirement_id
292+
{where_sql}
293+
ORDER BY
294+
case_id, distance, req_id
295+
LIMIT ?
296+
"""
297+
data = self.conn.execute(sql, params)
298+
return data.fetchall()

test2text/services/embeddings/annotation_embeddings_controls.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,9 @@ def write_batch(batch: list[tuple[int, str]]):
3030
db.annotations.set_embedding(anno_id, embedding)
3131
db.conn.commit()
3232

33-
annotations = db.conn.execute(f"""
34-
SELECT id, summary FROM Annotations
35-
{"WHERE embedding IS NULL" if not embed_all else ""}
36-
""")
33+
annotations = db.get_null_entries(from_table="Annotations")
3734

38-
for i, (anno_id, summary) in enumerate(annotations.fetchall()):
35+
for i, (anno_id, summary) in enumerate(annotations):
3936
if on_progress:
4037
on_progress((i + 1) / annotations_to_embed)
4138
batch.append((anno_id, summary))

0 commit comments

Comments
 (0)