Skip to content

Commit b9338c6

Browse files
committed
fixed new methods in client.py
1 parent bc466b8 commit b9338c6

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

test2text/pages/controls/controls_page.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ def controls_page():
1313
def refresh_counts():
1414
with get_db_client() as db:
1515
st.session_state["all_annotations_count"] = (
16-
db.count_all_entries_in_table("Annotations")
16+
db.count_all_entries("Annotations")
1717
)
1818
st.session_state["embedded_annotations_count"] = (
19-
db.count_embedded_entries_in_table("Annotations")
19+
db.count_notnull_entries("embedding",from_table="Annotations")
2020
)
2121

2222
refresh_counts()

test2text/services/db/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def get_table_names(self):
8888
cursor.close()
8989
return tables
9090

91-
def get_column_values(self, columns: list[str], from_table: str):
91+
def get_column_values(self, *columns: str, from_table: str):
9292
cursor = self.conn.execute(f"SELECT {', '.join(columns)} FROM {from_table}")
9393
return cursor.fetchall()
9494

@@ -117,10 +117,10 @@ def count_all_entries(self, from_table: str) -> int:
117117
return count
118118

119119
def count_notnull_entries(
120-
self, columns: list[str], from_table: str
121-
) -> Union[int, None]:
120+
self, *columns: str, from_table: str
121+
) -> int:
122122
count = self.conn.execute(
123-
f"SELECT COUNT(*) FROM {from_table} WHERE {', '.join(columns)} IS NOT NULL"
123+
f"SELECT COUNT(*) FROM {from_table} WHERE {' AND '.join([column + ' IS NOT NULL' for column in columns])}"
124124
).fetchone()[0]
125125
return count
126126

test2text/services/embeddings/annotation_embeddings_controls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def embed_annotations(*_, embed_all=False, on_progress: OnProgress = None):
1212
with get_db_client() as db:
1313
from .embed import embed_annotations_batch
1414

15-
annotations_count = db.count_all_entries_in_table("Annotations")
16-
embedded_annotations_count = db.count_embedded_entries_in_table("Annotations")
15+
annotations_count = db.count_all_entries("Annotations")
16+
embedded_annotations_count = db.count_notnull_entries("embedding",from_table="Annotations")
1717
if embed_all:
1818
annotations_to_embed = annotations_count
1919
else:

0 commit comments

Comments
 (0)