1- import signal
2- import sys
3- from typing import List , Tuple , Dict
1+ from typing import List , Tuple , Dict , Set
42import re
53from pathlib import Path
64
@@ -57,7 +55,6 @@ class Doc2Query:
5755 batch_size : int
5856 input_file : Path
5957 output_file : Path
60- output_df : pd .DataFrame
6158 pattern : re .Pattern
6259
6360 def __init__ (
@@ -90,13 +87,6 @@ def __init__(
9087 output_file : Path, optional
9188 The output file, by default OUTPUT_FILE
9289 """
93- signal .signal (
94- signal .SIGINT , lambda signo , _ : self .__del__ () and sys .exit (signo )
95- )
96- signal .signal (
97- signal .SIGTERM , lambda signo , _ : self .__del__ () and sys .exit (signo )
98- )
99-
10090 self .device = torch .device (device )
10191 self .model = (
10292 T5ForConditionalGeneration .from_pretrained (model_name )
@@ -114,16 +104,6 @@ def __init__(
114104 self .output_file = output_file
115105 assert input_file .exists ()
116106 self .input_file = input_file
117- if self .output_file .exists ():
118- self .output_df = pd .read_table (
119- self .output_file ,
120- names = ["docid" ] + [f"query_{ i } " for i in range (num_samples )],
121- header = None ,
122- )
123- else :
124- self .output_df = pd .DataFrame (
125- columns = ["docid" ] + [f"query_{ i } " for i in range (num_samples )]
126- )
127107 self .pattern = re .compile ("^\\ s*http\\ S+" )
128108
129109 def add_new_queries (self , new_queries : List [Tuple [str , List [str ]]]):
@@ -144,44 +124,24 @@ def add_new_queries(self, new_queries: List[Tuple[str, List[str]]]):
144124
145125 for docid , queries in new_queries :
146126 assert 1 <= len (queries ) <= self .num_samples
127+ new_data ["docid" ].append (docid )
128+ for i , query in enumerate (queries ):
129+ new_data [f"query_{ i } " ].append (query )
147130
148- if len (queries ) == self .num_samples :
149- # We do not append the queries if they are already in the output dataframe
150- # remove docid from output_df
151- if docid in self .output_df ["docid" ].values :
152- self .output_df = self .output_df [self .output_df ["docid" ] != docid ]
153- new_data ["docid" ].append (docid )
154- for i , query in enumerate (queries ):
155- new_data [f"query_{ i } " ].append (query )
156- else :
157- assert docid in self .output_df ["docid" ].values
158- # We append the queries if they are not in the output dataframe
159- existing_queries : List [str ] = []
160- for i in range (self .num_samples ):
161- # fetch the existing queries, if they are not NaN / None / strip() == "" etc.
162- query = self .output_df [self .output_df ["docid" ] == docid ][
163- f"query_{ i } "
164- ].values [0 ]
165- if query is not None and query .strip () != "" :
166- existing_queries .append (query )
167-
168- assert len (existing_queries ) + len (queries ) == self .num_samples
169-
170- # remove docid from output_df
171- self .output_df = self .output_df [self .output_df ["docid" ] != docid ]
172- new_data ["docid" ].append (docid )
173- for i , query in enumerate (existing_queries + queries ):
174- new_data [f"query_{ i } " ].append (query )
175-
176- self .output_df = pd .concat (
177- [self .output_df , pd .DataFrame (new_data )], ignore_index = True
178- )
131+ self .write_output (new_data )
132+
133+ def write_output (self , new_data : Dict [str , List [str ]]):
134+ df = pd .DataFrame (new_data )
135+ df .to_csv (self .output_file , mode = "a" , sep = "\t " , index = False , header = False )
179136
180- def write_output (self ):
181- self .output_df .to_csv (self .output_file , sep = "\t " , index = False , header = False )
137+ def _already_processed_docids (self ) -> Set [int ]:
138+ """Get set of docids that have already been processed."""
139+ if not self .output_file .exists ():
140+ return set ()
182141
183- def __del__ (self ):
184- self .write_output ()
142+ with open (self .output_file , "r" ) as f :
143+ # Reading only the docid column (1st column) and returning it as a set
144+ return set (int (line .split ("\t " )[0 ]) for line in f .readlines ())
185145
186146 def generate_queries (self ):
187147 """
@@ -190,17 +150,12 @@ def generate_queries(self):
190150 input_df = pd .read_table (
191151 self .input_file , names = ["docid" , "document" ], header = None
192152 )
193- # remove docids that are already in the output dataframe, that do not have any NaN/None/strip() == "" values
194- skipping_ids : int = 0
195- valid_docids = set (self .output_df ["docid" ])
196- for _ , row in self .output_df .iterrows ():
197- for i in range (self .num_samples ):
198- query = row [f"query_{ i } " ]
199- if query is None or query .strip () == "" :
200- valid_docids .remove (row ["docid" ])
201- skipping_ids += 1
202- break
203- input_df = input_df [~ input_df ["docid" ].isin (valid_docids )]
153+
154+ processed_docids = self ._already_processed_docids ()
155+
156+ skipping_ids = input_df ["docid" ].nunique ()
157+ input_df = input_df [~ input_df ["docid" ].isin (processed_docids )]
158+ skipping_ids -= input_df ["docid" ].nunique ()
204159
205160 print (
206161 f"Processing { len (input_df )} documents (skipping { skipping_ids } ). Minimum ID: { input_df ['docid' ].min ()} , maximum ID: { input_df ['docid' ].max ()} "
@@ -234,7 +189,6 @@ def _generate_queries(self, documents: List[Tuple[str, str]]):
234189 queries = self ._doc2query (docs )
235190 new_queries : List [Tuple [str , List [str ]]] = list (zip (docids , queries ))
236191 self .add_new_queries (new_queries )
237- self .write_output ()
238192
239193 def _doc2query (self , texts : List [str ]) -> List [List [str ]]:
240194 """
@@ -271,6 +225,39 @@ def _doc2query(self, texts: List[str]) -> List[List[str]]:
271225 rtr = [gens for gens in chunked (outputs , self .num_samples )]
272226 return rtr
273227
228+ def sort_output_file (self ):
229+ """Sort the output file by docid."""
230+ if not self .output_file .exists ():
231+ print ("Output file does not exist." )
232+ return
233+
234+ df = pd .read_table (self .output_file , names = ["docid" ] + [f"query_{ i } " for i in range (self .num_samples )], header = None )
235+ df = df .sort_values (by = "docid" )
236+ df .to_csv (self .output_file , sep = "\t " , index = False , header = False )
237+
238+ def verify_output (self ):
239+ """Check if all docids from the input_file have queries in the output_file."""
240+ # Check if output file exists
241+ if not self .output_file .exists ():
242+ print ("Output file does not exist. Verification failed!" )
243+ return
244+
245+ input_df = pd .read_table (self .input_file , names = ["docid" , "document" ], header = None )
246+ output_df = pd .read_table (self .output_file , names = ["docid" ] + [f"query_{ i } " for i in range (self .num_samples )], header = None )
247+
248+ input_docids = set (input_df ["docid" ].values )
249+ output_docids = set (output_df ["docid" ].values )
250+
251+ missing_docids = input_docids - output_docids
252+
253+ if not missing_docids :
254+ print ("All docids from input_file have corresponding queries in the output_file." )
255+ else :
256+ print (f"Missing queries for { len (missing_docids )} docids in the output_file." )
257+ print ("Some of the missing docids are:" , list (missing_docids )[:10 ])
274258
275259if __name__ == "__main__" :
276- Doc2Query ().generate_queries ()
260+ d2q = Doc2Query ()
261+ d2q .generate_queries ()
262+ d2q .sort_output_file ()
263+ d2q .verify_output ()
0 commit comments