Skip to content

Commit 25436ff

Browse files
committed
Making the script lot more memory and IO (file write) efficient
1 parent 3a27836 commit 25436ff

File tree

1 file changed

+56
-69
lines changed

1 file changed

+56
-69
lines changed

scripts/doc2query-t5.py

Lines changed: 56 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import signal
2-
import sys
3-
from typing import List, Tuple, Dict
1+
from typing import List, Tuple, Dict, Set
42
import re
53
from pathlib import Path
64

@@ -57,7 +55,6 @@ class Doc2Query:
5755
batch_size: int
5856
input_file: Path
5957
output_file: Path
60-
output_df: pd.DataFrame
6158
pattern: re.Pattern
6259

6360
def __init__(
@@ -90,13 +87,6 @@ def __init__(
9087
output_file : Path, optional
9188
The output file, by default OUTPUT_FILE
9289
"""
93-
signal.signal(
94-
signal.SIGINT, lambda signo, _: self.__del__() and sys.exit(signo)
95-
)
96-
signal.signal(
97-
signal.SIGTERM, lambda signo, _: self.__del__() and sys.exit(signo)
98-
)
99-
10090
self.device = torch.device(device)
10191
self.model = (
10292
T5ForConditionalGeneration.from_pretrained(model_name)
@@ -114,16 +104,6 @@ def __init__(
114104
self.output_file = output_file
115105
assert input_file.exists()
116106
self.input_file = input_file
117-
if self.output_file.exists():
118-
self.output_df = pd.read_table(
119-
self.output_file,
120-
names=["docid"] + [f"query_{i}" for i in range(num_samples)],
121-
header=None,
122-
)
123-
else:
124-
self.output_df = pd.DataFrame(
125-
columns=["docid"] + [f"query_{i}" for i in range(num_samples)]
126-
)
127107
self.pattern = re.compile("^\\s*http\\S+")
128108

129109
def add_new_queries(self, new_queries: List[Tuple[str, List[str]]]):
@@ -144,44 +124,24 @@ def add_new_queries(self, new_queries: List[Tuple[str, List[str]]]):
144124

145125
for docid, queries in new_queries:
146126
assert 1 <= len(queries) <= self.num_samples
127+
new_data["docid"].append(docid)
128+
for i, query in enumerate(queries):
129+
new_data[f"query_{i}"].append(query)
147130

148-
if len(queries) == self.num_samples:
149-
# We do not append the queries if they are already in the output dataframe
150-
# remove docid from output_df
151-
if docid in self.output_df["docid"].values:
152-
self.output_df = self.output_df[self.output_df["docid"] != docid]
153-
new_data["docid"].append(docid)
154-
for i, query in enumerate(queries):
155-
new_data[f"query_{i}"].append(query)
156-
else:
157-
assert docid in self.output_df["docid"].values
158-
# We append the queries if they are not in the output dataframe
159-
existing_queries: List[str] = []
160-
for i in range(self.num_samples):
161-
# fetch the existing queries, if they are not NaN / None / strip() == "" etc.
162-
query = self.output_df[self.output_df["docid"] == docid][
163-
f"query_{i}"
164-
].values[0]
165-
if query is not None and query.strip() != "":
166-
existing_queries.append(query)
167-
168-
assert len(existing_queries) + len(queries) == self.num_samples
169-
170-
# remove docid from output_df
171-
self.output_df = self.output_df[self.output_df["docid"] != docid]
172-
new_data["docid"].append(docid)
173-
for i, query in enumerate(existing_queries + queries):
174-
new_data[f"query_{i}"].append(query)
175-
176-
self.output_df = pd.concat(
177-
[self.output_df, pd.DataFrame(new_data)], ignore_index=True
178-
)
131+
self.write_output(new_data)
132+
133+
def write_output(self, new_data: Dict[str, List[str]]):
134+
df = pd.DataFrame(new_data)
135+
df.to_csv(self.output_file, mode="a", sep="\t", index=False, header=False)
179136

180-
def write_output(self):
181-
self.output_df.to_csv(self.output_file, sep="\t", index=False, header=False)
137+
def _already_processed_docids(self) -> Set[int]:
138+
"""Get set of docids that have already been processed."""
139+
if not self.output_file.exists():
140+
return set()
182141

183-
def __del__(self):
184-
self.write_output()
142+
with open(self.output_file, "r") as f:
143+
# Reading only the docid column (1st column) and returning it as a set
144+
return set(int(line.split("\t")[0]) for line in f.readlines())
185145

186146
def generate_queries(self):
187147
"""
@@ -190,17 +150,12 @@ def generate_queries(self):
190150
input_df = pd.read_table(
191151
self.input_file, names=["docid", "document"], header=None
192152
)
193-
# remove docids that are already in the output dataframe, that do not have any NaN/None/strip() == "" values
194-
skipping_ids: int = 0
195-
valid_docids = set(self.output_df["docid"])
196-
for _, row in self.output_df.iterrows():
197-
for i in range(self.num_samples):
198-
query = row[f"query_{i}"]
199-
if query is None or query.strip() == "":
200-
valid_docids.remove(row["docid"])
201-
skipping_ids += 1
202-
break
203-
input_df = input_df[~input_df["docid"].isin(valid_docids)]
153+
154+
processed_docids = self._already_processed_docids()
155+
156+
skipping_ids = input_df["docid"].nunique()
157+
input_df = input_df[~input_df["docid"].isin(processed_docids)]
158+
skipping_ids -= input_df["docid"].nunique()
204159

205160
print(
206161
f"Processing {len(input_df)} documents (skipping {skipping_ids}). Minimum ID: {input_df['docid'].min()}, maximum ID: {input_df['docid'].max()}"
@@ -234,7 +189,6 @@ def _generate_queries(self, documents: List[Tuple[str, str]]):
234189
queries = self._doc2query(docs)
235190
new_queries: List[Tuple[str, List[str]]] = list(zip(docids, queries))
236191
self.add_new_queries(new_queries)
237-
self.write_output()
238192

239193
def _doc2query(self, texts: List[str]) -> List[List[str]]:
240194
"""
@@ -271,6 +225,39 @@ def _doc2query(self, texts: List[str]) -> List[List[str]]:
271225
rtr = [gens for gens in chunked(outputs, self.num_samples)]
272226
return rtr
273227

228+
def sort_output_file(self):
229+
"""Sort the output file by docid."""
230+
if not self.output_file.exists():
231+
print("Output file does not exist.")
232+
return
233+
234+
df = pd.read_table(self.output_file, names=["docid"] + [f"query_{i}" for i in range(self.num_samples)], header=None)
235+
df = df.sort_values(by="docid")
236+
df.to_csv(self.output_file, sep="\t", index=False, header=False)
237+
238+
def verify_output(self):
239+
"""Check if all docids from the input_file have queries in the output_file."""
240+
# Check if output file exists
241+
if not self.output_file.exists():
242+
print("Output file does not exist. Verification failed!")
243+
return
244+
245+
input_df = pd.read_table(self.input_file, names=["docid", "document"], header=None)
246+
output_df = pd.read_table(self.output_file, names=["docid"] + [f"query_{i}" for i in range(self.num_samples)], header=None)
247+
248+
input_docids = set(input_df["docid"].values)
249+
output_docids = set(output_df["docid"].values)
250+
251+
missing_docids = input_docids - output_docids
252+
253+
if not missing_docids:
254+
print("All docids from input_file have corresponding queries in the output_file.")
255+
else:
256+
print(f"Missing queries for {len(missing_docids)} docids in the output_file.")
257+
print("Some of the missing docids are:", list(missing_docids)[:10])
274258

275259
if __name__ == "__main__":
276-
Doc2Query().generate_queries()
260+
d2q = Doc2Query()
261+
d2q.generate_queries()
262+
d2q.sort_output_file()
263+
d2q.verify_output()

0 commit comments

Comments
 (0)