Skip to content

Commit 4d7ef94

Browse files
author
anna.yamkovaya
committed
fixed batch
1 parent 5b8c7d4 commit 4d7ef94

File tree

1 file changed

+31
-11
lines changed

1 file changed

+31
-11
lines changed

test2text/services/loaders/index_annotations.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,40 @@
1111

1212
def index_annotations_from_files(files: list):
1313
db = DbClient("./private/requirements.db")
14+
st.subheader("Processing files:")
15+
file_order, file_name, proc_file = st.columns(3)
16+
with file_order:
17+
st.write("Number")
18+
with file_name:
19+
st.write("File name")
20+
with proc_file:
21+
st.write("Processing file")
1422

1523
for i, file in enumerate(files):
16-
st.info(f"Processing file {i + 1}: {file.name}")
24+
file_order, file_name, proc_file = st.columns(3)
25+
with file_order:
26+
st.write(f"{i + 1}/{len(files)}")
27+
with file_name:
28+
st.write(f"{file.name}")
1729
stringio = io.StringIO(file.getvalue().decode("utf-8"))
1830
reader = csv.reader(stringio)
31+
insertions = []
1932

2033
if not list(reader):
21-
st.warning(f"The uploaded CSV file {file.name} is empty.")
34+
with proc_file:
35+
st.warning(f"The uploaded CSV file {file.name} is empty.")
2236
continue
23-
for row in reader:
37+
with proc_file:
38+
file_progress_bar = st.progress(0)
39+
count_rows = len(list(reader))
40+
for i, row in enumerate(reader):
41+
file_progress_bar.progress(round(i*100/count_rows))
2442
[summary, _, test_script, test_case, *_] = row
2543
anno_id = db.annotations.get_or_insert(summary=summary)
2644
tc_id = db.test_cases.get_or_insert(
2745
test_script=test_script, test_case=test_case
2846
)
29-
db.cases_to_annos.insert(case_id=tc_id, annotation_id=anno_id)
47+
insertions.append(db.cases_to_annos.insert(case_id=tc_id, annotation_id=anno_id))
3048

3149
db.conn.commit()
3250
# Embed annotations
@@ -39,8 +57,7 @@ def index_annotations_from_files(files: list):
3957

4058
batch = []
4159

42-
def write_batch():
43-
global batch
60+
def write_batch(batch):
4461
embeddings = embed_annotations_batch([annotation for _, annotation in batch])
4562
for i, (anno_id, annotation) in enumerate(batch):
4663
embedding = embeddings[i]
@@ -53,15 +70,18 @@ def write_batch():
5370
(embedding, anno_id),
5471
)
5572
db.conn.commit()
56-
batch = []
5773

74+
st.subheader("Processing annotations:")
75+
progress_bar = st.progress(0, text="Processing annotation:")
5876
for i, (anno_id, summary) in enumerate(annotations.fetchall()):
59-
if i % 100 == 0:
60-
st.info(f"Processing annotation {i + 1}/{annotations_count}")
77+
progress_bar.progress(round((i+1) * 100/annotations_count), text="Processing annotation:")
6178
batch.append((anno_id, summary))
6279
if len(batch) == BATCH_SIZE:
63-
write_batch()
64-
write_batch()
80+
write_batch(batch)
81+
batch = []
82+
83+
write_batch(batch)
84+
6585
# Check annotations
6686
cursor = db.conn.execute("""
6787
SELECT COUNT(*) FROM Annotations

0 commit comments

Comments
 (0)