1- from itertools import groupby
21import numpy as np
32import streamlit as st
4- from sqlite_vec import serialize_float32
53
64from test2text .services .utils .math_utils import round_distance
5+ from test2text .services .repositories import (
6+ requirements as requirements_repo ,
7+ test_cases as test_cases_repo ,
8+ annotations as annotations_repo ,
9+ )
710
811SUMMARY_LENGTH = 100
912LABELS_SUMMARY_LENGTH = 15
@@ -13,22 +16,8 @@ def make_a_report():
1316 from test2text .services .db import get_db_client
1417
1518 with get_db_client () as db :
16- from test2text .services .embeddings .embed import embed_requirement
17- from test2text .services .utils import unpack_float32
18- from test2text .services .visualisation .visualize_vectors import (
19- minifold_vectors_2d ,
20- plot_2_sets_in_one_2d ,
21- minifold_vectors_3d ,
22- plot_2_sets_in_one_3d ,
23- )
24-
2519 st .header ("Test2Text Report" )
2620
27- def write_annotations (current_annotations : set [tuple ]):
28- st .write ("id," , "Summary," , "Distance" )
29- for anno_id , anno_summary , _ , distance in current_annotations :
30- st .write (anno_id , anno_summary , round_distance (distance ))
31-
3221 with st .container (border = True ):
3322 st .subheader ("Filter requirements" )
3423 with st .expander ("🔍 Filters" ):
@@ -47,62 +36,30 @@ def write_annotations(current_annotations: set[tuple]):
4736 )
4837 st .info ("Search using embeddings" )
4938
50- where_clauses = []
51- params = []
52-
53- if filter_id .strip ():
54- where_clauses .append ("Requirements.id = ?" )
55- params .append (filter_id .strip ())
56-
57- if filter_summary .strip ():
58- where_clauses .append ("Requirements.summary LIKE ?" )
59- params .append (f"%{ filter_summary .strip ()} %" )
60-
61- distance_sql = ""
62- distance_order_sql = ""
63- query_embedding_bytes = None
64- if filter_embedding .strip ():
65- query_embedding = embed_requirement (filter_embedding .strip ())
66- query_embedding_bytes = serialize_float32 (query_embedding )
67- distance_sql = ", vec_distance_L2(embedding, ?) AS distance"
68- distance_order_sql = "distance ASC, "
69-
7039 with st .container (border = True ):
7140 st .session_state .update ({"req_form_submitting" : True })
72- data = db . get_ordered_values_from_requirements (
73- distance_sql ,
74- where_clauses ,
75- distance_order_sql ,
76- params + [ query_embedding_bytes ] if distance_sql else params ,
41+ requirements = requirements_repo . fetch_filtered_requirements (
42+ db ,
43+ external_id = filter_id ,
44+ text_content = filter_summary ,
45+ smart_search_query = filter_embedding ,
7746 )
7847
79- if distance_sql :
80- requirements_dict = {
81- f"{ req_external_id } { summary [:SUMMARY_LENGTH ]} ... [smart search d={ round_distance (distance )} ]" : req_id
82- for (req_id , req_external_id , summary , distance ) in data
83- }
84- else :
85- requirements_dict = {
86- f"{ req_external_id } { summary [:SUMMARY_LENGTH ]} ..." : req_id
87- for (req_id , req_external_id , summary ) in data
88- }
48+ requirements = {
49+ req_id : (req_external_id , summary )
50+ for (req_id , req_external_id , summary ) in requirements
51+ }
8952
9053 st .subheader ("Choose 1 of filtered requirements" )
91- option = st .selectbox (
54+ selected_requirement = st .selectbox (
9255 "Choose a requirement to work with" ,
93- requirements_dict .keys (),
56+ requirements .keys (),
9457 key = "filter_req_id" ,
58+ format_func = lambda x : f"{ requirements [x ][0 ]} { requirements [x ][1 ][:SUMMARY_LENGTH ]} ..." ,
9559 )
9660
97- if option :
98- clause = "Requirements.id = ?"
99- if clause in where_clauses :
100- idx = where_clauses .index (clause )
101- params .insert (idx , requirements_dict [option ])
102- else :
103- where_clauses .append (clause )
104- params .append (requirements_dict [option ])
105-
61+ if selected_requirement :
62+ requirement = db .requirements .get_by_id_raw (selected_requirement )
10663 st .subheader ("Filter Test cases" )
10764
10865 with st .expander ("🔍 Filters" ):
@@ -126,53 +83,34 @@ def write_annotations(current_annotations: set[tuple]):
12683 )
12784 st .info ("Limit of selected test cases" )
12885
129- if filter_radius :
130- where_clauses .append ("distance <= ?" )
131- params .append (f"{ filter_radius } " )
132-
133- if filter_limit :
134- params .append (f"{ filter_limit } " )
135-
136- rows = db .join_all_tables_by_requirements (where_clauses , params )
86+ test_cases = test_cases_repo .fetch_test_cases_by_requirement (
87+ db , selected_requirement , filter_radius , filter_limit
88+ )
89+ test_cases = {
90+ tc_id : (test_script , test_case )
91+ for (tc_id , test_script , test_case ) in test_cases
92+ }
13793
138- if not rows :
94+ if not test_cases :
13995 st .error (
14096 "There is no requested data to inspect.\n "
14197 "Please check filters, completeness of the data or upload new annotations and requirements."
14298 )
143- return None
99+ else :
100+ from test2text .services .utils import unpack_float32
101+ from test2text .services .visualisation .visualize_vectors import (
102+ minifold_vectors_2d ,
103+ plot_2_sets_in_one_2d ,
104+ minifold_vectors_3d ,
105+ plot_2_sets_in_one_3d ,
106+ )
144107
145- for (
146- req_id ,
147- req_external_id ,
148- req_summary ,
149- req_embedding ,
150- ), group in groupby (rows , lambda x : x [0 :4 ]):
151108 st .divider ()
152109 with st .container ():
153- st .subheader (f" Inspect Requirement { req_external_id } " )
154- st .write (req_summary )
155- current_test_cases = dict ()
156- for (
157- _ ,
158- _ ,
159- _ ,
160- _ ,
161- anno_id ,
162- anno_summary ,
163- anno_embedding ,
164- distance ,
165- case_id ,
166- test_script ,
167- test_case ,
168- ) in group :
169- current_annotation = current_test_cases .get (
170- test_case , set ()
171- )
172- current_test_cases .update ({test_case : current_annotation })
173- current_test_cases [test_case ].add (
174- (anno_id , anno_summary , anno_embedding , distance )
175- )
110+ st .subheader (
111+ f" Inspect Requirement { requirements [selected_requirement ][0 ]} "
112+ )
113+ st .write (requirements [selected_requirement ][1 ])
176114
177115 t_cs , anno , viz = st .columns (3 )
178116 with t_cs :
@@ -181,8 +119,9 @@ def write_annotations(current_annotations: set[tuple]):
181119 st .info ("Test cases of chosen Requirement" )
182120 st .radio (
183121 "Test cases name" ,
184- current_test_cases .keys (),
185- key = "radio_choice" ,
122+ test_cases .keys (),
123+ key = "chosen_test_case" ,
124+ format_func = lambda tc_id : test_cases [tc_id ][1 ],
186125 )
187126 st .markdown (
188127 """
@@ -197,18 +136,30 @@ def write_annotations(current_annotations: set[tuple]):
197136 unsafe_allow_html = True ,
198137 )
199138
200- if st .session_state ["radio_choice" ]:
139+ if st .session_state ["chosen_test_case" ]:
140+ annotations = annotations_repo .fetch_annotations_by_test_case_with_distance_to_requirement (
141+ db ,
142+ st .session_state ["chosen_test_case" ],
143+ requirement [3 ], # embedding
144+ )
201145 with anno :
202146 with st .container (border = True ):
203147 st .write ("Annotations" )
204148 st .info (
205149 "List of Annotations for chosen Test case"
206150 )
207- write_annotations (
208- current_annotations = current_test_cases [
209- st .session_state ["radio_choice" ]
210- ]
211- )
151+ st .write ("id," , "Summary," , "Distance" )
152+ for (
153+ anno_id ,
154+ anno_summary ,
155+ _ ,
156+ distance ,
157+ ) in annotations :
158+ st .write (
159+ anno_id ,
160+ anno_summary ,
161+ round_distance (distance ),
162+ )
212163 with viz :
213164 with st .container (border = True ):
214165 st .write ("Visualization" )
@@ -217,18 +168,14 @@ def write_annotations(current_annotations: set[tuple]):
217168 )
218169 anno_embeddings = [
219170 unpack_float32 (anno_emb )
220- for _ , _ , anno_emb , _ in current_test_cases [
221- st .session_state ["radio_choice" ]
222- ]
171+ for _ , _ , anno_emb , _ in annotations
223172 ]
224173 anno_labels = [
225174 f"{ anno_id } "
226- for anno_id , _ , _ , _ in current_test_cases [
227- st .session_state ["radio_choice" ]
228- ]
175+ for anno_id , _ , _ , _ in annotations
229176 ]
230177 requirement_vectors = np .array (
231- [np .array (unpack_float32 (req_embedding ))]
178+ [np .array (unpack_float32 (requirement [ 3 ] ))]
232179 )
233180 annotation_vectors = np .array (anno_embeddings )
234181 if select == "2D" :
@@ -239,7 +186,7 @@ def write_annotations(current_annotations: set[tuple]):
239186 minifold_vectors_2d (annotation_vectors ),
240187 "Requirement" ,
241188 "Annotations" ,
242- first_labels = [f"{ req_external_id } " ],
189+ first_labels = [f"{ requirement [ 1 ] } " ],
243190 second_labels = anno_labels ,
244191 )
245192 else :
@@ -254,7 +201,7 @@ def write_annotations(current_annotations: set[tuple]):
254201 anno_vectors_3d ,
255202 "Requirement" ,
256203 "Annotations" ,
257- first_labels = [f"{ req_external_id } " ],
204+ first_labels = [f"{ requirement [ 1 ] } " ],
258205 second_labels = anno_labels ,
259206 )
260207
0 commit comments