Skip to content

Commit 3f4f11e

Browse files
committed
added count property for all tables, added docstring to table's classes
1 parent b689f2f commit 3f4f11e

File tree

11 files changed

+194
-33
lines changed

11 files changed

+194
-33
lines changed

test2text/pages/controls/controls_page.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ def controls_page():
1212

1313
def refresh_counts():
1414
with get_db_client() as db:
15-
st.session_state["all_annotations_count"] = db.count_all_entries(
16-
"Annotations"
17-
)
15+
st.session_state["all_annotations_count"] = db.annotations.count
1816
st.session_state["embedded_annotations_count"] = (
1917
db.count_notnull_entries("embedding", from_table="Annotations")
2018
)

test2text/services/db/client.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
7474
def __enter__(self):
7575
return self
7676

77-
def get_table_names(self):
77+
def get_table_names(self) -> list[str]:
7878
"""
7979
Returns a list of all user-defined tables in the database.
8080
@@ -87,7 +87,13 @@ def get_table_names(self):
8787
cursor.close()
8888
return tables
8989

90-
def get_column_values(self, *columns: str, from_table: str):
90+
def get_column_values(self, *columns: str, from_table: str) -> list[tuple]:
91+
"""
92+
Returns the values of the specified columns from the specified table.
93+
:param columns: list of column names
94+
:param from_table: name of the table
95+
:return: list of tuples containing the values of the specified columns
96+
"""
9197
cursor = self.conn.execute(f"SELECT {', '.join(columns)} FROM {from_table}")
9298
return cursor.fetchall()
9399

@@ -116,6 +122,11 @@ def count_all_entries(self, from_table: str) -> int:
116122
return count
117123

118124
def count_notnull_entries(self, *columns: str, from_table: str) -> int:
125+
"""
126+
Count the number of non-null entries in the specified columns of the specified table.
127+
:param columns: list of column names
128+
:param from_table: name of the table
129+
"""
119130
count = self.conn.execute(
120131
f"SELECT COUNT(*) FROM {from_table} WHERE {' AND '.join([column + ' IS NOT NULL' for column in columns])}"
121132
).fetchone()[0]
@@ -135,6 +146,9 @@ def has_column(self, column_name: str, table_name: str) -> bool:
135146
return column_name in columns
136147

137148
def get_null_entries(self, from_table: str) -> list:
149+
"""
150+
Returns values (id and summary) witch has null values in its embedding column.
151+
"""
138152
cursor = self.conn.execute(
139153
f"SELECT id, summary FROM {from_table} WHERE embedding IS NULL"
140154
)
@@ -174,8 +188,8 @@ def join_all_tables_by_requirements(
174188
self, where_clauses="", params=None
175189
) -> list[tuple]:
176190
"""
177-
Join all tables related to requirements based on the provided where clauses and parameters.
178-
return a list of tuples containing :
191+
Extract values from requirements with related annotations and their test cases based on the provided where clauses and parameters.
192+
Return a list of tuples containing :
179193
req_id,
180194
req_external_id,
181195
req_summary,
@@ -222,6 +236,14 @@ def join_all_tables_by_requirements(
222236
def get_ordered_values_from_requirements(
223237
self, distance_sql="", where_clauses="", distance_order_sql="", params=None
224238
) -> list[tuple]:
239+
"""
240+
Extracted values from Requirements table based on the provided where clauses and specified parameters ordered by distance and id.
241+
Return a list of tuples containing :
242+
req_id,
243+
req_external_id,
244+
req_summary,
245+
distance between annotation and requirement embeddings,
246+
"""
225247
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
226248
sql = f"""
227249
SELECT
@@ -241,6 +263,14 @@ def get_ordered_values_from_requirements(
241263
def get_ordered_values_from_test_cases(
242264
self, distance_sql="", where_clauses="", distance_order_sql="", params=None
243265
) -> list[tuple]:
266+
"""
267+
Extracted values from TestCases table based on the provided where clauses and specified parameters ordered by distance and id.
268+
Return a list of tuples containing :
269+
case_id,
270+
test_script,
271+
test_case,
272+
distance between test case and typed by user text embeddings if it is specified,
273+
"""
244274
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
245275
sql = f"""
246276
SELECT
@@ -260,6 +290,21 @@ def get_ordered_values_from_test_cases(
260290
def join_all_tables_by_test_cases(
261291
self, where_clauses="", params=None
262292
) -> list[tuple]:
293+
"""
294+
Join all tables related to test cases based on the provided where clauses and specified parameters.
295+
Return a list of tuples containing :
296+
case_id,
297+
test_script,
298+
test_case,
299+
anno_id,
300+
anno_summary,
301+
anno_embedding,
302+
distance between annotation and requirement embeddings,
303+
req_id,
304+
req_external_id,
305+
req_summary,
306+
req_embedding
307+
"""
263308
where_sql = ""
264309
if where_clauses:
265310
where_sql = f"WHERE {' AND '.join(where_clauses)}"
@@ -294,7 +339,10 @@ def join_all_tables_by_test_cases(
294339
data = self.conn.execute(sql, params)
295340
return data.fetchall()
296341

297-
def get_embeddings_by_id(self, id1: int, from_table: str):
342+
def get_embeddings_by_id(self, id1: int, from_table: str) -> float:
343+
"""
344+
Returns the embedding of the specified id from the specified table.
345+
"""
298346
cursor = self.conn.execute(
299347
f"SELECT embedding FROM {from_table} WHERE id = ?", (id1,)
300348
)

test2text/services/db/streamlit_conn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33

44
def get_db_client() -> DbClient:
5+
"""
6+
Returns a DbClient instance connected to the database where requirements, annotations, test cases and their relations are stored.
7+
:return: DbClient instance
8+
"""
59
from test2text.services.utils import res_folder
610

711
return DbClient(res_folder.get_file_path("db.sqlite3"))

test2text/services/db/tables/annos_to_reqs.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33

44

55
class AnnotationsToRequirementsTable(AbstractTable):
6-
def init_table(self):
6+
"""
7+
This class represents the relationship between annotations and requirements in the database by closest distance between them.
8+
"""
9+
10+
def init_table(self) -> None:
11+
"""
12+
Creates the AnnotationsToRequirements table in the database if it does not already exist.
13+
"""
714
self.connection.execute("""
815
CREATE TABLE IF NOT EXISTS AnnotationsToRequirements (
916
annotation_id INTEGER NOT NULL,
@@ -15,7 +22,10 @@ def init_table(self):
1522
)
1623
""")
1724

18-
def recreate_table(self):
25+
def recreate_table(self) -> None:
26+
"""
27+
Drops the AnnotationsToRequirements table if it exists and recreates it.
28+
"""
1929
self.connection.execute("""
2030
DROP TABLE IF EXISTS AnnotationsToRequirements
2131
""")
@@ -24,6 +34,13 @@ def recreate_table(self):
2434
def insert(
2535
self, annotation_id: int, requirement_id: int, cached_distance: float
2636
) -> bool:
37+
"""
38+
Inserts a new entry into the AnnotationsToRequirements table.
39+
:param annotation_id: The ID of the annotation
40+
:param requirement_id: The ID of the requirement
41+
:param cached_distance: The cached distance between the annotation and the requirement
42+
:return: True if the insertion was successful, False otherwise.
43+
"""
2744
try:
2845
cursor = self.connection.execute(
2946
"""
@@ -42,7 +59,12 @@ def insert(
4259
pass
4360
return False
4461

62+
@property
4563
def count(self) -> int:
64+
"""
65+
Returns the number of entries in the AnnotationsToRequirements table.
66+
:return: int - the number of entries in the table.
67+
"""
4668
cursor = self.connection.execute(
4769
"SELECT COUNT(*) FROM AnnotationsToRequirements"
4870
)

test2text/services/db/tables/annotations.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,18 @@
77

88

99
class AnnotationsTable(AbstractTable):
10+
"""
11+
This class represents the annotations of test cases in the database.
12+
"""
13+
1014
def __init__(self, connection: Connection, embedding_size: int):
1115
super().__init__(connection)
1216
self.embedding_size = embedding_size
1317

1418
def init_table(self):
19+
"""
20+
Creates the Annotations table in the database if it does not already exist.
21+
"""
1522
self.connection.execute(
1623
Template("""
1724
CREATE TABLE IF NOT EXISTS Annotations (
@@ -29,6 +36,12 @@ def init_table(self):
2936
)
3037

3138
def insert(self, summary: str, embedding: list[float] = None) -> Optional[int]:
39+
"""
40+
Inserts a new annotation into the database. If the annotation already exists, it updates the existing record.
41+
:param summary: The summary of the annotation
42+
:param embedding: The embedding of the annotation (optional)
43+
:return: The ID of the inserted or updated annotation, or None if the annotation already exists and was updated.
44+
"""
3245
cursor = self.connection.execute(
3346
"""
3447
INSERT OR IGNORE INTO Annotations (summary, embedding)
@@ -45,6 +58,12 @@ def insert(self, summary: str, embedding: list[float] = None) -> Optional[int]:
4558
return None
4659

4760
def get_or_insert(self, summary: str, embedding: list[float] = None) -> int:
61+
"""
62+
Inserts a new annotation into the database if it does not already exist, otherwise returns the existing annotation's ID.
63+
:param summary: The summary of the annotation
64+
:param embedding: The embedding of the annotation (optional)
65+
:return: The ID of the inserted or existing annotation.
66+
"""
4867
inserted_id = self.insert(summary, embedding)
4968
if inserted_id is not None:
5069
return inserted_id
@@ -61,6 +80,11 @@ def get_or_insert(self, summary: str, embedding: list[float] = None) -> int:
6180
return result[0]
6281

6382
def set_embedding(self, anno_id: int, embedding: list[float]) -> None:
83+
"""
84+
Sets the embedding for a given annotation ID.
85+
:param anno_id: The ID of the annotation
86+
:param embedding: The new embedding for the annotation
87+
"""
6488
if len(embedding) != self.embedding_size:
6589
raise ValueError(
6690
f"Embedding size must be {self.embedding_size}, got {len(embedding)}"
@@ -74,3 +98,12 @@ def set_embedding(self, anno_id: int, embedding: list[float]) -> None:
7498
""",
7599
(serialized_embedding, anno_id),
76100
)
101+
102+
@property
103+
def count(self) -> int:
104+
"""
105+
Returns the number of entries in the Annotations table.
106+
:return: int - the number of entries in the table.
107+
"""
108+
cursor = self.connection.execute("SELECT COUNT(*) FROM Annotations")
109+
return cursor.fetchone()[0]

test2text/services/db/tables/cases_to_annos.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55

66
class TestCasesToAnnotationsTable(AbstractTable):
7-
def init_table(self):
7+
"""
8+
This class represents the relationship between test cases and annotations in the database.
9+
"""
10+
11+
def init_table(self) -> None:
812
self.connection.execute("""
913
CREATE TABLE IF NOT EXISTS CasesToAnnos (
1014
case_id INTEGER NOT NULL,
@@ -40,6 +44,11 @@ def insert(self, case_id: int, annotation_id: int) -> bool:
4044
pass
4145
return False
4246

47+
@property
4348
def count(self) -> int:
49+
"""
50+
Returns the number of entries in the CasesToAnnos table.
51+
:return: int - the number of entries in the table.
52+
"""
4453
cursor = self.connection.execute("SELECT COUNT(*) FROM CasesToAnnos")
4554
return cursor.fetchone()[0]

test2text/services/db/tables/requirements.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,18 @@
77

88

99
class RequirementsTable(AbstractTable):
10+
"""
11+
This class represents the requirements for test cases in the database.
12+
"""
13+
1014
def __init__(self, connection: Connection, embedding_size: int):
1115
super().__init__(connection)
1216
self.embedding_size = embedding_size
1317

14-
def init_table(self):
18+
def init_table(self) -> None:
19+
"""
20+
Creates the Requirements table in the database if it does not already exist.
21+
"""
1522
self.connection.execute(
1623
Template("""
1724
CREATE TABLE IF NOT EXISTS Requirements (
@@ -32,6 +39,13 @@ def init_table(self):
3239
def insert(
3340
self, summary: str, embedding: list[float] = None, external_id: str = None
3441
) -> Optional[int]:
42+
"""
43+
Inserts a new requirement into the database. If the requirement already exists, it updates the existing record.
44+
:param summary: The summary of the requirement
45+
:param embedding: The embedding of the requirement (optional)
46+
:param external_id: The external ID of the requirement (optional)
47+
:return: The ID of the inserted or updated requirement, or None if the requirement already exists and was updated.
48+
"""
3549
cursor = self.connection.execute(
3650
"""
3751
INSERT OR IGNORE INTO Requirements (summary, embedding, external_id)
@@ -49,3 +63,11 @@ def insert(
4963
return result[0]
5064
else:
5165
return None
66+
67+
def count(self) -> int:
68+
"""
69+
Returns the number of entries in the Requirements table.
70+
:return: int - the number of entries in the table.
71+
"""
72+
cursor = self.connection.execute("SELECT COUNT(*) FROM Requirements")
73+
return cursor.fetchone()[0]

test2text/services/db/tables/test_case.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ def __init__(self, connection: Connection, embedding_size: int):
1212
self.embedding_size = embedding_size
1313

1414
def init_table(self):
15+
"""
16+
Creates the TestCases table in the database if it does not already exist.
17+
"""
1518
self.connection.execute(
1619
Template("""
1720
@@ -34,6 +37,13 @@ def init_table(self):
3437
def insert(
3538
self, test_script: str, test_case: str, embedding: list[float] = None
3639
) -> Optional[int]:
40+
"""
41+
Inserts a new test case into the database. If the test case already exists, it updates the existing record.
42+
:param test_script: The test script of the test case
43+
:param test_case: The test case of the test case
44+
:param embedding: The embedding of the test case (optional)
45+
:return: The ID of the inserted or updated test case, or None if the test case already exists and was updated.
46+
"""
3747
cursor = self.connection.execute(
3848
"""
3949
INSERT OR IGNORE INTO TestCases (test_script, test_case, embedding)
@@ -54,6 +64,12 @@ def insert(
5464
return None
5565

5666
def get_or_insert(self, test_script: str, test_case: str) -> int:
67+
"""
68+
Inserts a new test case into the database if it does not already exist, otherwise returns the existing test case's ID.
69+
:param test_script: The test script of the test case
70+
:param test_case: The test case of the test case
71+
:return: The ID of the inserted or existing test case.
72+
"""
5773
inserted_id = self.insert(test_script, test_case)
5874
if inserted_id is not None:
5975
return inserted_id
@@ -68,3 +84,12 @@ def get_or_insert(self, test_script: str, test_case: str) -> int:
6884
result = cursor.fetchone()
6985
cursor.close()
7086
return result[0]
87+
88+
@property
89+
def count(self) -> int:
90+
"""
91+
Returns the number of entries in the TestCases table.
92+
:return: int - the number of entries in the table.
93+
"""
94+
cursor = self.connection.execute("SELECT COUNT(*) FROM TestCases")
95+
return cursor.fetchone()[0]

test2text/services/loaders/index_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ def write_batch():
6161
write_batch()
6262
write_batch()
6363
# Check requirements
64-
return db.count_all_entries(from_table="Requirements")
64+
return db.requirements.count

0 commit comments

Comments
 (0)