11from typing import List , Tuple , Dict , Set
22import re
33from pathlib import Path
4+ from multiprocessing import Process , Queue
45
56from tqdm import tqdm
67import torch
1920OUTPUT_FILE : Path = Path ("data/doc2query.tsv" )
2021
2122
23+ def batched_writer (queue : Queue , output_file : Path , num_samples : int ):
24+ while True :
25+ items_to_write = []
26+
27+ # Wait till told to capture new queries
28+ item = queue .get ()
29+ if item == "STOP" :
30+ break
31+
32+ # Append the first item
33+ items_to_write .append (item )
34+
35+ # Now, try to get other items without blocking if the queue is empty
36+ break_after = False
37+ while not queue .empty ():
38+ item = queue .get_nowait ()
39+ if item == "STOP" :
40+ break_after = True
41+ break
42+ items_to_write .append (item )
43+
44+ # Write all collected items to the output file
45+ new_data = {
46+ "docid" : [docid for docid , _ in items_to_write ],
47+ }
48+
49+ for i in range (num_samples ):
50+ new_data [f"query_{ i } " ] = [
51+ queries [i ] if i < len (queries ) else "" for _ , queries in items_to_write
52+ ]
53+
54+ df = pd .DataFrame (new_data )
55+ df .to_csv (output_file , mode = "a" , sep = "\t " , index = False , header = False )
56+
57+ if break_after :
58+ break
59+
60+
2261class Doc2Query :
2362 """
2463 A class for generating queries from documents using T5.
@@ -45,6 +84,10 @@ class Doc2Query:
4584 The output dataframe
4685 pattern : re.Pattern
4786 The pattern to remove URLs from the input
87+ write_queue : Queue
88+ The queue to write to
89+ writer_process : Process
90+ The process to write to the output file
4891 """
4992
5093 model : T5ForConditionalGeneration
@@ -56,6 +99,8 @@ class Doc2Query:
5699 input_file : Path
57100 output_file : Path
58101 pattern : re .Pattern
102+ write_queue : Queue
103+ writer_process : Process
59104
60105 def __init__ (
61106 self ,
@@ -105,6 +150,12 @@ def __init__(
105150 assert input_file .exists ()
106151 self .input_file = input_file
107152 self .pattern = re .compile ("^\\ s*http\\ S+" )
153+ self .write_queue = Queue ()
154+ self .writer_process = Process (
155+ target = batched_writer ,
156+ args = (self .write_queue , self .output_file , self .num_samples ),
157+ )
158+ self .writer_process .start ()
108159
109160 def add_new_queries (self , new_queries : List [Tuple [str , List [str ]]]):
110161 """
@@ -115,24 +166,8 @@ def add_new_queries(self, new_queries: List[Tuple[str, List[str]]]):
115166 new_queries : List[Tuple[str, List[str]]]
116167 The new queries to add: (docid, queries)
117168 """
118- new_data : Dict [str , List [str ]] = {
119- "docid" : [],
120- "query_0" : [],
121- "query_1" : [],
122- "query_2" : [],
123- }
124-
125169 for docid , queries in new_queries :
126- assert 1 <= len (queries ) <= self .num_samples
127- new_data ["docid" ].append (docid )
128- for i , query in enumerate (queries ):
129- new_data [f"query_{ i } " ].append (query )
130-
131- self .write_output (new_data )
132-
133- def write_output (self , new_data : Dict [str , List [str ]]):
134- df = pd .DataFrame (new_data )
135- df .to_csv (self .output_file , mode = "a" , sep = "\t " , index = False , header = False )
170+ self .write_queue .put ((docid , queries ))
136171
137172 def _already_processed_docids (self ) -> Set [int ]:
138173 """Get set of docids that have already been processed."""
@@ -152,7 +187,7 @@ def generate_queries(self):
152187 )
153188
154189 processed_docids = self ._already_processed_docids ()
155-
190+
156191 skipping_ids = input_df ["docid" ].nunique ()
157192 input_df = input_df [~ input_df ["docid" ].isin (processed_docids )]
158193 skipping_ids -= input_df ["docid" ].nunique ()
@@ -189,6 +224,11 @@ def _generate_queries(self, documents: List[Tuple[str, str]]):
189224 queries = self ._doc2query (docs )
190225 new_queries : List [Tuple [str , List [str ]]] = list (zip (docids , queries ))
191226 self .add_new_queries (new_queries )
227+ self .write_queue .put ("STOP" )
228+
229+ def __del__ (self ):
230+ self .write_queue .put ("STOP" )
231+ self .writer_process .join ()
192232
193233 def _doc2query (self , texts : List [str ]) -> List [List [str ]]:
194234 """
@@ -231,7 +271,11 @@ def sort_output_file(self):
231271 print ("Output file does not exist." )
232272 return
233273
234- df = pd .read_table (self .output_file , names = ["docid" ] + [f"query_{ i } " for i in range (self .num_samples )], header = None )
274+ df = pd .read_table (
275+ self .output_file ,
276+ names = ["docid" ] + [f"query_{ i } " for i in range (self .num_samples )],
277+ header = None ,
278+ )
235279 df = df .sort_values (by = "docid" )
236280 df .to_csv (self .output_file , sep = "\t " , index = False , header = False )
237281
@@ -242,20 +286,31 @@ def verify_output(self):
242286 print ("Output file does not exist. Verification failed!" )
243287 return
244288
245- input_df = pd .read_table (self .input_file , names = ["docid" , "document" ], header = None )
246- output_df = pd .read_table (self .output_file , names = ["docid" ] + [f"query_{ i } " for i in range (self .num_samples )], header = None )
247-
289+ input_df = pd .read_table (
290+ self .input_file , names = ["docid" , "document" ], header = None
291+ )
292+ output_df = pd .read_table (
293+ self .output_file ,
294+ names = ["docid" ] + [f"query_{ i } " for i in range (self .num_samples )],
295+ header = None ,
296+ )
297+
248298 input_docids = set (input_df ["docid" ].values )
249299 output_docids = set (output_df ["docid" ].values )
250300
251301 missing_docids = input_docids - output_docids
252302
253303 if not missing_docids :
254- print ("All docids from input_file have corresponding queries in the output_file." )
304+ print (
305+ "All docids from input_file have corresponding queries in the output_file."
306+ )
255307 else :
256- print (f"Missing queries for { len (missing_docids )} docids in the output_file." )
308+ print (
309+ f"Missing queries for { len (missing_docids )} docids in the output_file."
310+ )
257311 print ("Some of the missing docids are:" , list (missing_docids )[:10 ])
258312
313+
259314if __name__ == "__main__" :
260315 d2q = Doc2Query ()
261316 d2q .generate_queries ()
0 commit comments