Skip to content

Commit 8f4cb3c

Browse files
committed
fixed formatting
1 parent a1ebd2e commit 8f4cb3c

File tree

9 files changed

+72
-32
lines changed

9 files changed

+72
-32
lines changed

convert_trace_annos.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ def is_empty(value):
1515

1616
def trace_test_cases_to_annos(trace_file_path: Path):
1717
with get_db_client() as db:
18-
1918
insertions = list()
2019
logger.info("Reading trace file and inserting annotations into table...")
21-
with open(trace_file_path, mode="r", newline="", encoding="utf-8") as trace_file:
20+
with open(
21+
trace_file_path, mode="r", newline="", encoding="utf-8"
22+
) as trace_file:
2223
reader = csv.reader(trace_file)
2324
current_tc = EMPTY
2425
concat_summary = EMPTY
@@ -37,7 +38,9 @@ def trace_test_cases_to_annos(trace_file_path: Path):
3738
case_id = db.test_cases.get_or_insert(
3839
test_script=test_script, test_case=current_tc
3940
)
40-
annotation_id = db.annotations.get_or_insert(summary=concat_summary)
41+
annotation_id = db.annotations.get_or_insert(
42+
summary=concat_summary
43+
)
4144
insertions.append(
4245
db.cases_to_annos.insert(
4346
case_id=case_id, annotation_id=annotation_id

test2text/pages/reports/report_by_req.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
def make_a_report():
1717
with get_db_client() as db:
1818
from test2text.services.embeddings.embed import embed_requirement
19+
1920
st.header("Test2Text Report")
2021

2122
def write_annotations(current_annotations: set[tuple]):
@@ -127,7 +128,10 @@ def write_annotations(current_annotations: set[tuple]):
127128
radius, limit = st.columns(2)
128129
with radius:
129130
filter_radius = st.number_input(
130-
"Insert a radius", value=1.00, step=0.01, key="filter_radius"
131+
"Insert a radius",
132+
value=1.00,
133+
step=0.01,
134+
key="filter_radius",
131135
)
132136
st.info("Max distance to annotation")
133137
with limit:
@@ -188,9 +192,12 @@ def write_annotations(current_annotations: set[tuple]):
188192
)
189193
return None
190194

191-
for (req_id, req_external_id, req_summary, req_embedding), group in groupby(
192-
rows, lambda x: x[0:4]
193-
):
195+
for (
196+
req_id,
197+
req_external_id,
198+
req_summary,
199+
req_embedding,
200+
), group in groupby(rows, lambda x: x[0:4]):
194201
st.divider()
195202
with st.container():
196203
st.subheader(f" Inspect Requirement {req_external_id}")
@@ -209,7 +216,9 @@ def write_annotations(current_annotations: set[tuple]):
209216
test_script,
210217
test_case,
211218
) in group:
212-
current_annotation = current_test_cases.get(test_case, set())
219+
current_annotation = current_test_cases.get(
220+
test_case, set()
221+
)
213222
current_test_cases.update({test_case: current_annotation})
214223
current_test_cases[test_case].add(
215224
(anno_id, anno_summary, anno_embedding, distance)
@@ -242,7 +251,9 @@ def write_annotations(current_annotations: set[tuple]):
242251
with anno:
243252
with st.container(border=True):
244253
st.write("Annotations")
245-
st.info("List of Annotations for chosen Test case")
254+
st.info(
255+
"List of Annotations for chosen Test case"
256+
)
246257
write_annotations(
247258
current_annotations=current_test_cases[
248259
st.session_state["radio_choice"]
@@ -266,7 +277,9 @@ def write_annotations(current_annotations: set[tuple]):
266277
annotation_vectors = np.array(anno_embeddings)
267278
if select == "2D":
268279
plot_2_sets_in_one_2d(
269-
minifold_vectors_2d(requirement_vectors),
280+
minifold_vectors_2d(
281+
requirement_vectors
282+
),
270283
minifold_vectors_2d(annotation_vectors),
271284
"Requirement",
272285
"Annotations",

test2text/pages/reports/report_by_tc.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
def make_a_tc_report():
1717
with get_db_client() as db:
1818
from test2text.services.embeddings.embed import embed_requirement
19+
1920
st.header("Test2Text Report")
2021

2122
def write_requirements(current_requirements: set[tuple]):
@@ -27,7 +28,13 @@ def write_requirements(current_requirements: set[tuple]):
2728
with dist:
2829
st.write("Distance")
2930

30-
for req_id, req_external_id, req_summary, _, distance in current_requirements:
31+
for (
32+
req_id,
33+
req_external_id,
34+
req_summary,
35+
_,
36+
distance,
37+
) in current_requirements:
3138
req, summary, dist = st.columns(3)
3239
with req:
3340
st.write(f"#{req_id} Requirement {req_external_id}")
@@ -126,7 +133,10 @@ def write_requirements(current_requirements: set[tuple]):
126133
radius, limit = st.columns(2)
127134
with radius:
128135
filter_radius = st.number_input(
129-
"Insert a radius", value=1.00, step=0.01, key="filter_radius"
136+
"Insert a radius",
137+
value=1.00,
138+
step=0.01,
139+
key="filter_radius",
130140
)
131141
st.info("Max distance to annotation")
132142
with limit:
@@ -212,7 +222,9 @@ def write_requirements(current_requirements: set[tuple]):
212222
current_reqs = current_annotations.get(
213223
current_annotation, set()
214224
)
215-
current_annotations.update({current_annotation: current_reqs})
225+
current_annotations.update(
226+
{current_annotation: current_reqs}
227+
)
216228
current_annotations[current_annotation].add(
217229
(
218230
req_id,
@@ -262,9 +274,13 @@ def write_requirements(current_requirements: set[tuple]):
262274
with anno:
263275
with st.container(border=True):
264276
st.write("Requirements")
265-
st.info("Found Requirements for chosen annotation")
277+
st.info(
278+
"Found Requirements for chosen annotation"
279+
)
266280
write_requirements(
267-
current_annotations[reqs_by_anno[radio_choice]]
281+
current_annotations[
282+
reqs_by_anno[radio_choice]
283+
]
268284
)
269285
with viz:
270286
with st.container(border=True):
@@ -285,7 +301,9 @@ def write_requirements(current_requirements: set[tuple]):
285301
if select == "2D":
286302
plot_2_sets_in_one_2d(
287303
minifold_vectors_2d(annotation_vectors),
288-
minifold_vectors_2d(requirement_vectors),
304+
minifold_vectors_2d(
305+
requirement_vectors
306+
),
289307
"Annotation",
290308
"Requirements",
291309
first_color="red",

test2text/services/db/client.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def get_table_names(self):
8888
cursor.close()
8989
return tables
9090

91-
9291
@property
9392
def get_db_full_info(self):
9493
"""
@@ -102,16 +101,20 @@ def get_db_full_info(self):
102101
table_names = self.get_table_names()
103102
for table_name in table_names:
104103
row_count = self.count_all_entries_in_table(table_name)
105-
db_tables_info.update({
106-
table_name: row_count,
107-
})
104+
db_tables_info.update(
105+
{
106+
table_name: row_count,
107+
}
108+
)
108109
return db_tables_info
109110

110111
def count_all_entries_in_table(self, table: str) -> int:
111112
count = self.conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
112113
return count
113114

114-
def count_notnull_entries_in_table(self,column: str, table: str) -> Union[int, None]:
115+
def count_notnull_entries_in_table(
116+
self, column: str, table: str
117+
) -> Union[int, None]:
115118
if self.has_column(column, table):
116119
count = self.conn.execute(
117120
f"SELECT COUNT(*) FROM {table} WHERE {column} IS NOT NULL"
@@ -130,4 +133,4 @@ def has_column(self, column_name: str, table_name: str) -> bool:
130133
cursor = self.conn.execute(f'PRAGMA table_info("{table_name}")')
131134
columns = [row[1] for row in cursor.fetchall()] # row[1] is the column name
132135
cursor.close()
133-
return column_name in columns
136+
return column_name in columns

test2text/services/embeddings/annotation_embeddings_controls.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
BATCH_SIZE = 30
66

77

8-
9-
10-
118
OnProgress = Callable[[float], None]
129

1310

1411
def embed_annotations(*_, embed_all=False, on_progress: OnProgress = None):
1512
with get_db_client() as db:
1613
from .embed import embed_annotations_batch
14+
1715
annotations_count = db.count_all_entries_in_table("Annotations")
1816
embedded_annotations_count = db.count_embedded_entries_in_table("Annotations")
1917
if embed_all:
@@ -24,7 +22,9 @@ def embed_annotations(*_, embed_all=False, on_progress: OnProgress = None):
2422
batch = []
2523

2624
def write_batch(batch: list[tuple[int, str]]):
27-
embeddings = embed_annotations_batch([annotation for _, annotation in batch])
25+
embeddings = embed_annotations_batch(
26+
[annotation for _, annotation in batch]
27+
)
2828
for i, (anno_id, annotation) in enumerate(batch):
2929
embedding = embeddings[i]
3030
db.annotations.set_embedding(anno_id, embedding)

test2text/services/embeddings/cache_distances.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def refresh_and_get_distances() -> list[float]:
2525
current_req_annos = 0
2626
if current_req_annos < 5 or distance < 0.7:
2727
db.annos_to_reqs.insert(
28-
annotation_id=anno_id, requirement_id=req_id, cached_distance=distance
28+
annotation_id=anno_id,
29+
requirement_id=req_id,
30+
cached_distance=distance,
2931
)
3032
current_req_annos += 1
3133
return distances

test2text/services/loaders/convert_trace_annos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def write_table_row(*args, **kwargs):
2727

2828
def trace_test_cases_to_annos(trace_files: list):
2929
with get_db_client() as db:
30-
3130
st.info(
3231
"Reading trace files and inserting test case + annotations pairs into database..."
3332
)
@@ -57,7 +56,9 @@ def trace_test_cases_to_annos(trace_files: list):
5756
case_id = db.test_cases.get_or_insert(
5857
test_script=test_script, test_case=current_tc
5958
)
60-
annotation_id = db.annotations.get_or_insert(summary=concat_summary)
59+
annotation_id = db.annotations.get_or_insert(
60+
summary=concat_summary
61+
)
6162
insertions.append(
6263
db.cases_to_annos.insert(
6364
case_id=case_id, annotation_id=annotation_id
@@ -78,4 +79,3 @@ def trace_test_cases_to_annos(trace_files: list):
7879
sum(insertions),
7980
len(insertions) - sum(insertions),
8081
)
81-

test2text/services/loaders/index_annotations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
def index_annotations_from_files(files: list, *_, on_file_start: OnFileStart = None):
1515
with get_db_client() as db:
16-
1716
for i, file in enumerate(files):
1817
file_counter = None
1918
if on_file_start:

test2text/services/visualisation/visualize_vectors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ def visualize_vectors():
213213
)
214214
progress_bar.progress(80, "Plotted in 3D")
215215

216-
anno_vectors_2d = minifold_vectors_2d(extract_closest_annotation_vectors(db))
216+
anno_vectors_2d = minifold_vectors_2d(
217+
extract_closest_annotation_vectors(db)
218+
)
217219

218220
plot_2_sets_in_one_2d(
219221
reqs_vectors_2d, anno_vectors_2d, "Requerements", "Annotations"

0 commit comments

Comments
 (0)