Skip to content

Commit 56d9e17

Browse files
authored
Fix Resume RAGulate Queries (#552)
preset the query count (_finished_queries) to the count of completed queries and update tqdm
1 parent a624749 commit 56d9e17

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
# database.
9191
self._tru.reset_database()
9292

93+
total_existing_queries = 0
9394
for dataset in datasets:
9495
queries, golden_set = dataset.get_queries_and_golden_set()
9596
if self.sample_percent < 1.0:
@@ -105,6 +106,8 @@ def __init__(
105106
app_ids=[dataset.name]
106107
)
107108
existing_queries = existing_records["input"].dropna().tolist()
109+
total_existing_queries += len(existing_queries)
110+
108111
queries = [query for query in queries if query not in existing_queries]
109112

110113
self._queries[dataset.name] = queries
@@ -114,6 +117,9 @@ def __init__(
114117
metric_count = 4
115118
self._total_feedbacks = self._total_queries * metric_count
116119

120+
# Set finished queries count to total existing queries
121+
self._finished_queries = total_existing_queries
122+
117123
def signal_handler(self, _, __):
118124
"""Handle SIGINT signal."""
119125
self._sigint_received = True
@@ -202,7 +208,10 @@ def query(self):
202208
"(r)unning, (w)aiting, (f)ailed, (s)kipped"
203209
)
204210

205-
self._progress = tqdm(total=(self._total_queries + self._total_feedbacks))
211+
self._progress = tqdm(
212+
total=(self._total_queries + self._total_feedbacks),
213+
initial=self._finished_queries,
214+
)
206215

207216
for dataset_name in self._queries:
208217
feedback_functions = [

0 commit comments

Comments
 (0)