Skip to content

Commit c6e82fd

Browse files
committed
[Req report] refactor requirements filtering
1 parent d8b928e commit c6e82fd

File tree

5 files changed

+186
-205
lines changed

5 files changed

+186
-205
lines changed

test2text/pages/reports/report_by_req.py

Lines changed: 144 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,150 @@
11
from itertools import groupby
22
import numpy as np
33
import streamlit as st
4-
from sqlite_vec import serialize_float32
54

65
from test2text.services.utils.math_utils import round_distance
6+
from test2text.services.repositories import requirements
77

88
SUMMARY_LENGTH = 100
99
LABELS_SUMMARY_LENGTH = 15
1010

11+
def display_found_details(data: list):
12+
from test2text.services.utils import unpack_float32
13+
from test2text.services.visualisation.visualize_vectors import (
14+
minifold_vectors_2d,
15+
plot_2_sets_in_one_2d,
16+
minifold_vectors_3d,
17+
plot_2_sets_in_one_3d,
18+
)
19+
def write_annotations(current_annotations: set[tuple]):
20+
st.write("id,", "Summary,", "Distance")
21+
for anno_id, anno_summary, _, distance in current_annotations:
22+
st.write(anno_id, anno_summary, round_distance(distance))
23+
24+
25+
for (
26+
req_id,
27+
req_external_id,
28+
req_summary,
29+
req_embedding,
30+
), group in groupby(data, lambda x: x[0:4]):
31+
st.divider()
32+
with st.container():
33+
st.subheader(f" Inspect Requirement {req_external_id}")
34+
st.write(req_summary)
35+
current_test_cases = dict()
36+
for (
37+
_,
38+
_,
39+
_,
40+
_,
41+
anno_id,
42+
anno_summary,
43+
anno_embedding,
44+
distance,
45+
case_id,
46+
test_script,
47+
test_case,
48+
) in group:
49+
current_annotation = current_test_cases.get(
50+
test_case, set()
51+
)
52+
current_test_cases.update({test_case: current_annotation})
53+
current_test_cases[test_case].add(
54+
(anno_id, anno_summary, anno_embedding, distance)
55+
)
56+
57+
t_cs, anno, viz = st.columns(3)
58+
with t_cs:
59+
with st.container(border=True):
60+
st.write("Test Cases")
61+
st.info("Test cases of chosen Requirement")
62+
st.radio(
63+
"Test cases name",
64+
current_test_cases.keys(),
65+
key="radio_choice",
66+
)
67+
st.markdown(
68+
"""
69+
<style>
70+
.stRadio > div {
71+
max-width: 100%;
72+
word-break: break-word;
73+
white-space: normal;
74+
}
75+
</style>
76+
""",
77+
unsafe_allow_html=True,
78+
)
79+
80+
if st.session_state["radio_choice"]:
81+
with anno:
82+
with st.container(border=True):
83+
st.write("Annotations")
84+
st.info(
85+
"List of Annotations for chosen Test case"
86+
)
87+
write_annotations(
88+
current_annotations=current_test_cases[
89+
st.session_state["radio_choice"]
90+
]
91+
)
92+
with viz:
93+
with st.container(border=True):
94+
st.write("Visualization")
95+
select = st.selectbox(
96+
"Choose type of visualization", ["2D", "3D"]
97+
)
98+
anno_embeddings = [
99+
unpack_float32(anno_emb)
100+
for _, _, anno_emb, _ in current_test_cases[
101+
st.session_state["radio_choice"]
102+
]
103+
]
104+
anno_labels = [
105+
f"{anno_id}"
106+
for anno_id, _, _, _ in current_test_cases[
107+
st.session_state["radio_choice"]
108+
]
109+
]
110+
requirement_vectors = np.array(
111+
[np.array(unpack_float32(req_embedding))]
112+
)
113+
annotation_vectors = np.array(anno_embeddings)
114+
if select == "2D":
115+
plot_2_sets_in_one_2d(
116+
minifold_vectors_2d(
117+
requirement_vectors
118+
),
119+
minifold_vectors_2d(annotation_vectors),
120+
"Requirement",
121+
"Annotations",
122+
first_labels=[f"{req_external_id}"],
123+
second_labels=anno_labels,
124+
)
125+
else:
126+
reqs_vectors_3d = minifold_vectors_3d(
127+
requirement_vectors
128+
)
129+
anno_vectors_3d = minifold_vectors_3d(
130+
annotation_vectors
131+
)
132+
plot_2_sets_in_one_3d(
133+
reqs_vectors_3d,
134+
anno_vectors_3d,
135+
"Requirement",
136+
"Annotations",
137+
first_labels=[f"{req_external_id}"],
138+
second_labels=anno_labels,
139+
)
140+
11141

12142
def make_a_report():
13143
from test2text.services.db import get_db_client
14144

15-
with get_db_client() as db:
16-
from test2text.services.embeddings.embed import embed_requirement
17-
from test2text.services.utils import unpack_float32
18-
from test2text.services.visualisation.visualize_vectors import (
19-
minifold_vectors_2d,
20-
plot_2_sets_in_one_2d,
21-
minifold_vectors_3d,
22-
plot_2_sets_in_one_3d,
23-
)
24-
145+
with (get_db_client() as db):
25146
st.header("Test2Text Report")
26147

27-
def write_annotations(current_annotations: set[tuple]):
28-
st.write("id,", "Summary,", "Distance")
29-
for anno_id, anno_summary, _, distance in current_annotations:
30-
st.write(anno_id, anno_summary, round_distance(distance))
31-
32148
with st.container(border=True):
33149
st.subheader("Filter requirements")
34150
with st.expander("🔍 Filters"):
@@ -47,62 +163,26 @@ def write_annotations(current_annotations: set[tuple]):
47163
)
48164
st.info("Search using embeddings")
49165

50-
where_clauses = []
51-
params = []
52-
53-
if filter_id.strip():
54-
where_clauses.append("Requirements.id = ?")
55-
params.append(filter_id.strip())
56-
57-
if filter_summary.strip():
58-
where_clauses.append("Requirements.summary LIKE ?")
59-
params.append(f"%{filter_summary.strip()}%")
60-
61-
distance_sql = ""
62-
distance_order_sql = ""
63-
query_embedding_bytes = None
64-
if filter_embedding.strip():
65-
query_embedding = embed_requirement(filter_embedding.strip())
66-
query_embedding_bytes = serialize_float32(query_embedding)
67-
distance_sql = ", vec_distance_L2(embedding, ?) AS distance"
68-
distance_order_sql = "distance ASC, "
69-
70166
with st.container(border=True):
71167
st.session_state.update({"req_form_submitting": True})
72-
data = db.get_ordered_values_from_requirements(
73-
distance_sql,
74-
where_clauses,
75-
distance_order_sql,
76-
params + [query_embedding_bytes] if distance_sql else params,
77-
)
168+
data = requirements.fetch_filtered_requirements(db,
169+
external_id=filter_id,
170+
text_content=filter_summary,
171+
smart_search_query=filter_embedding)
78172

79-
if distance_sql:
80-
requirements_dict = {
81-
f"{req_external_id} {summary[:SUMMARY_LENGTH]}... [smart search d={round_distance(distance)}]": req_id
82-
for (req_id, req_external_id, summary, distance) in data
83-
}
84-
else:
85-
requirements_dict = {
86-
f"{req_external_id} {summary[:SUMMARY_LENGTH]}...": req_id
87-
for (req_id, req_external_id, summary) in data
88-
}
173+
requirements_dict = {
174+
f"{req_external_id} {summary[:SUMMARY_LENGTH]}...": req_id
175+
for (req_id, req_external_id, summary) in data
176+
}
89177

90178
st.subheader("Choose 1 of filtered requirements")
91-
option = st.selectbox(
179+
selected_requirement = st.selectbox(
92180
"Choose a requirement to work with",
93181
requirements_dict.keys(),
94182
key="filter_req_id",
95183
)
96184

97-
if option:
98-
clause = "Requirements.id = ?"
99-
if clause in where_clauses:
100-
idx = where_clauses.index(clause)
101-
params.insert(idx, requirements_dict[option])
102-
else:
103-
where_clauses.append(clause)
104-
params.append(requirements_dict[option])
105-
185+
if selected_requirement:
106186
st.subheader("Filter Test cases")
107187

108188
with st.expander("🔍 Filters"):
@@ -140,123 +220,9 @@ def write_annotations(current_annotations: set[tuple]):
140220
"There is no requested data to inspect.\n"
141221
"Please check filters, completeness of the data or upload new annotations and requirements."
142222
)
143-
return None
144-
145-
for (
146-
req_id,
147-
req_external_id,
148-
req_summary,
149-
req_embedding,
150-
), group in groupby(rows, lambda x: x[0:4]):
151-
st.divider()
152-
with st.container():
153-
st.subheader(f" Inspect Requirement {req_external_id}")
154-
st.write(req_summary)
155-
current_test_cases = dict()
156-
for (
157-
_,
158-
_,
159-
_,
160-
_,
161-
anno_id,
162-
anno_summary,
163-
anno_embedding,
164-
distance,
165-
case_id,
166-
test_script,
167-
test_case,
168-
) in group:
169-
current_annotation = current_test_cases.get(
170-
test_case, set()
171-
)
172-
current_test_cases.update({test_case: current_annotation})
173-
current_test_cases[test_case].add(
174-
(anno_id, anno_summary, anno_embedding, distance)
175-
)
176-
177-
t_cs, anno, viz = st.columns(3)
178-
with t_cs:
179-
with st.container(border=True):
180-
st.write("Test Cases")
181-
st.info("Test cases of chosen Requirement")
182-
st.radio(
183-
"Test cases name",
184-
current_test_cases.keys(),
185-
key="radio_choice",
186-
)
187-
st.markdown(
188-
"""
189-
<style>
190-
.stRadio > div {
191-
max-width: 100%;
192-
word-break: break-word;
193-
white-space: normal;
194-
}
195-
</style>
196-
""",
197-
unsafe_allow_html=True,
198-
)
223+
else:
224+
display_found_details(rows)
199225

200-
if st.session_state["radio_choice"]:
201-
with anno:
202-
with st.container(border=True):
203-
st.write("Annotations")
204-
st.info(
205-
"List of Annotations for chosen Test case"
206-
)
207-
write_annotations(
208-
current_annotations=current_test_cases[
209-
st.session_state["radio_choice"]
210-
]
211-
)
212-
with viz:
213-
with st.container(border=True):
214-
st.write("Visualization")
215-
select = st.selectbox(
216-
"Choose type of visualization", ["2D", "3D"]
217-
)
218-
anno_embeddings = [
219-
unpack_float32(anno_emb)
220-
for _, _, anno_emb, _ in current_test_cases[
221-
st.session_state["radio_choice"]
222-
]
223-
]
224-
anno_labels = [
225-
f"{anno_id}"
226-
for anno_id, _, _, _ in current_test_cases[
227-
st.session_state["radio_choice"]
228-
]
229-
]
230-
requirement_vectors = np.array(
231-
[np.array(unpack_float32(req_embedding))]
232-
)
233-
annotation_vectors = np.array(anno_embeddings)
234-
if select == "2D":
235-
plot_2_sets_in_one_2d(
236-
minifold_vectors_2d(
237-
requirement_vectors
238-
),
239-
minifold_vectors_2d(annotation_vectors),
240-
"Requirement",
241-
"Annotations",
242-
first_labels=[f"{req_external_id}"],
243-
second_labels=anno_labels,
244-
)
245-
else:
246-
reqs_vectors_3d = minifold_vectors_3d(
247-
requirement_vectors
248-
)
249-
anno_vectors_3d = minifold_vectors_3d(
250-
annotation_vectors
251-
)
252-
plot_2_sets_in_one_3d(
253-
reqs_vectors_3d,
254-
anno_vectors_3d,
255-
"Requirement",
256-
"Annotations",
257-
first_labels=[f"{req_external_id}"],
258-
second_labels=anno_labels,
259-
)
260226

261227

262228
if __name__ == "__main__":

test2text/services/db/client.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -233,33 +233,6 @@ def join_all_tables_by_requirements(
233233
data = self.conn.execute(sql, params)
234234
return data.fetchall()
235235

236-
def get_ordered_values_from_requirements(
237-
self, distance_sql="", where_clauses="", distance_order_sql="", params=None
238-
) -> list[tuple]:
239-
"""
240-
Extracted values from Requirements table based on the provided where clauses and specified parameters ordered by distance and id.
241-
Return a list of tuples containing :
242-
req_id,
243-
req_external_id,
244-
req_summary,
245-
distance between annotation and requirement embeddings,
246-
"""
247-
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
248-
sql = f"""
249-
SELECT
250-
Requirements.id as req_id,
251-
Requirements.external_id as req_external_id,
252-
Requirements.summary as req_summary
253-
{distance_sql}
254-
FROM
255-
Requirements
256-
{where_sql}
257-
ORDER BY
258-
{distance_order_sql}Requirements.id
259-
"""
260-
data = self.conn.execute(sql, params)
261-
return data.fetchall()
262-
263236
def get_ordered_values_from_test_cases(
264237
self, distance_sql="", where_clauses="", distance_order_sql="", params=None
265238
) -> list[tuple]:

test2text/services/repositories/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)