Skip to content

Commit fe1bb57

Browse files
committed
Further optimizing doc2query script
1 parent 25436ff commit fe1bb57

File tree

1 file changed

+79
-24
lines changed

1 file changed

+79
-24
lines changed

scripts/doc2query-t5.py

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Tuple, Dict, Set
22
import re
33
from pathlib import Path
4+
from multiprocessing import Process, Queue
45

56
from tqdm import tqdm
67
import torch
@@ -19,6 +20,44 @@
1920
OUTPUT_FILE: Path = Path("data/doc2query.tsv")
2021

2122

23+
def batched_writer(queue: Queue, output_file: Path, num_samples: int):
24+
while True:
25+
items_to_write = []
26+
27+
# Wait till told to capture new queries
28+
item = queue.get()
29+
if item == "STOP":
30+
break
31+
32+
# Append the first item
33+
items_to_write.append(item)
34+
35+
# Now, try to get other items without blocking if the queue is empty
36+
break_after = False
37+
while not queue.empty():
38+
item = queue.get_nowait()
39+
if item == "STOP":
40+
break_after = True
41+
break
42+
items_to_write.append(item)
43+
44+
# Write all collected items to the output file
45+
new_data = {
46+
"docid": [docid for docid, _ in items_to_write],
47+
}
48+
49+
for i in range(num_samples):
50+
new_data[f"query_{i}"] = [
51+
queries[i] if i < len(queries) else "" for _, queries in items_to_write
52+
]
53+
54+
df = pd.DataFrame(new_data)
55+
df.to_csv(output_file, mode="a", sep="\t", index=False, header=False)
56+
57+
if break_after:
58+
break
59+
60+
2261
class Doc2Query:
2362
"""
2463
A class for generating queries from documents using T5.
@@ -45,6 +84,10 @@ class Doc2Query:
4584
The output dataframe
4685
pattern : re.Pattern
4786
The pattern to remove URLs from the input
87+
write_queue : Queue
88+
The queue to write to
89+
writer_process : Process
90+
The process to write to the output file
4891
"""
4992

5093
model: T5ForConditionalGeneration
@@ -56,6 +99,8 @@ class Doc2Query:
5699
input_file: Path
57100
output_file: Path
58101
pattern: re.Pattern
102+
write_queue: Queue
103+
writer_process: Process
59104

60105
def __init__(
61106
self,
@@ -105,6 +150,12 @@ def __init__(
105150
assert input_file.exists()
106151
self.input_file = input_file
107152
self.pattern = re.compile("^\\s*http\\S+")
153+
self.write_queue = Queue()
154+
self.writer_process = Process(
155+
target=batched_writer,
156+
args=(self.write_queue, self.output_file, self.num_samples),
157+
)
158+
self.writer_process.start()
108159

109160
def add_new_queries(self, new_queries: List[Tuple[str, List[str]]]):
110161
"""
@@ -115,24 +166,8 @@ def add_new_queries(self, new_queries: List[Tuple[str, List[str]]]):
115166
new_queries : List[Tuple[str, List[str]]]
116167
The new queries to add: (docid, queries)
117168
"""
118-
new_data: Dict[str, List[str]] = {
119-
"docid": [],
120-
"query_0": [],
121-
"query_1": [],
122-
"query_2": [],
123-
}
124-
125169
for docid, queries in new_queries:
126-
assert 1 <= len(queries) <= self.num_samples
127-
new_data["docid"].append(docid)
128-
for i, query in enumerate(queries):
129-
new_data[f"query_{i}"].append(query)
130-
131-
self.write_output(new_data)
132-
133-
def write_output(self, new_data: Dict[str, List[str]]):
134-
df = pd.DataFrame(new_data)
135-
df.to_csv(self.output_file, mode="a", sep="\t", index=False, header=False)
170+
self.write_queue.put((docid, queries))
136171

137172
def _already_processed_docids(self) -> Set[int]:
138173
"""Get set of docids that have already been processed."""
@@ -152,7 +187,7 @@ def generate_queries(self):
152187
)
153188

154189
processed_docids = self._already_processed_docids()
155-
190+
156191
skipping_ids = input_df["docid"].nunique()
157192
input_df = input_df[~input_df["docid"].isin(processed_docids)]
158193
skipping_ids -= input_df["docid"].nunique()
@@ -189,6 +224,11 @@ def _generate_queries(self, documents: List[Tuple[str, str]]):
189224
queries = self._doc2query(docs)
190225
new_queries: List[Tuple[str, List[str]]] = list(zip(docids, queries))
191226
self.add_new_queries(new_queries)
227+
self.write_queue.put("STOP")
228+
229+
def __del__(self):
230+
self.write_queue.put("STOP")
231+
self.writer_process.join()
192232

193233
def _doc2query(self, texts: List[str]) -> List[List[str]]:
194234
"""
@@ -231,7 +271,11 @@ def sort_output_file(self):
231271
print("Output file does not exist.")
232272
return
233273

234-
df = pd.read_table(self.output_file, names=["docid"] + [f"query_{i}" for i in range(self.num_samples)], header=None)
274+
df = pd.read_table(
275+
self.output_file,
276+
names=["docid"] + [f"query_{i}" for i in range(self.num_samples)],
277+
header=None,
278+
)
235279
df = df.sort_values(by="docid")
236280
df.to_csv(self.output_file, sep="\t", index=False, header=False)
237281

@@ -242,20 +286,31 @@ def verify_output(self):
242286
print("Output file does not exist. Verification failed!")
243287
return
244288

245-
input_df = pd.read_table(self.input_file, names=["docid", "document"], header=None)
246-
output_df = pd.read_table(self.output_file, names=["docid"] + [f"query_{i}" for i in range(self.num_samples)], header=None)
247-
289+
input_df = pd.read_table(
290+
self.input_file, names=["docid", "document"], header=None
291+
)
292+
output_df = pd.read_table(
293+
self.output_file,
294+
names=["docid"] + [f"query_{i}" for i in range(self.num_samples)],
295+
header=None,
296+
)
297+
248298
input_docids = set(input_df["docid"].values)
249299
output_docids = set(output_df["docid"].values)
250300

251301
missing_docids = input_docids - output_docids
252302

253303
if not missing_docids:
254-
print("All docids from input_file have corresponding queries in the output_file.")
304+
print(
305+
"All docids from input_file have corresponding queries in the output_file."
306+
)
255307
else:
256-
print(f"Missing queries for {len(missing_docids)} docids in the output_file.")
308+
print(
309+
f"Missing queries for {len(missing_docids)} docids in the output_file."
310+
)
257311
print("Some of the missing docids are:", list(missing_docids)[:10])
258312

313+
259314
if __name__ == "__main__":
260315
d2q = Doc2Query()
261316
d2q.generate_queries()

0 commit comments

Comments
 (0)