1- from itertools import groupby
21import numpy as np
32import streamlit as st
4- from sqlite_vec import serialize_float32
53
64from test2text .services .utils .math_utils import round_distance
5+ from test2text .services .repositories import (
6+ test_cases as tc_repo ,
7+ requirements as req_repo ,
8+ annotations as an_repo ,
9+ )
710
811
912SUMMARY_LENGTH = 100
@@ -13,7 +16,6 @@ def make_a_tc_report():
1316 from test2text .services .db import get_db_client
1417
1518 with get_db_client () as db :
16- from test2text .services .embeddings .embed import embed_requirement
1719 from test2text .services .utils import unpack_float32
1820 from test2text .services .visualisation .visualize_vectors import (
1921 minifold_vectors_2d ,
@@ -24,17 +26,6 @@ def make_a_tc_report():
2426
2527 st .header ("Test2Text Report" )
2628
27- def write_requirements (current_requirements : set [tuple ]):
28- st .write ("External id," , "Summary," , "Distance" )
29- for (
30- _ ,
31- req_external_id ,
32- req_summary ,
33- _ ,
34- distance ,
35- ) in current_requirements :
36- st .write (req_external_id , req_summary , round_distance (distance ))
37-
3829 with st .container (border = True ):
3930 st .subheader ("Filter test cases" )
4031 with st .expander ("🔍 Filters" ):
@@ -50,47 +41,25 @@ def write_requirements(current_requirements: set[tuple]):
5041 )
5142 st .info ("Search using embeddings" )
5243
53- where_clauses = []
54- params = []
55-
56- if filter_summary .strip ():
57- where_clauses .append ("Testcases.test_case LIKE ?" )
58- params .append (f"%{ filter_summary .strip ()} %" )
59-
60- distance_sql = ""
61- distance_order_sql = ""
62- query_embedding_bytes = None
63- if filter_embedding .strip ():
64- query_embedding = embed_requirement (filter_embedding .strip ())
65- query_embedding_bytes = serialize_float32 (query_embedding )
66- distance_sql = ", vec_distance_L2(embedding, ?) AS distance"
67- distance_order_sql = "distance ASC, "
68-
6944 with st .container (border = True ):
7045 st .session_state .update ({"tc_form_submitting" : True })
71- data = db .get_ordered_values_from_test_cases (
72- distance_sql ,
73- where_clauses ,
74- distance_order_sql ,
75- params + [query_embedding_bytes ] if distance_sql else params ,
46+ test_cases = tc_repo .fetch_filtered_test_cases (
47+ db , text_content = filter_summary , smart_search_query = filter_embedding
7648 )
77- if distance_sql :
78- tc_dict = {
79- f"{ test_case } [smart search d={ round_distance (distance )} ]" : tc_id
80- for (tc_id , _ , test_case , distance ) in data
81- }
82- else :
83- tc_dict = {test_case : tc_id for (tc_id , _ , test_case ) in data }
49+ test_cases = {
50+ tc_id : (test_script , test_case )
51+ for tc_id , test_script , test_case in test_cases
52+ }
8453
8554 st .subheader ("Choose ONE of filtered test cases" )
86- option = st .selectbox (
87- "Choose a requirement to work with" , tc_dict .keys (), key = "filter_tc_id"
55+ selected_test_case = st .selectbox (
56+ "Choose a requirement to work with" ,
57+ test_cases .keys (),
58+ key = "filter_tc_id" ,
59+ format_func = lambda x : test_cases [x ][1 ],
8860 )
8961
90- if option :
91- where_clauses .append ("Testcases.id = ?" )
92- params .append (tc_dict [option ])
93-
62+ if selected_test_case :
9463 st .subheader ("Filter Requirements" )
9564
9665 with st .expander ("🔍 Filters" ):
@@ -114,81 +83,39 @@ def write_requirements(current_requirements: set[tuple]):
11483 )
11584 st .info ("Limit of selected requirements" )
11685
117- if filter_radius :
118- where_clauses .append ("distance <= ?" )
119- params .append (f"{ filter_radius } " )
120-
121- if filter_limit :
122- params .append (f"{ filter_limit } " )
123-
124- rows = db .join_all_tables_by_test_cases (where_clauses , params )
86+ annotations = an_repo .fetch_annotations_by_test_case (
87+ db , selected_test_case
88+ )
89+ annotations_dict = {
90+ anno_id : (anno_summary , anno_embedding )
91+ for anno_id , anno_summary , anno_embedding in annotations
92+ }
12593
126- if not rows :
94+ if not annotations_dict :
12795 st .error (
12896 "There is no requested data to inspect.\n "
12997 "Please check filters, completeness of the data or upload new annotations and requirements."
13098 )
131- return None
132-
133- for (tc_id , test_script , test_case ), group in groupby (
134- rows , lambda x : x [0 :3 ]
135- ):
99+ else :
136100 st .divider ()
137101 with st .container ():
138- st .subheader (f"Inspect #{ tc_id } Test case '{ test_case } '" )
139- st .write (f"From test script { test_script } " )
140- current_annotations = dict ()
141- for (
142- _ ,
143- _ ,
144- _ ,
145- anno_id ,
146- anno_summary ,
147- anno_embedding ,
148- distance ,
149- req_id ,
150- req_external_id ,
151- req_summary ,
152- req_embedding ,
153- ) in group :
154- current_annotation = (anno_id , anno_summary , anno_embedding )
155- current_reqs = current_annotations .get (
156- current_annotation , set ()
157- )
158- current_annotations .update (
159- {current_annotation : current_reqs }
160- )
161- current_annotations [current_annotation ].add (
162- (
163- req_id ,
164- req_external_id ,
165- req_summary ,
166- req_embedding ,
167- distance ,
168- )
169- )
102+ st .subheader (
103+ f"Inspect #{ selected_test_case } Test case '{ test_cases [selected_test_case ][1 ]} '"
104+ )
105+ st .write (
106+ f"From test script { test_cases [selected_test_case ][0 ]} "
107+ )
170108
171109 t_cs , anno , viz = st .columns (3 )
172110 with t_cs :
173111 with st .container (border = True ):
174112 st .write ("Annotations" )
175113 st .info ("Annotations linked to chosen Test case" )
176- reqs_by_anno = {
177- f"#{ anno_id } { anno_summary } " : (
178- anno_id ,
179- anno_summary ,
180- anno_embedding ,
181- )
182- for (
183- anno_id ,
184- anno_summary ,
185- anno_embedding ,
186- ) in current_annotations .keys ()
187- }
188- radio_choice = st .radio (
114+ chosen_annotation = st .radio (
189115 "Annotation's id + summary" ,
190- reqs_by_anno .keys (),
191- key = "radio_choice" ,
116+ annotations_dict .keys (),
117+ key = "chosen_annotation" ,
118+ format_func = lambda x : f"[{ x } ] { annotations_dict [x ][0 ][:SUMMARY_LENGTH ]} " ,
192119 )
193120 st .markdown (
194121 """
@@ -203,18 +130,42 @@ def write_requirements(current_requirements: set[tuple]):
203130 unsafe_allow_html = True ,
204131 )
205132
206- if radio_choice :
133+ if chosen_annotation :
134+ requirements = (
135+ req_repo .fetch_requirements_by_annotation (
136+ db ,
137+ annotation_id = chosen_annotation ,
138+ radius = filter_radius ,
139+ limit = filter_limit ,
140+ )
141+ )
142+ reqs_dict = {
143+ req_id : (
144+ req_external_id ,
145+ req_summary ,
146+ req_emb ,
147+ distance ,
148+ )
149+ for req_id , req_external_id , req_summary , req_emb , distance in requirements
150+ }
207151 with anno :
208152 with st .container (border = True ):
209153 st .write ("Requirements" )
210154 st .info (
211155 "Found Requirements for chosen annotation"
212156 )
213- write_requirements (
214- current_annotations [
215- reqs_by_anno [radio_choice ]
216- ]
217- )
157+ st .write ("External id," , "Summary," , "Distance" )
158+ for (
159+ req_external_id ,
160+ req_summary ,
161+ _ ,
162+ distance ,
163+ ) in reqs_dict .values ():
164+ st .write (
165+ req_external_id ,
166+ req_summary ,
167+ round_distance (distance ),
168+ )
218169 with viz :
219170 with st .container (border = True ):
220171 st .write ("Visualization" )
@@ -223,18 +174,27 @@ def write_requirements(current_requirements: set[tuple]):
223174 )
224175 req_embeddings = [
225176 unpack_float32 (req_emb )
226- for _ , _ , _ , req_emb , _ in current_annotations [
227- reqs_by_anno [radio_choice ]
228- ]
177+ for _ , _ , req_emb , _ in reqs_dict .values ()
229178 ]
230179 req_labels = [
231- f"{ ext_id } "
232- for _ , ext_id , req_sum , _ , _ in current_annotations [
233- reqs_by_anno [radio_choice ]
234- ]
180+ req_ext_id or req_id
181+ for req_id , (
182+ req_ext_id ,
183+ _ ,
184+ _ ,
185+ _ ,
186+ ) in reqs_dict .items ()
235187 ]
236188 annotation_vectors = np .array (
237- [np .array (unpack_float32 (anno_embedding ))]
189+ [
190+ np .array (
191+ unpack_float32 (
192+ annotations_dict [
193+ chosen_annotation
194+ ][1 ]
195+ )
196+ )
197+ ]
238198 )
239199 requirement_vectors = np .array (req_embeddings )
240200 if select == "2D" :
@@ -245,7 +205,7 @@ def write_requirements(current_requirements: set[tuple]):
245205 ),
246206 first_title = "Annotation" ,
247207 second_title = "Requirements" ,
248- first_labels = radio_choice ,
208+ first_labels = chosen_annotation ,
249209 second_labels = req_labels ,
250210 )
251211 else :
@@ -260,7 +220,7 @@ def write_requirements(current_requirements: set[tuple]):
260220 reqs_vectors_3d ,
261221 first_title = "Annotation" ,
262222 second_title = "Requirements" ,
263- first_labels = radio_choice ,
223+ first_labels = chosen_annotation ,
264224 second_labels = req_labels ,
265225 )
266226
0 commit comments